# Purpose: Split decomposition routines, Dress-Trees
# Author:  Lukas Knecht
# Created: 15 Feb 1993
# References:
# [1] Bandelt HJ, Dress AWM: Split decomposition: a new and useful approach
#     to phylogenetic analysis of distance data
IsolationIndex := proc (d: matrix(numeric), I: set)
  description
  'Computes isolation index for the split [I, {1..length(d)} minus I].';
  n := length (d);
  K := CreateArray (1..n);
  for i to n do K[i] := i od;
  K := {op (K)} minus I;
  w := DBL_MAX;
  for i in I do
    for j in I do
      for k in K do
	for l in K do
	  w := min (w, max (d[i,j] + d[k,l], d[i,k] + d[j,l], d[i,l] + d[j,k])
		    - d[i,j] - d[k,l])
	od
      od
    od
  od;
  w / 2
end:

dSplits := proc (d: matrix(numeric))
  description
  'Computes the d-splits and their isolation indices from the distance matrix.
  Returns list([index,set]).';
  n := length (d);
  if n < 2 or not type (d, array (numeric,n,n)) then
    error ('Must have at least two vertices')
  fi;
  # create initial splits from first two taxa
  res := [[d[1,2], {1}]];
  all := {1,2};
  for i from 3 to n do
    newres := [];
    for r in res do
      # check I union {i},K
      I := r[2];
      lI := length (I);
      K := all minus I;
      lK := length (K);
      w := 2 * r[1];
      for j in I while w > 0 do
	for k in K while w > 0 do
	  for l in K while l <= k and w > 0 do
	    x := d[i,j] + d[k,l];
	    w := min (w, max (x, d[i,k] + d[j,l], d[i,l] + d[j,k]) - x)
	  od
	od
      od;
      if w > 0 then
	newres := append (newres, [w / 2, If (lI < lK, I union {i}, K)])
      fi;
      # check I,K union {i}
      w := 2 * r[1];
      for j in K while w > 0 do
	for k in I while w > 0 do
	  for l in I while l <= k and w > 0 do
	    x := d[i,j] + d[k,l];
	    w := min (w, max (x, d[i,k] + d[j,l], d[i,l] + d[j,k]) - x)
	  od
	od
      od;
      if w > 0 then
	newres := append (newres, [w / 2, If (lK < lI, K union {i}, I)])
      fi
    od;
    # check {1..i-1},{i}
    w := DBL_MAX;
    for k to i - 1 while w > 0 do
      for l to k while w > 0 do
	x := d[i,i] + d[k,l];
	w := min (w, max (x, d[i,k] + d[i,l]) - x)
      od
    od;
    if w > 0 then
      newres := append (newres, [w / 2, {i}])
    fi;
    res := newres;
    all := all union {i}
  od;
  sort (res, x -> -x[1])
end:

dSplitMetricSum := proc (splits: list ([numeric,set]), n: posint)
  description
  'Computes the split decomposable distances d1. n is the number of taxa.';
  d1 := CreateArray (1..n, 1..n);
  for i to n do
    for j to i - 1 do
      pair := {i,j};
      for s in splits do
	if length (s[2] intersect pair) = 1 then
	  d1[i,j] := d1[i,j] + s[1]
	fi
      od;
      d1[j,i] := d1[i,j]
    od
  od;
  d1
end:

dSplitIndex := proc (d: matrix(numeric), splits: list ([numeric,set]))
  description
  'Computes the splittable fraction rho.';
  d1 := dSplitMetricSum (splits, length (d));
  num := denom := 0;
  for i to length (d) do
    num := num + sum (d1[i]);
    denom := denom + sum (d[i])
  od;
  num / denom
end:

cc_find := proc (cc: array (posint), i: posint)
  if cc[i] = i then
    i
  else
    cc[i] := cc_find (cc, cc[i])
  fi
end:

is_connected := proc (nodes: set, dist: matrix)
  cc := CreateArray (1..length (nodes));
  for i to length (nodes) do cc[i] := i od;
  for i to length (nodes) do
    for j to i - 1 do
      if dist[nodes[j],nodes[i]] <> DBL_MAX then
	cc[i] := cc_find (cc, j)
      fi
    od
  od;
  evalb (sum (cc) = length (nodes))
end:
  
min_connected := proc (nodes: set, dist: matrix, adj: list, maxnodes: posint)
  if length (nodes) = maxnodes then
    if is_connected (nodes, dist) then
      return (nodes)
    fi
  else
    # Try all adjacent nodes which are not in nodes
    tryset := {op (adj[nodes[1]])};
    for i from 2 to length (nodes) do
      tryset := tryset union {op (adj[nodes[i]])}
    od;
    tryset := tryset minus nodes;
    for t in tryset do
      newnodes := nodes union {t};
      res := min_connected (nodes union {t}, dist, adj, maxnodes);
      if length (res) > 0 then
	return (res)
      fi
    od
  fi;
  return ({})
end:

atan2 := proc (x: [numeric, numeric])
  if abs (x[1]) < abs (x[2]) * 100 * DBL_EPSILON then
    If (x[2] < 0, 3*Pi/2, Pi/2)
  else
    res := arctan (x[2] / x[1]);
    if x[1] < 0 then
      res + Pi
    elif res < 0 then
      res + 2 * Pi
    else
      res
    fi
  fi
end:

dSplitGraph := proc (splits: list ([numeric, set]), all: {posint, set})
  description
  'Computes a graph from a list of dSplits. Edges will have labels of the
  format [length, splitnr] where splitnr is an index into splits which
  corresponds to this edge. The procedure returns an expression sequence
  g: Graph, angles: array(length(splits),numeric) where angles contains a
  list of angles to be used as hints when drawing the edges of graph.
  all is the set of all taxa of the split or a posint if the set is 1..all.';

  angle_value := proc (low: numeric, high: numeric, avg: numeric)
    high - low - abs ((low + high) / 2 - avg) / 8
  end;
  
  if type (all, posint) then
    res := CreateArray (1..all);
    for i to all do res[i] := i od;
    res := {op (res)}
  else
    res := all
  fi;
  res := Graph (Edges (), Nodes (res));
  angles := CreateArray (1..length (splits));
  for s in sort (splits, x->-length(x[2])) do
    splitnr := SearchArray (s, splits);
    nodes := res[Nodes];
    edges := res[Edges];
    dist := res[Labels];
    # Collect all nodes to be splitted in source and their edge angles
    oldnr := length (nodes);
    source := [];
    for i to oldnr do
      if length (nodes[i] intersect s[2]) > 0 then
	source := append (source, i)
      fi
    od;
    source := {op (source)};
    # Complete to minimal set of source nodes such that all source nodes are
    # a connected component.
    adj := res[Adjacencies];
    for maxnodes from length (source) to length (source) + 3 do
      newsource := min_connected (source, dist, adj, maxnodes);
      if length (newsource) = maxnodes then
	source := newsource;
	break
      fi
    od;
    # Collect edge angles of all source nodes
    sourceangles := [];
    for i in source do
      for j to i - 1 do
	if dist[j,i] <> DBL_MAX then
	  sourceangles := append (sourceangles, 
				  mod (angles[dist[j,i,2]] + Pi, 2 * Pi))
	fi
      od;
      for j from i + 1 to oldnr do
	if dist[i,j] <> DBL_MAX then
	  sourceangles := append (sourceangles, angles[dist[i,j,2]])
	fi
      od
    od;
    sourceangles := sort (sourceangles);
    # Determine direction
    if length (sourceangles) = 0 then
      angle := 0
    elif length (sourceangles) = 1 then
      angle := sourceangles[1] + 3*Pi/4
    else
      avg := [0,0];
      for a in sourceangles do
	avg := avg - [cos (a), sin (a)]
      od;
      avg := atan2 (avg);
      largest := (sourceangles[length(sourceangles)])..(sourceangles[1]+2*Pi);
      for i to length (sourceangles) - 1 do
	if angle_value (sourceangles[i], sourceangles[i+1], avg) >
	  angle_value (largest[1], largest[2], avg) then
	  largest := sourceangles[i]..sourceangles[i+1]
	fi
      od;
      angle := (largest[2] + largest[1]) / 2
    fi;
    angle := mod (angle, 2 * Pi);
    angles[splitnr] := angle;
    # Split the collected nodes
    for i in source do
      newnode := nodes[i] intersect s[2];
      nodes := append (nodes, newnode);
      j := length (nodes);
      nodes[i] := nodes[i] minus newnode;
      edges := append (edges, Edge ([s[1], splitnr], i, j))
    od;
    # Inherit all connections from the splitted nodes
    for i to length (source) do
      for j to i - 1 do
	if dist[source[j], source[i]] <> DBL_MAX then
	  edges := append (edges, Edge (dist[source[j], source[i]],
					j + oldnr, i + oldnr))
	fi
      od
    od;
    res[Edges] := edges;
    res[Nodes] := nodes
  od;
  edges := res[Edges];
  newangles := CreateArray (1..length (edges));
  for i to length (edges) do
    newangles[i] := angles[edges[i,1,2]]
  od;
  res, newangles
