module external doCSE;

ReadLibrary('Optimizations/Normalize');

doCSE := proc( pd )
   body := pd[6];
   opt := CompressTree(body);

   body := opt[1];
   values := opt[3];
   multiple_uses := opt[5];

   to_keep := [];
   for i in multiple_uses do
      # do not include numeric constants
      if(not type(values[i], numeric)) then
	 to_keep := append(to_keep, i);
      fi:
   od:

   last_local := length(pd[2]);
   for i in to_keep do
      pd[2] := append(pd[2], symbol('_val_num_'.string(i)));
   od:

   body := ExpandTree(to_keep, last_local, body, values);

   pd[6] := body;
   pd;
end:

constanttype := proc(x) -> boolean;
   type(x,{numeric,string,list({numeric,string}),set({numeric,string})});
end;
   
valuetype := proc(x) -> boolean;
   type(x,{numeric,string});
end;
   
CompressTree := proc( n; (value_nums=table()) : structure(anything,table), (values=[]) : list, (due_defs=[]) : list, (multiple_uses={}) : set, (VALNUM=GetValNum) : procedure)
   # if it's a value type
   if(valuetype(n)) then
      # FIXME constants should be global
      VALNUM(n, value_nums, values, due_defs, multiple_uses);
   elif (not type(n, structure)) then
      # FIXME what should we do with calls to parse()? is it able to access
      # locals?
      (n, value_nums, values, due_defs, multiple_uses, false);
   else
      cmd := top(n);
      ret := (cmd)();

      if(cmd = Assign) then
	 for i from 1 to length(n)-1 do
	    ret := append(ret, n[i]);
	 od:
	 opt := procname(n[length(n)], value_nums, values, due_defs, multiple_uses, VALNUM);
	 right := opt[1];
	 ret := append(ret, right);
	 value_nums := opt[2];
	 values := opt[3];
	 due_defs := opt[4];
	 multiple_uses := opt[5];
	 impure := opt[6];

	 for i from 1 to length(n)-1 do
	    var := GetInvalidatedName(n[i]);
	    if(not type(var,name)) then
	       value_nums := InvalidateAllUses(var, value_nums);
	    else
	       value_nums := InvalidateAllGlobals(value_nums);
	    fi:
	 od:

	 if type(right,structure(anything,valnum)) then
	    num := right[1];
	    # n[i] will now be an alias for value number 'num'
	    # TODO can n[i] be anything evil like an eval() statement or
	    # whatever? are globals also evil?
	    for i from 1 to length(n)-1 do
	       left := n[i];

	       if type(left,structure(anything,select)) then
		  opt := procname(left[2], value_nums, values, due_defs, multiple_uses, VALNUM);
		  left[2] := opt[1];
		  value_nums := opt[2];
		  values := opt[3];
		  due_defs := opt[4];
		  multiple_uses := opt[5];
		  impure := opt[6];
	       fi:
	       value_nums[left] := num;
	    od:
	 fi:

	 (ret, value_nums, values, due_defs, multiple_uses, false);
      elif(cmd = ForLoop) then
	 var := GetInvalidatedName(n[1]);
	 if(not type(var,name)) then
	    value_nums := InvalidateAllUses(var, value_nums);
	 else
	    value_nums := InvalidateAllGlobals(value_nums);
	 fi:
	 backup := copy(value_nums);

	 for i from 1 to length(n)-1 do
	    ret := append(ret, n[i]);
	 od:

	 subtree := n[length(n)];
	 if not type(subtree, structure(anything,StatSeq)) then
	    subtree := StatSeq(subtree);
	 fi:

	 opt := procname(subtree, value_nums, values, due_defs, multiple_uses, VALNUM);
	 ret := append(ret, opt[1]);
	 value_nums := opt[2];
	 values := opt[3];
	 due_defs := opt[4];
	 multiple_uses := opt[5];
	 impure := opt[6];

	 # discard all value nums inserted in subtree
	 value_nums := PropagateInvalidated(value_nums,backup);

	 (ret, value_nums, values, due_defs, multiple_uses, false);
      elif(cmd = IfStat) then
	 backup := copy(value_nums);

	 for i from 1 to length(n)-1 by 2 do
	    opt := procname(n[i], value_nums, values, [], multiple_uses, GetValNumNodef);
	    ret := append(ret, opt[1]);
	    value_nums := opt[2];
	    values := opt[3];
	    multiple_uses := opt[5];
	    impure := opt[6];

	    subtree := n[i+1];
	    if not type(subtree, structure(anything,StatSeq)) then
	       subtree := StatSeq(subtree);
	    fi:

	    backup2 := copy(value_nums);

	    opt := procname(subtree, value_nums, values, due_defs, multiple_uses, VALNUM);
	    ret := append(ret, opt[1]);
	    value_nums := opt[2];
	    values := opt[3];
	    due_defs := opt[4];
	    multiple_uses := opt[5];
	    impure := opt[6];

	    # discard all value nums inserted in subtree
	    value_nums := PropagateInvalidated(value_nums,backup2);
	 od:

	 if(i = length(n)) then
	    subtree := n[i];
	    if not type(subtree, structure(anything,StatSeq)) then
	       subtree := StatSeq(subtree);
	    fi:

	    opt := procname(subtree, value_nums, values, due_defs, multiple_uses, VALNUM);
	    ret := append(ret, opt[1]);
	    value_nums := opt[2];
	    values := opt[3];
	    due_defs := opt[4];
	    multiple_uses := opt[5];
	    impure := opt[6];
	 fi:

	 # discard all value nums inserted in subtree
	 value_nums := PropagateInvalidated(value_nums,backup);

	 (ret, value_nums, values, due_defs, multiple_uses, false);
      elif(cmd = StatSeq) then
	 for i from 1 to length(n) do
	    opt := procname(n[i], value_nums, values, due_defs, multiple_uses, VALNUM);
	    value_nums := opt[2];
	    values := opt[3];
	    due_defs := opt[4];
	    multiple_uses := opt[5];
	    impure := opt[6];

	    # put defs into dominator
	    for j to length(due_defs) do
	       num := due_defs[j];
	       ret := append(ret, noeval(valdef(num)));
	    od:
	    due_defs := [];

	    ret := append(ret, opt[1]);
	 od:

	 (ret, value_nums, values, [], multiple_uses, false);
      elif(cmd = Local) then
	 (n, value_nums, values, due_defs, multiple_uses, false);
      elif(cmd = Param) then
	 (n, value_nums, values, due_defs, multiple_uses, false);
      elif(cmd = structure and type(eval(n[1]),procedure)) then
	 donotoptimize := {print,printf};

	 cmd := n[1];
	 ret := append(ret, cmd);

	 if(not member(cmd,donotoptimize) and is_pure(cmd)) then
	    impure := false;
	    for i from 2 to length(n) do
	       opt := procname(n[i], value_nums, values, due_defs, multiple_uses, VALNUM);
	       ret := append(ret, opt[1]);
	       value_nums := opt[2];
	       values := opt[3];
	       due_defs := opt[4];
	       multiple_uses := opt[5];
	       impure := impure or opt[6];
	    od:

	    # do not compress if it has an impure subexpression
	    if(impure) then
	       value_nums := InvalidateAllGlobals(value_nums);
	       (ret, value_nums, values, due_defs, multiple_uses, true);
	    else
	       VALNUM(ret, value_nums, values, due_defs, multiple_uses);
	    fi:
	 else
	    # if this function is impure then the whole expression could be
	    # possibly impure

	    # FIXME globals might be changed using symbol() and impure functions
	    # -> assume a call to any impure function changes globals
	    # -> invalidate all expresstions containing globals

	    impure := false;
	    for i from 2 to length(n) do
	       opt := procname(n[i], value_nums, values, due_defs, multiple_uses, VALNUM);
	       ret := append(ret, opt[1]);
	       value_nums := opt[2];
	       values := opt[3];
	       due_defs := opt[4];
	       multiple_uses := opt[5];
	       impure := impure or opt[6];
	    od:

	    if impure or not member(cmd,donotoptimize) then
	       value_nums := InvalidateAllGlobals(value_nums);
	    fi:
	    (ret, value_nums, values, due_defs, multiple_uses, true);
	 fi:
      else # FIXME!!!! only if result is value type
	 impure := false;
         for i from 1 to length(n) do
	    opt := procname(n[i], value_nums, values, due_defs, multiple_uses, VALNUM);
	    ret := append(ret, opt[1]);
	    value_nums := opt[2];
	    values := opt[3];
	    due_defs := opt[4];
	    multiple_uses := opt[5];
	    impure := impure or opt[6];
	 od:

	 # do not compress if it has an impure subexpression
	 if(impure) then
	    (ret, value_nums, values, due_defs, multiple_uses, true);
	 else
	    VALNUM(ret, value_nums, values, due_defs, multiple_uses);
	 fi:
      fi:
   fi:
