#
#	IdenticalTrees( t1:Tree, t2:Tree )
#		return true/false if the trees have the same topology
#		leaves are considered equal if their labels are equal.
#
#					CK ???
#		rewritten	Gaston H. Gonnet (Jun 21, 2005)
#		once more	Gaston H. Gonnet (Jul 4, 2005)

IdenticalTrees := proc( t1:Tree, t2:Tree )
l1 := [];  for z in Leaves(t1) do l1 := append(l1,z[1]) od;
l2 := [];  for z in Leaves(t2) do l2 := append(l2,z[1]) od;
l1 := sort(l1);  l2 := sort(l2);
n := length(l1);
if l1 <> l2 then return(false) elif n <= 3 then return(true) fi;

if l1[1] <> l1[2] then wit := l1[1]
elif l1[-1] <> l1[-2] then wit := l1[-1]
else for i from 2 to n-1 while l1[i-1]=l1[i] or l1[i]=l1[i+1] do od;
     if i >= n then
	error('IdenticalTrees requires at least one unique label') fi;
     wit := l1[i]
fi;

FlattenTree := proc(t:Tree)
if type(t,Leaf) then t[1]
else sort( [procname(t[1]),procname(t[3])] ) fi
end:

BubbleUp := proc( f1:list )
f := f1;
while f[1] <> wit and f[2] <> wit do
    if has(f[1],wit) then
	 if has(f[1,1],wit) then
	      f := sort( [f[1,1],sort([f[1,2],f[2]])] )
	 else f := sort( [f[1,2],sort([f[1,1],f[2]])] ) fi
    elif has(f[2],wit) then
	 if has(f[2,1],wit) then
	      f := sort( [f[2,1],sort([f[2,2],f[1]])] )
	 else f := sort( [f[2,2],sort([f[2,1],f[1]])] ) fi
    else error('should not happen') fi;
od:
f
end:

evalb( BubbleUp(FlattenTree(t1)) = BubbleUp(FlattenTree(t2)) )
end:



GetLcaSubtree := proc(t)
description 'Get all leaf numbers of tree t';
  if type(bet, name) then bet := copy([]); fi;
  if not op (0, t) = Leaf then
    bet := append(bet, op(GetLcaSubtree(t[1])));
    bet := append(bet, op(GetLcaSubtree(t[3])));
  else 
    [t[1]];
  fi;
end:

# ---------------------------


TotalTreeWeight := proc (t: Tree)  
  description
  'Returns the sum of the length of all branches im PAM units'; 
  dist := GetTreeLength_r(t); 
  return(-dist); 
end:

GetTreeLength_r := proc (t: Tree)
  if op (0, t) = Leaf then
    if type(dist, name) then
      dist := 0; 
    fi; 
    dist; 
  else
    left := GetTreeLength_r(t[1]) + t[1, 2] - t[2]; 
    right := GetTreeLength_r(t[3]) + t[3, 2] - t[2]; 
    dist :=  left + right ; 
  fi; 
end:


# ------------------------------------------


AddSpecies := proc (t: Tree, Species: list)
  description
  '  Species:	List of species, this is used to distinguish between
		paralogous and orthologous changes.
   Every node of the tree contains the information of which species were on 
   the left (t[6]) and on the right (t[7]) side of the branch. If the tree length is less than 6,
   then the tree is expanded with 0 at position 4 and 5.
   (e.g. {MOUSE, YEAST, ECOLI}).'; 
  if op (0, t) = Leaf then
    if type(t[1], numeric) then
      oc := {Species[t[1]]}; 
    else
      oc := {Species[t[3]]}; 
    fi;
    if length(t)<4 then Leaf(t[1..3], 0, 0, oc, oc)
    elif length(t)<5 then Leaf(t[1..4], 0, oc, oc)
    else Leaf(t[1..5], oc, oc) fi;
  else
    t1 := AddSpecies (t[1], Species); 
    t3 := AddSpecies (t[3], Species); 
    newl := (t1[6] union t1[7]); 
    newr := (t3[6] union t3[7]); 
    if length(t)<4 then Tree(t1, t[2], t3, 0, 0, newl, newr)
    elif length(t)<5 then Tree(t1, t[2], t3, t[4], 0 , newl, newr)
    else Tree(t1, t[2], t3, t[4], t[5], newl, newr) fi;
  fi
