#
#	Reduce Best Fitting Subtree algorithm for tree improvement
#
#	This heuristic attempts to improve a distance tree
#	based on the idea of collapsing/reducing the internal
#	nodes which have a best-fitting index.
#
#				Gaston H Gonnet (Dec 5, 2005)
#

module external RBFS_Tree, ComputeDimensionlessFit;


#
#           X       Y
#            \     /
#             \   /
#              \ /
#               o
#               |
#               :
#               :
#               :
#               |
#               o    <--- subtree in question
#              / \
#             /   \
#            /     \
#           A       o
#                  / \
#                 /   \
#                /     \
#               o       D
#              / \
#             /   \
#            /     \
#           B       C
#
#	k nodes in subtree, n nodes in total.
#	|A,B,...C| = k,  |X,Y,...|=n-k
#
#	branches to be set: 2*k-3 + 1 + (n-k)
#	distance relations: k*(k-1)/2 + k*(n-k)
#	degrees of freedom: (2*n-k-4)*(k-1)/2

RBFS_Tree := proc( t:Tree, Dist0:matrix, Var0:matrix ; 'Top'=((Top=1):posint) )

Dist := Dist0;  Var := Var0;
D2 := Tree_matrix(t):
n := length(D2);
inidiml := ComputeDimensionlessFit(D2,Dist,Var);

LeafLev := CreateArray(1..n):
LeafNodes := CreateArray(1..n):
for z in Leaves(t) do
    if length(z) >= 3 and type(z[3],posint) then i := z[3]
    elif type(z[1],posint) then i := z[1]
    else error(z,'cannot find index in Leaf') fi;
    LeafLev[i] := z[Height];
    LeafNodes[i] := z
od;

# HeiDir*z[Height] is highest for leaves, lowest for root
HeiDir := If( t[Height]-LeafLev[1] > 0, -1, 1 );
LeafLev := HeiDir*LeafLev;

LeafInds := [];
ComputeFitness := proc( t:Tree )
global LeafInds;
if type(t,Leaf) then
     { If( length(t) >= 3 and type(t[3],posint), t[3], t[1] ) }
else l1 := procname(t[1]);
     l3 := procname(t[3]);
     l13 := l1 union l3;
     k := length(l13);
     if n <= k then return(l13) fi;
     th := HeiDir*t[Height];
     fit := 0;

     for A in l13 do
	 DistA := Dist[A];
	 D2A := D2[A];
	 VarA := Var[A];
	 for B in l13 do if A < B then
	     fit := fit + (DistA[B]-D2A[B])^2 / VarA[B]
	 fi od
     od;

     for X to n do if not member(X,l13) then
	 w0 := w1 := w2 := 0;
	 DistX := Dist[X];
	 VarX := Var[X];
	 for A in l13 do
	     d := DistX[A] - (LeafLev[A]-th);
	     iv := 1 / VarX[A];
	     w0 := w0 + iv;
	     w1 := w1 + d*iv;
	     w2 := w2 + d^2*iv;
	 od;
	 fit := fit + w2 - w1^2/w0
     fi od;
     if (2*n-k-4)*(k-1)>0 then 
     	fit := fit / ( (2*n-k-4)*(k-1)/2 );
     fi;
     LeafInds := append(LeafInds,[fit,l13,t]);
     l13
fi
end:

ComputeFitness(t);
LeafInds := sort(LeafInds):

# Select a convenient set of subtrees with minimal fitness index
subts := {};
for i to length(LeafInds) do
    li := LeafInds[i];
    for z in subts do
	if z[2] minus li[2] = {} then subts := subts minus {z}
	elif li[2] minus z[2] = {} then li := NULL;  break
	fi
    od:
    subts := subts union {li}:
    # stop when the new tree has less than 1/2 of the original number of nodes
    if 2*( sum(length(z[2]),z=subts) - length(subts) ) > n then break fi;
od:

# produce the new Dist and Var matrices
map := subts union ( {seq(i,i=1..n)} minus {seq(op(z[2]),z=subts)} );
n2 := length(map);
Dist2 := CreateArray(1..n2,1..n2):
Var2 := CreateArray(1..n2,1..n2):
for i1 to n2 do for i2 from i1+1 to n2 do
    if type(map[i1],posint) then
	 x1 := map[i1];
	 if type(map[i2],posint) then 
	      x2 := map[i2];
	      Dist2[i1,i2] := Dist2[i2,i1] := Dist[x1,x2];
	      Var2[i1,i2] := Var2[i2,i1] := Var[x1,x2];
	 else w0 := w1 := 0;
	      th := HeiDir*map[i2,3,Height];
	      for x2 in map[i2,2] do
		  w0 := w0 + 1/Var[x1,x2];
		  w1 := w1 + ( Dist[x1,x2] - (LeafLev[x2]-th) ) / Var[x1,x2];
	      od;
	      Dist2[i1,i2] := Dist2[i2,i1] := max(0,w1/w0);
	      Var2[i1,i2] := Var2[i2,i1] := 1/w0;
	 fi
    elif type(map[i2],list) then
	 th := HeiDir*(map[i1,3,Height]+map[i2,3,Height]);
	 w0 := w1 := 0;
	 for x1 in map[i1,2] do for x2 in map[i2,2] do
	     w0 := w0 + 1/Var[x1,x2];
	     w1 := w1 + ( Dist[x1,x2] - LeafLev[x2] - LeafLev[x1] + th ) /
			Var[x1,x2];
	 od od;
	 Dist2[i1,i2] := Dist2[i2,i1] := max(0,w1/w0);
	 Var2[i1,i2] := Var2[i2,i1] := 1/w0;
    else error('should not happen')
    fi
od od;


# map the original tree into a reduced tree
convt := table();
for i to length(map) do
    if type(map[i],posint) then
	 convt[ LeafNodes[map[i]] ] := Leaf(i,LeafNodes[map[i],2],i)
    else convt[ map[i,3] ] := Leaf(i,map[i,3,Height],i) fi
od;
Labels2 := [seq(i,i=1..n2)];
t2 := proc(t:Tree)
  x := convt[t];
  if x = unassigned then Tree(procname(t[1]),t[2],procname(t[3]))
  else x fi end(t);
#  Build reduced tree
t2 := LeastSquaresTree( Dist2, Var2, Labels2, t2 ):
bestt := {[DimensionlessFit,t2]};
inifit := bestt[1,1];
if printlevel > 1 then printf(
	'RBFS_Tree: %d nodes reduced to %d with fit index: %.8g\n',
	n, n2, inifit ) fi;
M := max(10,Top);


#  Attempt to build a better reduced tree (keep best M trees)
to 5*M do
    t2 := LeastSquaresTree( Dist2, Var2, Labels2, NJRandom ):
    bestt := bestt union {[DimensionlessFit,t2]};
    for i from 2 to length(bestt) do if bestt[i,1]-bestt[i-1,1] < 1e-10 then
	bestt := bestt minus {bestt[i]};
	break
    fi od;
    if length(bestt) > M then bestt := {bestt[1..M]} fi;
od:
if printlevel > 1 then printf(
	'RBFS_Tree: best new fit indices in reduced tree (n=%d): %a\n',
	n2, {seq(z[1],z=bestt)} ) fi;


# for each of the top trees, enlarge it and optimize
Labels := [seq(z[1],z=LeafNodes)];
t3 := LeastSquaresTree( Dist, Var, Labels, t );
newt := {[inidiml,t],[DimensionlessFit,t3]};

# if sufficiently large, add the trees from recursion
if n2 > 60 then bestt := bestt union
	procname( bestt[1,2], Dist2, Var2, 'Top'=round(1.5*M) ) fi;
for w in bestt do
    t3 := proc(t:Tree)
	if type(t,Leaf) then 
	     x := map[t[3]];
	     if type(x,posint) then LeafNodes[x] else x[3] fi
	else Tree(procname(t[1]),t[2],procname(t[3])) fi
	end(w[2]);
    t3 := LeastSquaresTree( Dist, Var, Labels, t3 );
    newt := newt union {[DimensionlessFit,t3]};
    for i from 2 to length(newt) do if newt[i,1]-newt[i-1,1] < 1e-10 then
	newt := newt minus {newt[i]};
	break
    fi od;