end:

ExpandTree := proc (to_keep:list(posint), last_local:{posint,0}, body, values:list)
   if(type(body,valnum(anything))) then
      num := body[1];
      if(member(num,to_keep)) then
	 pos := SearchArray(num,to_keep);
	 Local(last_local+pos);
      else
	 procname(to_keep, last_local, values[num], values);
      fi:
   elif(type(body, valdef(anything))) then
      num := body[1];
      if(member(num,to_keep)) then
	 pos := SearchArray(num,to_keep);
	 Assign(Local(last_local+pos), procname(to_keep, last_local, values[num], values));
      else
	 NULL;
      fi:
   else
      if(type(body, structure)) then
	 cmd := top(body);
	 retbody := cmd();
	 for i from 1 to length(body) do
	    ret := procname(to_keep, last_local, body[i], values);
	    if(not ret = NULL) then
	       retbody := append(retbody, ret);
	    fi:
	 od:
	 retbody;
      else
	 body;
      fi:
   fi:
end:

GetValNum:= proc( n; (value_nums=table()) : structure(anything,table), (values=[]) : list, (due_defs=[]) : list, (multiple_uses={}) : set )
   num := value_nums[n];
   if(num = unassigned) then
      values := append(values, n);
      num := length(values);
      value_nums[n] := num;
      due_defs := append(due_defs, num);
   else
      multiple_uses := append(multiple_uses, num);
   fi:
   (noeval(valnum(num)), value_nums, values, due_defs, multiple_uses, false);