end:

# -------------------------------

FindRules := proc(t: Tree)
global list; 
  description 
  'Checks the tree for any rules in the form: 
   a is closer to b than to c and returns a list of those rules.'; 
  Set(printgc=false); 
  if length(t)<6 then printf('Before this can be done, the tree has to be associated with species. \nUse  t := CT_Color (tree, ColorNr)(optional) and then t := AddSpecies(t, Species) first...'); 
    return([]); 
  fi; 
  list := copy('');  
  if printlevel > 3 then print('FindRules: this may take long for a large tree...'); fi;
  FindRules_R(t);
  list := [list]; 
  list := If(length(list)>1, list[2..length(list)], []); 
  list; 
end:

FindRules_R := proc(t: Tree)
global list; 
  if printlevel > 3 and list <> '' then
    if mod(length([list]), 100) < 20 then printf('Nr rules: %d\n',length([list])); fi; 
  fi;
  if not type(t, Leaf) then 
    l := t[6]; r := t[7]; 
    int := intersect(t[6], t[7]); 
    if length(int) > 0 then para := true else para := false fi; 
    le := l minus r; ri := r minus l; 
    if length(le) + length(ri)> 2 and para = false then 
      res := IsAmbig(le, ri, list); 
      if res[2] <> [] then
        list := list, op(res[2]); 
      fi; 
    fi; 
    FindRules_R(t[1]); 
    FindRules_R(t[3]);  
  fi; 
end:

FindSpeciesViolations := proc(arg: anything)
 description 'arg: a Tree or a list.
 If it is a tree, it must contain information (6, 7) about species. 
 Use AddSpecies to get such a tree. 
 From this tree a list of rules is generated (a closer to b than to c etc).
 If it is a list of those rules ([a, {b, c}], [d, {e, f}], ...) a list of 
 contradictions is returned';
 if type(arg, Tree) then list := FindRules(arg);
 else list := arg fi;
 for i to length(list)-1 do 
   a1 := list[i, 1]; 
   a2 := list[i, 2]; 
   for j from i + 1 to length(list) do
      b2 := list[j, 2]; 
      b1 := list[j, 1]; 
      if intersect(a1, b2)<>{} and intersect(b1, a2)<>{} and intersect(a2, b2) <> {} then
	printf('%a %a -> %a %a\n', a1, a2, b1, b2); 
      fi; 
   od; 
 od; 
end:

 
IsAmbig := proc(left, right, list); 
 l := length(list); 
 tests := ''; 
 count := 0; 
 le := left; ri := right; 

 for n to 2 do
   for i to length(le) do
     for j to length(ri)-1 do
       for k from j + 1 to length(ri) do
          test1 := [{le[i]}, {ri[j], ri[k]}]; 
          tests := tests, test1; 
          a1 := test1[1]; 
          a2 := test1[2]; 
          for m to l do
            b2 := list[m, 2]; 
            b1 := list[m, 1]; 
            if intersect(a1, b2)<>{} and intersect(b1, a2)<>{} and intersect(a2, b2) <> {} then
              count := count + 1; 
            fi; 
          od; 
        od; 
      od; 
    od; 
    t := le; le := ri; ri := t; 
  od; 
  if count > 0 then lprint(); fi; 
  tests := [tests]; 
  tests := If(length(tests)>1, tests[2..length(tests)], []); 
  return([count, tests]); 
end:

CheckAmbigTree := proc(t: Tree)
description ' Tree t must contain species information at position 6 and 7. 
 To get this, use AddSpecies. 
 The function checks for violations of rules such as "a is closer to b
 than to c in one place but a is closer to c than to b in another place".
 The number of violations for each subtree are counted and added to the
 tree at position 8. 
 If an additional argument is given, a list of rules, those
 rules are taken. (Function FindRules finds those rules :-)';
 if nargs >1 then list := args[2]
 else list := FindRules(t); fi; 
 if printlevel > 3 then print('CheckTree: this may take long for a large tree...'); fi;
 t1 := CheckTree(t, list);
 if printlevel > 3 then printf('\nCheckTree: done\n'); fi; 
 return(t1);
