/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.smcsd;

import edu.berkeley.utility.Discretization;
import edu.berkeley.utility.LogSum;
import edu.berkeley.utility.ParamSet;
import edu.berkeley.utility.Discretization.DiscretizationInterval;

// this class implements the new core that is linear in d (calls CoreLinearOneSize for each time interval)
public class CoreLinear {
	
	private final ParamSet pSet;
	private final int nTrunk;
	private final DiscretizationInterval[] tPoints;
	private final int d;
	
	// new event probabilities
	
	// the probability that a locus with absorption state i > j recombines during time interval j and then coalesces immediately during interval i
	private final double[] logRecoDiffCoalSameArray;
	
	// the probability that a locus with absorption state i > j recombines during interval j but then escapes interval j to coalesce more anciently
	private final double[] logRecoDiffCoalLaterArray;
	
	// the probability that a locus with absorption state j recombines in state j and then immediately coalesces in state j
	private final double[] logRecoSameCoalSameArray;
	
	// the probability that a locus with absorption state j recombines in interval j and does not coalesce until some later interval
	private final double[] logRecoSameCoalLaterArray;
	
	// the probability that a site absorbing in state i does not recombine, such that the adjacent locus absorbs in state i as well
	private final double[] logNoRecoArray;
	
	// the probability that a site has recombined more recently than j, then gets absorbed into state j
	private final double[] logCoalNowArray;
	
	// the probability that a site had recombined more recently than j, but escapes j to coalesce more anciently
	private final double[] logCoalLaterArray;
	
	// indexed by time-index, allele (src), allele (dst)
	private final double[][][] logMutMatrices;
	
	// probability of a self-transition from j to j
	private double[] logSelfTrans;
	
	// adding option for printing expected segments
	public CoreLinear(ParamSet pSet, int nTrunk, double[] times, double[] sizes, double[] descendingAscendingTable, boolean printExpectedSegs) {

		this.pSet = pSet;
		this.nTrunk = nTrunk;
		double[] nbar = Discretization.computeNbarArray(times, sizes, nTrunk, descendingAscendingTable);
		this.tPoints = Discretization.makeDiscretizationPoints(times, sizes, this.nTrunk, nbar);
		this.d = this.tPoints.length;
		
		// compute new transition matrices
		this.logRecoDiffCoalSameArray = new double[this.d];
		this.logRecoDiffCoalLaterArray = new double[this.d];
		this.logRecoSameCoalSameArray = new double[this.d];
		this.logRecoSameCoalLaterArray = new double[this.d];
		this.logNoRecoArray = new double[this.d];
		this.logCoalNowArray = new double[this.d];
		this.logCoalLaterArray = new double[this.d];
		this.logMutMatrices = new double[this.d][][];
		
		// use a different core for each time interval
		for (int j = 0; j < this.d; j++) {
			CoreLinearOneSize core = new CoreLinearOneSize(this.pSet, this.tPoints[j], sizes[j], nbar[j]);
			
			this.logRecoDiffCoalSameArray[j] = core.getLogRecoDiffCoalSame();
			this.logRecoDiffCoalLaterArray[j] = core.getLogRecoDiffCoalLater();
			this.logRecoSameCoalSameArray[j] = core.getLogRecoSameCoalSame();
			this.logRecoSameCoalLaterArray[j] = core.getLogRecoSameCoalLater();
			this.logNoRecoArray[j] = core.getLogNoReco();
			this.logCoalNowArray[j] = core.getLogCoalNow();
			this.logCoalLaterArray[j] = core.getLogCoalLater();
			
			this.logMutMatrices[j] = core.getLogMutMatrix();
		}
		
		if (printExpectedSegs) {
			this.logSelfTrans = this.computeLogSelfTrans(); // we should do this after we've initialized the other probabilities
		}
	}
	
	//----------------
	// PUBLIC GETTERS
	//----------------
	
	public int numIntervals() { return this.d; }
	public int nTrunk() { return this.nTrunk; }
	public int numAlleles() { return this.pSet.numAlleles(); }
	
	// quadrature weight getters
	public double getLogInitialWeight(int tIdx) {
		return Math.log(this.tPoints[tIdx].weight);
	}
	
	// event probability getters
	public double getLogRecoDiffCoalSame(int tIdx) {
		return this.logRecoDiffCoalSameArray[tIdx];
	}
	public double getLogRecoDiffCoalLater(int tIdx) {
		return this.logRecoDiffCoalLaterArray[tIdx];
	}
	public double getLogRecoSameCoalSame(int tIdx) {
		return this.logRecoSameCoalSameArray[tIdx];
	}
	public double getLogRecoSameCoalLater(int tIdx) {
		return this.logRecoSameCoalLaterArray[tIdx];
	}
	public double getLogNoReco(int tIdx) {
		return this.logNoRecoArray[tIdx];
	}
	public double getLogCoalNow(int tIdx) {
		return this.logCoalNowArray[tIdx];
	}
	public double getLogCoalLater(int tIdx) {
		return this.logCoalLaterArray[tIdx];
	}
	public double getLogSelfTrans(int tIdx) {
		return this.logSelfTrans[tIdx];
	}
	
	// emission getter, modifying to handle missing data (emit 1)
	public double getLogEmission(int srcAllele, int dstAllele, int tIdx) {
		// if dstAllele is unknown (N), probability is 1
		int unknown = this.pSet.numAlleles(); // last allele always represents unknown base
		if (dstAllele == unknown) {
			return 0; // log(1)
							
		// otherwise if srcAllele is unknown, sum over all possible values for srcAllele
		} else if (srcAllele == unknown) {
			LogSum mutSum = new LogSum(this.pSet.numAlleles());
			for (int trunkAllele = 0; trunkAllele < this.pSet.numAlleles(); trunkAllele++) {
				mutSum.addLogSummand(this.logMutMatrices[tIdx][trunkAllele][dstAllele] + this.pSet.getLogStationaryProb()[trunkAllele]);
			}
			return mutSum.retrieveLogSum();
		}
						
		// if we know both srcAllele and dstAllele, just return from the matrix
		return this.logMutMatrices[tIdx][srcAllele][dstAllele];
	}
	
	// also compute the probability of a self transition (for expected segments)
	private double[] computeLogSelfTrans() {
		
		double[] S = new double[this.d];

		for (int tIdx=0; tIdx < this.d; tIdx++) {
			double z = Math.exp(this.getLogRecoSameCoalSame(tIdx));

			// compute the sum
			double prod = 1;
			for (int k=tIdx-1; k >= 0; k--) {
				z += Math.exp(this.getLogRecoDiffCoalLater(k)) * prod * Math.exp(this.getLogCoalNow(tIdx));
				prod *= Math.exp(this.getLogCoalLater(k));
			}

			S[tIdx] = Math.log(z);
		}
		return S;
	}
}