end:

GetValNumNodef := proc( n; (value_nums=table()) : structure(anything,table), (values=[]) : list, (due_defs=[]) : list, (multiple_uses={}) : set )
   num := value_nums[n];
   if(num <> unassigned) then
      multiple_uses := append(multiple_uses, num);
      (noeval(valnum(num)), value_nums, values, due_defs, multiple_uses, false);
   else
      (n, value_nums, values, due_defs, multiple_uses, false);
   fi:
end:

InvalidateAllUses := proc( variable; (value_nums=table()) : structure(anything,table) )
   for expr in Indices(value_nums) do
      if( not value_nums[expr]=unassigned and not constanttype(expr) and SearchSubtree(variable, expr)) then
	 value_nums[expr] := unassigned;
      fi:
   od:
   # Q: how to handle value number aliases?
   # A: The definition of the alias must have been executed already and changing the
   # expression of the value number would not change the alias eighter -> doing
   # nothing als long as the expression of the alias does not change should be
   # just fine
   value_nums;
end:

InvalidateAllGlobals := proc(;(value_nums=table()) : structure(anything,table) )
   for expr in Indices(value_nums) do
      if( not value_nums[expr]=unassigned and not constanttype(expr) and uses_globals(expr)) then
	 value_nums[expr] := unassigned;
      fi:
   od:
   value_nums;
end:

uses_globals := proc( s ) -> boolean;
   pure_procs := {noeval, length, size, member, type, log, log10, sqrt, op}; #keep symbol() out of this set!
   if(type(eval(s),procedure)) then
      pd := disassemble(op(s));
      body := pd[6];
      if(not member(builtin, pd[3])) then
	 uses_globalsR(body,pure_procs union {s})[1];
      else
	 true;
      fi:
   else
      uses_globalsR(s, pure_procs)[1];
   fi:
end:

uses_globalsR := proc( s; (pchecked={}) : set(procedure) ) -> (boolean, set(procedure));
   evil := false;
#   if(top(s) = structure) then
#      p := s[1];
#      if(member(p,pchecked)) then
#	 evil := false;
#      else
#	 if (not type(eval(p), procedure)) then
#	    # if it's not a prodedure we cannot check it and it's
#	    # potentially EVIL!
#	    evil := true;
#	 else
#	    pchecked := pchecked union {p};
#	    pd := disassemble(op(p));
#	    body := pd[6];
#	    if(not member(builtin, pd[3])) then
#	       ret := procname(body,pchecked);
#	       pchecked := pchecked union ret[2];
#	       evil := evil or ret[1];
#	    else
#	       evil := true;
#	    fi:
#	 fi:
#      fi:
#   elif(type(s,structure)) then
   if type(s,structure) then
      for i from 1 to length(s) do
	 ret := procname(s[i], pchecked);
	 pchecked := pchecked union ret[2];
	 evil := evil or ret[1];
      od;
   elif(type(s,name)) then
      evil := true;
   else
      evil := false;
   fi:
   (evil, pchecked);
end:

writes_globals := proc( s ) -> boolean;
   if type(s,structure) then
      r := false;
      if (top(s)=Assign) then
	 for i from 1 to length(s)-1 do
	    if type(s[i],name) then
	       return(true);
	    fi:
	 od:
      elif (top(s)=ForLoop) and type(s[1],name) then
	 return(true);
      fi:

      for z in s do r := r or procname(z) od;
      r
   fi:
   false;
end:

is_pure := proc( p:procedure ) -> boolean;
   option trace;
   is_pureR(p)[1];
end:

is_pureR := proc( p:procedure; (pchecked={}) : set(procedure) ) -> (boolean, set(procedure));
   pure_procs := {noeval, length, size, member, type}; #keep symbol() out of this set!
   if(member(p, pure_procs)) then (true, {p});
   else
      pd := disassemble(op(p));
      pd := doNormalize(pd);
      body := pd[6];
      if(not writes_globals(body)) then
	 if(not member(builtin, pd[3])) then
	    procs := FindCalledProcs(body);
	    pchecked := pchecked union {p};
	    pure := true;
	    for s in procs do
	       if (not type(s, procedure)) then
		  # if it's not a prodedure we cannot check it and it's
		  # potentially EVIL!
		  pure := false;
	       elif (not member(s, pchecked)) then
		  ret := procname(s, pchecked);
		  pchecked := pchecked union ret[2];
		  pure := pure and ret[1];
	       fi:
	       if not pure then break fi;
	    od:
	    (pure, pchecked);
	 else
	    (false, {p});
	 fi:
      else
	 (false, {p});
      fi:
   fi:
end:

FindCalledProcs := proc( s ) -> set(anything);
   if(top(s) = structure) then
      {s[1]};
   elif(type(s,structure)) then
      procs := {};
      for x in s do
	 procs := procs union procname(x);
      od:
      procs;
   else
     {};
   fi:
end:

PropagateInvalidated := proc( src : structure(anything, table), tgt : structure(anything, table)) -> structure(anything, table);
   ret := tgt;
   for i in Indices(tgt) do
      if(src[i] = unassigned) then
	 ret[i] := unassigned;
      fi:
   od:
   ret;
end:

SearchSubtree := proc( n, tree ) -> boolean;
   if(tree = n) then
      true;
   else
      ret := false;
      for i to length(tree) do
	 subtree := tree[i];

	 if(type(subtree, structure)) then
	    ret := procname(n, subtree);
	    if(ret) then
	       break;
	    fi:
	 fi:
      od:
      ret;
   fi;
end:

GetInvalidatedName := proc( n : {structure,name} )
   if(top(n) = select) then
      n[1];
   else
      n;
   fi:
end:

end: #module