end:

CheckTree := proc(t: Tree, list: array); 
  tree := t; ambig := 0; 
  if not type(tree, Leaf) then 
    if printlevel > 3 then printf('.'); fi;
    l := t[6]; r := t[7]; 
    t1 := t[1]; t3 := t[3]; 
    t1 := CheckTree(t1, list); 
    t3 := CheckTree(t3, list); 
    le := l minus r; ri := r minus l;  
    int := intersect(t[6], t[7]); 
    if length(int) > 0 then para := true else para := false fi; 
    if length(le) + length(ri)> 2  and para = false then
      res := IsAmbig(le, ri, list); 
      ambig := ambig + res[1]; 
    fi; 
    tree[1] := t1; tree[3] := t3; 
    return(Tree(tree[1..7], ambig)); 
  else
    return(Leaf(tree[1..7], ambig)); 
  fi; 
end:

# -----------------------------

SubDist := proc(t: Tree, i: integer, j: integer)
  description
  'Get the distance in PAM units from leaf i to leaf j'; 
  subt := GetSubTree_r(t, i, j); 
  dist := GetRootDist_r(subt, i) + GetRootDist_r(subt, j) - 2*subt[2]; 
  return(-dist)
end:

GetRootDist_r := proc(t, i); 
  l := t[4]; r := t[5]; 
  if member(i, l) then
    if not type(t[1], Leaf) then 
      dist := GetRootDist_r(t[1], i); 
    else
      dist := t[1, 2];  
    fi; 
  else
    if not type(t[3], Leaf) then 
      dist := GetRootDist_r(t[3], i); 
    else 
      dist := t[3, 2]; 
    fi; 
  fi; 
  dist; 
end:


GetPathDistance := proc(order: array)
  description
  'order:  order of tree or AllAll traversal
   
   If second argument is a tree, then the tree
   is traversed in the given order and the length 
   (only in PAM units!) of the path is returned.
   
   If second argument is an array of Matches (AllAll) then
   the AllAl is "traversed" in the given order and
   the path length is returned. The score is always
   divided by the length of the match
   
   The second argument may also be a distance matrix.
   
   If a third argument is given ("PAM" or "SCORE") the units
   of the distance can be chosen (for the AllAll)';
   
   if nargs < 2 then error('Second argument must be a tree or an AllAll'); fi;
   path := 0; 
   if type(args[2], Tree) then 
     t := args[2];
     for i to length(order)-1 do
         path := path + SubDist(t, order[i], order[i + 1]); 
     od:
   elif type(args[2], array(array(Match))) then
     AllAll := args[2];
     if nargs<3 then
       for i to length(order)-1 do
         path := path + AllAll[order[i], order[i + 1], PamNumber]
       od;
     elif SearchString('score', args[3]) > -1 then
       for i to length(order)-1 do
         path := path + AllAll[order[i], order[i + 1], Sim]/AllAll[order[i], order[i + 1], Length1];
       od;
     fi;
   elif type(args[2], array(array(numeric))) then
     Dist := args[2];
     for i to length(order)-1 do
       path := path + Dist[order[i], order[i + 1]];
     od:
   else
     error('Invalid type of second argument');
   fi;
   return(path/2);
 end:
 
GetMATreeNew := proc(MA: array(string))

#taken from /usr/people/sgc/code/devdarwin/lib/Local/MA/MAlignment_tools
#and slightly modified

description 'Estimates Dist and Var Matrices from an alignment';

  n := length(MA);
  Dist:=CreateArray(1..n,1..n);
  Var :=CreateArray(1..n,1..n);
  for i to n-1 do
    for j from i+1 to n do
      s:=EstimatePam(MA[i],MA[j],DMS);
      Dist[i,j]:=Dist[j,i]:=s[2];
      Var[i,j] :=Var[j,i] :=s[3];
    od;
  od;
  PTree:=traperror(MinSquareTree(Dist, Var));
  if type(PTree, string) then
    print(PTree);
    print('There are probably sequences of length 0!');
    PTree := 0
  fi;
  return([PTree,Dist,Var]);
end:

