#
#	Code to compute the score, distance, prob-identical and variance
#	of two aligned sequences by maximum likelihood
#
#	The model assumes that each position is either:
#
#	(1) Completely conserved with probability p
#
#	(2) Has evolved according to a markovian model of evolution
#		a distance d.
#
#	This function estimates both p and d by maximum likelihood.
#
#			Gaston H. Gonnet (May 8, 2007)
#
EstimatePamComCon := proc( s1:string, s2:string, days:list(DayMatrix) )
global MLEstimatePamp;

if length(s1) <> length(s2) then error(s1,s2,'are of different lengths') fi;
mapf := days[1,Mapping];
if mapf=AToInt then n := 20 else n := length(days[1,Sim]) fi;

C := CreateArray(1..n,1..n):
del := aadel := 0;
for i to length(s1) do
    if s1[i] = '_' then
	 if i=1 or s1[i-1] <> '_' then del := del+1
	 else aadel := aadel+1 fi
    elif s2[i] = '_' then
	 if i=1 or s2[i-1] <> '_' then del := del+1
	 else aadel := aadel+1 fi
    else i1 := mapf(s1[i]);
	 i2 := mapf(s2[i]);
	 if i1 > 0 and i1 <= n and i2 > 0 and i2 <= n then
		C[i1,i2] := C[i1,i2] + 1
	 fi
    fi
od:

# model FixedDel = FD0 + FD1*ln(pam)
FD1 := (DMS[1,FixedDel]-DMS[-1,FixedDel]) /
	ln(DMS[1,PamDistance]/DMS[-1,PamDistance]);
FD0 := DMS[1,FixedDel]-FD1*ln(DMS[1,PamDistance]);
j := round(length(days)/2);
if | DMS[j,FixedDel] - FD1*ln(DMS[j,PamDistance]) - FD0 | > 1e-8 then
    error('deletion cost not logarithmic') fi;

FD1d := del*ln(10)/10*FD1;
ID0 := ln(10)/10*aadel*DMS[1,IncDel];
Q := DMS[1,'logPAM1'];

# compute f from middle matrix
M := exp(10*Q);
f := [ seq( M[j,1]/M[1,j], j=1..n ) ];
f := f / sum(f);

ep := EstimatePam(args);
# initial estimates
pi := sum(C[i,i],i=1..n) / length(s1);
pamPi := PamToPerIdent(ep[2])/100;
d := ep[2];  p := max(0,(pi-pamPi)/(1-pamPi));
lprint(ep[2],pi,pamPi,p);
if d=DMS[1,PamNumber] or sum(C[i,i],i=1..n) = length(s1) then
      return( ep )
fi;

#
#	Maximum likelihood details (maple code)
#
#	Ai=Bi
#	  Le := (p + (1-p)*(M^d)[Ai,Ai]) * f[Ai] / f[Ai]^2;
#
#	Ai <> Bi
#	  Lu := (1-p)*(M^d)[Ai,Bi] * f[Bi] / (f[Ai]*f[Bi]);
#
#	spat := [ diff((M^d)[Ai,Ai],d,d) = log2Md[Ai,Ai],
#		  diff((M^d)[Ai,Ai],d) = logMd[Ai,Ai],
#		  (M^d)[Ai,Ai] = Md[Ai,Ai],
#		  diff((M^d)[Ai,Bi],d,d) = log2Md[Ai,Bi],
#		  diff((M^d)[Ai,Bi],d) = logMd[Ai,Bi],
#		  (M^d)[Ai,Bi] = Md[Ai,Bi] ];
#
#	Lep := diff(ln(Le),p);
#	Lup := diff(ln(Lu),p);
#	Led := diff(ln(Le),d);
#	Lud := diff(ln(Lu),d);
#
#	Lepp := factor(subs(spat,diff(Lep,p)));
#	Lupp := factor(subs(spat,diff(Lup,p)));
#	Lepd := factor(subs(spat,diff(Lep,d)));
#	Lupd := factor(subs(spat,diff(Lup,d)));
#	Ledd := subs(spat,diff(Led,d));
#	Ludd := factor(subs(spat,diff(Lud,d)));
#
#	Lep := subs(spat,Lep);
#	Lup := subs(spat,Lup);
#	Led := subs(spat,Led);
#	Lud := subs(spat,Lud);
#
#	for z in [p,d,pp,pd,dd] do
#	    printf( `\tL%s := L%s + C[Ai,Ai]*(%a);\n`, z, z, Le.z ) od;
#	for z in [p,d,pp,pd,dd] do
#	    printf( `\tL%s := L%s + C[Ai,Bi]*(%a);\n`, z, z, Lu.z ) od;
#
#        lprint( solve( {dp*Lpp+dd*Lpd=-Lp,dp*Lpd+dd*Ldd=-Ld}, {dp,dd} ));
#
#	lprint( -linalg[inverse]( [[Lpp,Lpd],[Lpd,Ldd]] )[2,2] );
#

Md := exp(d*Q);
logMd := Q*Md;
log2Md := Q*logMd;

to 20 do
    L := FD1d*ln(d) + ID0;
    Ld := FD1d/d;
    Ldd := -FD1d/d^2;
    LpComCoi := Lpp := Lpd := Lp := 0;
    for Ai to n do for Bi to n do
        if C[Ai,Bi] = 0 then next
	elif Ai=Bi then
	     L := L + C[Ai,Ai]*ln((p+(1-p)*Md[Ai,Ai])/f[Ai]);
	     Lp := Lp + C[Ai,Ai]*((1-Md[Ai,Ai])/(p+(1-p)*Md[Ai,Ai]));
	     Ld := Ld + C[Ai,Ai]*((1-p)*logMd[Ai,Ai]/(p+(1-p)*Md[Ai,Ai]));
	     Lpp := Lpp - C[Ai,Ai]*((-1+Md[Ai,Ai])^2/
		  (-p-Md[Ai,Ai]+Md[Ai,Ai]*p)^2);
	     Lpd := Lpd + C[Ai,Ai]*(-logMd[Ai,Ai]/(-p-Md[Ai,Ai]+Md[Ai,Ai]*p)^2);
	     Ldd := Ldd + C[Ai,Ai]*((1-p)*log2Md[Ai,Ai]/(p+(1-p)*Md[Ai,Ai])-
		  (1-p)^2*logMd[Ai,Ai]^2/(p+(1-p)*Md[Ai,Ai])^2);
	else
	     L := L + C[Ai,Bi]*ln(Md[Ai,Bi]/f[Ai]);
	     Lp := Lp + C[Ai,Bi]*(-1/(1-p));
	     Ld := Ld + C[Ai,Bi]*(logMd[Ai,Bi]/Md[Ai,Bi]);
	     Lpp := Lpp + C[Ai,Bi]*(-1/(-1+p)^2);
	     Ldd := Ldd + C[Ai,Bi]*((log2Md[Ai,Bi]*Md[Ai,Bi]-logMd[Ai,Bi]^2)/
		  Md[Ai,Bi]^2);
        fi
    od od;

    disc := Lpd^2-Lpp*Ldd;
    if |disc| < 1e-12*(Lpd^2+|Lpp*Ldd|) then
        error('singularity found, cannot use ML') fi;
    dp := -(-Ldd*Lp+Ld*Lpd)/disc;
    dd := -(-Lpp*Ld+Lp*Lpd)/disc;
    while p+dp >= 1 or p+dp < 0 do dp := dp/2 od;
    while  d+dd <= 0 do dd := dd/2 od;

    lprint( 10*L/ln(10), dp,p, dd,d );
    p := p+dp;
    d := d+dd;

    if |dp|< 1e-8 and |dd| < 1e-10*d then break fi;
    Mdp := exp(dd*Q);
    Md := Md*Mdp;
    logMd := logMd*Mdp;
    log2Md := log2Md*Mdp;
od:
MLEstimatePamp := p;

if |dp|+|dd| >= 1e-8*d then error('did not converge',d,p,dp,dd) fi;

[10*L/ln(10),d,Lpp/(-Lpp*Ldd+Lpd^2)]
end:

# ReadLibrary(EstimatePamComCon):
# CreateDayMatrices();

# s1 := Rand(Protein(6000));  s2 := Mutate(s1,30);
# EstimatePam(s1,s2,DMS);
# EstimatePamComCon(s1,s2,DMS);

# S1 := Stat('d'); S2 := Stat('p'); S3 := Stat('var(d)');
# for i to 1000 do
# s1 := Rand(Protein(2000));
# s2 := s1[1..500].Mutate(s1[501..-1],30);
# ep := traperror(EstimatePamComCon(s1,s2,DMS));
# if ep=lasterror then next fi;
# S1+ep[2]; S2+MLEstimatePamp; S3+ep[3];
# od:
# print(S1,S2,S3);