end:

PolishAngles := proc (g: Graph, angles: array (numeric))
  description
  'Attempts to polish angles by collapsing g to a tree.';
  edges := g[Edges];
  m := 0;
  for e in edges do
    m := max (m, e[1,2])
  od;
  splitangles := CreateArray (1..m);
  for i to length (edges) do
    splitangles[edges[i,1,2]] := angles[i]
  od;
  cg := BinTree (Collapse (g));
  t := Tree (cg);
  if op (0, t) = Leaf then return (g, angles) fi;
  # Compute split nrs of edges rotated by collapsed splits
  edgelist := CreateArray (1..m, []);
  for e in edges do
    edgelist[e[1,2]] := append (edgelist[e[1,2]], e)
  od;
  affects := CreateArray (1..m, 0);
  for e in cg[Edges] do
    affects[e[1,2]] := {}
  od;
  dist := g[Labels];
  for i to m do
    if affects[i] <> 0 then
      for j to length (edgelist[i]) do
	e := edgelist[i,j];
	for k to j - 1 do
	  f := edgelist[i,k];
	  d1 := dist[min (e[2], f[2]), max (e[2], f[2])];
	  d2 := dist[min (e[3], f[3]), max (e[3], f[3])];
	  if d1 <> DBL_MAX and d2 <> DBL_MAX and d1[2] = d2[2] then
	    affects[i] := affects[i] union {d1[2]}
	  fi
	od
      od
    else
      affects[i] := {}
    fi
  od;
  r := DrawTCount(t[1]); l := DrawTCount(t[3]);
  ar := 2*Pi*r/(l+r);
  corr := CreateArray (1..m, 1..2);
  PolishAngles_R (t[1], 0, ar, t[2], t[4], 
		  cg[Edges], affects, splitangles, corr);
  PolishAngles_R (t[3], ar, 2*Pi - ar, t[2], t[5], 
		  cg[Edges], affects, splitangles, corr);
  for i to m do
    if corr[i,2] > 0 then
      splitangles[i] := mod (splitangles[i] + corr[i,1] / corr[i,2], 2 * Pi)
    fi
  od;
  newangles := CreateArray (1..length (edges));
  for i to length (edges) do
    newangles[i] := splitangles[edges[i,1,2]]
  od;
  g, newangles
end:

PolishAngles_R := proc (t: Tree, low: numeric, width: numeric, oldh: numeric,
			edgenr: integer, cedges: Edges, affects: list(set),
			angles: list(numeric), corr: list([numeric,integer]))
  s := cedges[abs (edgenr),1,2];
  angle := angles[s];
  if edgenr < 0 then angle := mod (angle + Pi, 2*Pi) fi;
  newrotate := low + width / 2 - angle;
  angles[s] := mod (angles[s] + newrotate, 2*Pi);
  if abs (newrotate) > 0.01 then
    for i in affects[s] do
      corr[i,1] := corr[i,1] + newrotate;
      corr[i,2] := corr[i,2] + 1
    od
  fi;
  if op (0, t) = Tree then
    angle1 := mod (angles[cedges[abs(t[4]),1,2]] + If (t[4]<0, 0, Pi)
		   + newrotate, 2 * Pi);
    angle3 := mod (angles[cedges[abs(t[5]),1,2]] + If (t[5]<0, 0, Pi)
		   + newrotate, 2 * Pi);
    if mod (angle3 - angle1, 2*Pi) < Pi then
      tx := [[1,4], [3,5]]
    else
      tx := [[3,5], [1,4]]
    fi;
    r := abs (t[2] - oldh);
    theta := CreateArray (1..2);
    beta := CreateArray (1..2);
    theta[1] := low;
    tc1 := DrawTCount (t[tx[1,1]]);
    beta[1] := width * tc1 / (tc1 + DrawTCount (t[tx[2,1]]));
    theta[2] := low + beta[1];
    beta[2] := width - beta[1];
    for i to 2 do
      ht := DrawTHeight (t[tx[i,1]]); height := abs (ht[1] - ht[2]);
      if ht[1] = ht[2] then height := abs (ht[1]) fi;
      if height = 0 then height := 1 fi;
      sinalpha := r * sin (beta[i]) / height;
      alpha := If (abs (sinalpha) < 1, arctan (sinalpha/sqrt(1-sinalpha^2)), 0);
      if beta[1]+beta[2]>=Pi/2 then alpha := 0 fi;
      if alpha>Pi/4 then alpha := 0 fi;
      if i = 1 then theta[i] := theta[i] - alpha
      else beta[i] := beta[i] + alpha fi;
      PolishAngles_R (t[tx[i,1]], theta[i], beta[i], t[2], t[tx[i,2]],
		      cedges, affects, angles, corr)
    od
  fi