od:


#
#	Find, among the best trees, if it is possible to merge trees
#	with the same leaves.
#	A table, indexed by the set of leaves is constructed.  Only
#	the identical topologies are kept.
#
#	It is assumed that any subtree with less than 5 leaves is
#	already optimal (5-optim has already been done)
Subtrees := table({});
for i to min(5,length(newt)) do
    z := newt[i];
    proc( t:Tree ) external Subtrees;
    if type(t,Leaf) then 
         { If( length(t) >= 3 and type(t[3],posint), t[3], t[1] ) }
    else ls := procname(t[1]) union procname(t[3]);
	 if length(ls) <= 5 or length(ls) > n-5 then return(ls) fi;
	 sls := Subtrees[ls];
	 for z in sls do if IdenticalTrees(t,z) then return(ls) fi od;
	 Subtrees[ls] := sls union {t};
	 ls
    fi end( z[2] )
od;

# Eliminate the indices which are supersets if they contain the
# same number of subtrees, i.e.
#  length(Subt[z1]) = length(Subt[z2]) and z1 minus z2 = {}
Subsets := table({});
for z in Indices(Subtrees) do
    lstz := length(Subtrees[z]);
    if lstz < 2 then Subtrees[z] := {};  next fi;
    Subsets[lstz] := Subsets[lstz] union {z};
od;
for z in Indices(Subsets) do
    Subsz := Subsets[z];
    for i to length(Subsz) do for j to length(Subsz) do if i <> j then
	if Subsz[i] minus Subsz[j] = {} then Subtrees[Subsz[j]] := {} fi
    fi od od
od:

#  Make a list of possible substitutions for the best tree
PossSubs := [];
proc( t:Tree ) external PossSubs;
    if type(t,Leaf) then 
         { If( length(t) >= 3 and type(t[3],posint), t[3], t[1] ) }
    else ls := procname(t[1]) union procname(t[3]);
	 if length(ls) <= 5 or length(ls) > n-5 then return(ls) fi;
	 sls := Subtrees[ls];
	 for z in sls do if not IdenticalTrees(t,z) then
	     PossSubs := append( PossSubs, t=z );
	 fi od;
	 ls
    fi end( newt[1,2] );

# Add the substitutions to the newt list
newt2 := [];
for z in PossSubs do
    w := LeastSquaresTree( Dist, Var, Labels, subs(z,copy(newt[1,2])) );
    newt2 := append( newt2, [DimensionlessFit,w] );
od:
newt := newt union {op(newt2)};

# remove almost-identical dimless fit values
for i from length(newt) by -1 to 2 do if newt[i,1]-newt[i-1,1] < 1e-10 then
    newt := newt minus {newt[i]};
fi od;



if printlevel > 1 then printf(
	'RBFS_Tree: best final fit indices in original tree (n=%d): %a\n',
	n, {seq(z[1],z=newt)} ) fi;

{ seq(newt[i],i=1..min(Top,length(newt))) }

end:


ComputeDimensionlessFit := proc( input:{Tree,matrix}, Dist:matrix, Var:matrix )
if type(input,Tree) then return( procname(Tree_matrix(input),Dist,Var) ) fi;
n := length(input);
if n<4 then return(0)
elif length(Dist) <> n or length(Var) <> n or length(input[1]) <> n or
     length(Dist[1]) <> n or length(Var[1]) <> n then
     error('invalid dimensions of the input matrices') fi;
m := n*(n-1)/2;
d12 := Dist[1,2];
S := h := sigd1 := sigd2 := 0;
for i to n do for j from i+1 to n do
    wij := If( Var[i,j]=0, 0, 1/Var[i,j] );
    h := h + wij;
    dij := Dist[i,j];
    S := S + wij*(dij-input[i,j])^2;
    sigd1 := sigd1 + dij-d12;
    sigd2 := sigd2 + (dij-d12)^2
od od;
h := m/h;
sigd2 := (m*sigd2-sigd1^2)/m/(m-1);
1000 * S / sigd2 / ((n-2)*(n-3)/2) * h;
end:

end;  # end module