end:

Collapse := proc (g: Graph)
  description
  'Collapses cycles in g by removing edges.';
  # Get multiplicity and length of edges
  edges := g[Edges];
  nodes := g[Nodes];
  dist := g[Labels];
  n := length (edges);
  kept := CreateArray (1..n, true);
  edgelist := CreateArray (1..n, []);
  m := 0;
  for i to n do
    j := edges[i,1,2];
    m := max (m, j);
    edgelist[j] := append (edgelist[j], i)
  od;
  edgelist := edgelist[1..m];
  map := CreateArray (1..length (nodes));
  for i to length (nodes) do
    map[i] := i
  od;
  do
    # Get shortest multiple edge
    multiple := CreateArray (1..n, false);
    s := 0;
    for i to n do
      if kept[i] then
	e := edges[i];
	if multiple[e[1,2]] and (s = 0 or e[1,1] < edges[s,1,1]) then
	  s := i
	fi;
	multiple[e[1,2]] := true
      fi
    od;
    if s = 0 then break fi;
    # Remap all affected nodes
    for i to n do
      if kept[i] and edges[i,1,2] = edges[s,1,2] then
	kept[i] := false;
	map[cc_find (map, edges[i,3])] := cc_find (map, edges[i,2])
      fi
    od;
    for i to length (map) do
      cc_find (map, i)
    od;
    for l in edgelist do
      for i to length (l) do
	if kept[l[i]] then
	  e := edges[l[i]];
	  for j to i - 1 do
	    if kept[l[j]] then
	      f := edges[l[j]];
	      if {map[e[2]], map[e[3]]} = {map[f[2]], map[f[3]]} then
		kept[l[j]] := false
	      fi
	    fi
	  od
	fi
      od
    od
  od;
  newNodes := Nodes ();
  new := CreateArray (1..length (nodes));
  for i to length (nodes) do
    if map[i] = i then
      newNodes := append (newNodes, encoded (nodes[i]));
      new[i] := length (newNodes)
    fi
  od;
  newEdges := Edges ();
  for i to n do
    if kept[i] then
      e := edges[i];
      newEdges := append (newEdges, Edge (e[1], new[map[e[2]]], new[map[e[3]]]))
    fi
  od;
  Graph (newEdges, newNodes)
end:

BinTree := proc (g: Graph)
  description
  'Converts a cycle free connected graph to a graph equivalent to a binary
  tree by introducing new nodes and edges.';
  edges := copy (g[Edges]);
  nodes := copy (g[Nodes]);
  do
    res := Graph (edges, nodes);
    inc := res[Incidences];
    for i to length (inc) while length(inc[i]) = 1 or length(inc[i]) = 3 do od;
    if i > length (inc) then break fi;
    if length (inc[i]) = 2 then
      nodes := append (nodes, '');
      edges := append (edges, Edge ([0.01,edges[inc[i,1],1,2]],i,length(nodes)))
    else
      # Introduce edge between first two and remaining nodes
      nodes := append (nodes, '');
      n := length (nodes);
      e1 := edges[inc[i,1]]; 
      e2 := edges[inc[i,2]];
      if e1[2] = i then
	e1[2] := n; new := i,n
      else
	e1[3] := n; new := n,i
      fi;
      if e2[2] = i then e2[2] := n else e2[3] := n fi;
      edges := append (edges, Edge ([0.01,e1[1,2]],new))
    fi
  od;
  res
end:

TreeAngles := proc (g: Graph)
  description
  'Find angles for edges of g (being a tree) in order to draw it.';
  
  t := Tree (g);
  if op (0, t) = Leaf then return ([]) fi;
  angles := CreateArray (1..length (g[Edges]));
  r := DrawTCount(t[1]); l := DrawTCount(t[3]);
  ar := 2*Pi*r/(l+r);
  angles[abs(t[4])] := (mod (TreeAngles_R (t[1], 0, ar, t[2], angles) +
			     If (t[4] < 0, Pi, 0), 2*Pi) +
                        mod (TreeAngles_R (t[3], ar, 2*Pi - ar, t[2], angles) +
			     If (t[5] < 0, Pi, 0), 2*Pi)) / 2;
  angles
end:

TreeAngles_R := proc (t, min, width, oldh, angles)
  if op(0,t) = Tree then
    r := abs (t[2] - oldh);
    theta := CreateArray (1..2);
    beta := CreateArray (1..2);
    theta[1] := min;
    tc1 := DrawTCount (t[1]);
    beta[1] := width * tc1 / (tc1 + DrawTCount (t[3]));
    theta[2] := min + beta[1];
    beta[2] := width - beta[1];
    for i to 2 do
      ht := DrawTHeight (t[2*i-1]); height := abs (ht[1] - ht[2]);
      if ht[1] = ht[2] then height := abs (ht[1]) fi;
      if height = 0 then height := 1 fi;
      sinalpha := r * sin (beta[i]) / height;
      alpha := If (abs (sinalpha) < 1, arctan (sinalpha/sqrt(1-sinalpha^2)), 0);
      if beta[1]+beta[2]>=Pi/2 then alpha := 0 fi;
      if alpha>Pi/4 then alpha := 0 fi;
      if i = 1 then theta[i] := theta[i] - alpha
      else beta[i] := beta[i] + alpha fi;
      angles[abs (t[3+i])] := mod (TreeAngles_R (t[2*i-1], theta[i], beta[i],
						 t[2], angles) + 
                                   If (t[3+i] < 0, Pi, 0), 2*Pi)
    od
  fi;  
  min + width / 2
end:

encoded := proc (x: anything)
  t := '';
  if type (x, numeric) then
    t := sprintf ('%d', round (x))
  elif type (x, string) then
    t := x
  elif type (x, set) then
    for i to length (x) do
      t := t.x[i];
      if i < length (x) then t := t.',' fi
    od
  fi;
  t
end:

DrawSplitGraph := proc (g: Graph, angles: array (numeric), title: string)
  description
  'Draws graph g with edge e at angle angles[e[1,2]].';
  n := length (g[Nodes]);
  edges := g[Edges];
  xy := CreateArray (1..n, UNDEFINED);
  xy[1] := [0,0];
  oldleft := n;
  left := n - 1;
  while left < oldleft and left > 0 do
    oldleft := left;
    for i to length (edges) do
      e := edges[i];
      if xy[e[2]] <> UNDEFINED and xy[e[3]] = UNDEFINED then
	xy[e[3]] := xy[e[2]] + e[1,1] * [cos (angles[i]), sin (angles[i])];
	left := left - 1
      fi;
      if xy[e[3]] <> UNDEFINED and xy[e[2]] = UNDEFINED then
	xy[e[2]] := xy[e[3]] - e[1,1] * [cos (angles[i]), sin (angles[i])];
	left := left - 1
      fi
    od
  od;
  if left > 0 then error ('Disconnected graph') fi;
  x := zip ((x->x[1])(xy));
  y := zip ((x->x[2])(xy));
  minx := min (x); maxx := max (x);
  miny := min (y); maxy := max (y);
  pd := [];
  labeled := CreateArray (1..max (zip ((x->x[1,2])([op (edges)]))), false);
  for e in edges do
    x1 := x[e[2]]; y1 := y[e[2]];
    x2 := x[e[3]]; y2 := y[e[3]];
    r := (maxx - minx) / 160 / sqrt ((x2-x1)^2 + (y2-y1)^2);
    pd := append (pd, LINE (x1, y1, x2, y2));
    if not labeled[e[1,2]] then
      labeled[e[1,2]] := true;
      if e[1,1] > 0.5 then
	pd := append (pd, CTEXT ((x1+x2)/2-r*(y2-y1), (y1+y2)/2+r*(x2-x1),
				 encoded (e[1,1]), 8))
      fi
    fi
  od;
  for i to n do
    t := encoded (g[Nodes,i]);
    if length (t) > 0 then
      pd := append (pd, CIRCLE (x[i], y[i], 7),
		    CTEXT (x[i], y[i]-(maxy-miny)/125, encoded (g[Nodes,i])))
    fi
  od;
  if nargs > 2 then
    pd := append (pd, LTEXT (minx, maxy, title))    
  fi;
  DrawPlot (pd, proportional)
end:

DrawSplits := proc (splits: list ([numeric,set]), all: {posint, set})
  description
  'Draws a graph from a list of dSplits. all is the set of all taxa of the
  split or a posint if the set is 1..all.';
  DrawSplitGraph (PolishAngles (dSplitGraph (splits, all)))
end:
