/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun S. Song
 */

package edu.berkeley.smcsd;

import edu.berkeley.utility.HapConfig;
import edu.berkeley.utility.Haplotype;
import edu.berkeley.utility.Haplotype.GeneticTypeMultiplicity;
import edu.berkeley.utility.LogSum;

// forward and backward probabilities modified to be computed in O(L*n*d)
public final class DecodeLinear<H extends Haplotype, C extends HapConfig<H>> {
	
	private final CoreLinear core;
	private final int d;
	private final int nTrunk;
	
	private H haplotype;
	private C configuration;
	private final int L;
	
	private double[][][] forwardLogVals;
	private double[][][] backwardLogVals;
	
	// need to store forward and backward vals with the haplotypes marginalized out
	private double[][] forwardLogMargTimeVals;
	private double[][] backwardLogMargTimeVals;
	
	// need to store the ancient time vals for the forward and backward probabilities
	private double[][] forwardLogAncientTimeVals;
	private double[][] backwardLogAncientTimeVals;
	
	// need to store the forward coal vals
	private double[][] forwardLogCoalVals;
	
	private double totalLogProb;
	
	public DecodeLinear(CoreLinear core, H haplotype, C configuration) {
		this.core = core;
		this.d = core.numIntervals();
		this.nTrunk = core.nTrunk();
	
		this.haplotype = haplotype;
		this.configuration = configuration;
		this.L = haplotype.getNumLoci();
		
		this.forwardLogVals = new double[this.L][this.d][this.nTrunk];
		this.backwardLogVals = new double[this.L][this.d][this.nTrunk];
		
		this.forwardLogMargTimeVals = new double[this.L][this.d];
		this.backwardLogMargTimeVals = new double[this.L][this.d];
		
		this.forwardLogAncientTimeVals = new double[this.L][this.d];
		this.backwardLogAncientTimeVals = new double[this.L][this.d];
		
		this.forwardLogCoalVals = new double[this.L][this.d];
		
		this.totalLogProb = 0; // initialize to 0 for now
	}
	
	public double computeForwardBackward() {
		
		GeneticTypeMultiplicity<H>[] hapIdxMap = this.configuration.getHapIdxMap();
		
		// this initializes the forward and backward probs and all their helper arrays
		computeNewForwardsProbabilities(hapIdxMap);
		computeNewBackwardsProbabilities(hapIdxMap);
		
		// initialize the total log probability
		LogSum totalProbCalc = new LogSum(this.nTrunk * this.d);
		int lastLocus = this.L-1;
		
		for (int tIndex = 0; tIndex < this.d; tIndex++) {
			for (int hapIndex = 0; hapIndex < this.nTrunk; hapIndex++) {
				totalProbCalc.addLogSummand(this.forwardLogVals[lastLocus][tIndex][hapIndex] + this.backwardLogVals[lastLocus][tIndex][hapIndex]);
			}
		}

		this.totalLogProb = totalProbCalc.retrieveLogSum();
		return this.totalLogProb;
	}
	
	// getters for forward and backward
	public double[][][] getForwardLogProbs() { return this.forwardLogVals; }
	public double[][][] getBackwardLogProbs() { return this.backwardLogVals; }
	
	// --------------------------------------
	// NEW EXPECTED EVENTS TESTING FUNCTIONS
	//---------------------------------------
	
	// 1) compute the number of reco diff coal same events by locus
	public double[][] computeLogRecoDiffCoalSameByLocus() {
	
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[][] logPosteriorRecoDiffCoalSame = new double[this.L-1][this.d];
		
		for (int locus=0; locus < this.L-1; locus++) {
			for (int tIdx = 0; tIdx < this.d; tIdx++) {
				
				if (tIdx < this.d-1) {
					logPosteriorRecoDiffCoalSame[locus][tIdx] = this.forwardLogAncientTimeVals[locus][tIdx+1] + this.core.getLogRecoDiffCoalSame(tIdx) + this.backwardLogMargTimeVals[locus+1][tIdx] - this.totalLogProb;
				} else {
					logPosteriorRecoDiffCoalSame[locus][tIdx] = Math.log(0);
				}
			}
		}
		return logPosteriorRecoDiffCoalSame;
	}
	
	// 2) compute the number of reco diff coal later events (NOTE: this is 0 for the last time interval, cannot coalesce later), by locus
	public double[][] computeLogRecoDiffCoalLaterByLocus() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[][] logRecoDiffCoalLaterByLocus = new double[this.L-1][this.d];
		
		for (int locus=0; locus < this.L-1; locus++) {
			for (int tIdx = 0; tIdx < this.d; tIdx++) {
				
				if (tIdx < this.d-1) {
					logRecoDiffCoalLaterByLocus[locus][tIdx] = this.forwardLogAncientTimeVals[locus][tIdx+1] + this.core.getLogRecoDiffCoalLater(tIdx) + this.backwardLogAncientTimeVals[locus+1][tIdx+1] - this.totalLogProb;
				} else {
					logRecoDiffCoalLaterByLocus[locus][tIdx] = Math.log(0);
				}
			}
		}
		return logRecoDiffCoalLaterByLocus;
	}
	
	// 3) compute the number of reco same coal same events by locus
	public double[][] computeLogRecoSameCoalSameByLocus() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[][] logRecoSameCoalSameByLocus = new double[this.L-1][this.d];
		
		for (int locus=0; locus < this.L-1; locus++) {
			for (int tIdx = 0; tIdx < this.d; tIdx++) {
				logRecoSameCoalSameByLocus[locus][tIdx] = this.forwardLogMargTimeVals[locus][tIdx] + this.core.getLogRecoSameCoalSame(tIdx) + this.backwardLogMargTimeVals[locus+1][tIdx] - this.totalLogProb;
			}
		}
		return logRecoSameCoalSameByLocus;
	}
	
	// 4) compute the number of reco same coal later events (NOTE: this is 0 for the last time interval, cannot coalesce later), by locus
	public double[][] computeLogRecoSameCoalLaterByLocus() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[][] logRecoSameCoalLaterByLocus = new double[this.L-1][this.d];
		
		for (int locus=0; locus < this.L-1; locus++) {
			for (int tIdx = 0; tIdx < this.d; tIdx++) {
				
				if (tIdx < this.d-1) {
					logRecoSameCoalLaterByLocus[locus][tIdx] = this.forwardLogMargTimeVals[locus][tIdx] + this.core.getLogRecoSameCoalLater(tIdx) + this.backwardLogAncientTimeVals[locus+1][tIdx+1] - this.totalLogProb;
				} else {
					logRecoSameCoalLaterByLocus[locus][tIdx] = Math.log(0);
				}
			}
		}
		return logRecoSameCoalLaterByLocus;
	}
	
	// 5) compute the number of no reco events by locus
	public double[][] computeLogNoRecoByLocus() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		GeneticTypeMultiplicity<H>[] hapIdxMap = this.configuration.getHapIdxMap();
		double[][] logNoRecoByLocus = new double[this.L-1][this.d];
		
		for (int locus=0; locus < L-1; locus++) {
			int dstAllele = this.haplotype.getAllele(locus+1);
			for (int tIdx = 0; tIdx < this.d; tIdx++) {
				
				LogSum hapSum = new LogSum(this.nTrunk);
				for (int hapIdx = 0; hapIdx < this.nTrunk; hapIdx++) {
					int srcAllele = hapIdxMap[hapIdx].geneticType.getAllele(locus+1);
					hapSum.addLogSummand(this.forwardLogVals[locus][tIdx][hapIdx] + this.core.getLogNoReco(tIdx) + this.core.getLogEmission(srcAllele, dstAllele, tIdx) + this.backwardLogVals[locus+1][tIdx][hapIdx] - this.totalLogProb);
				}
				logNoRecoByLocus[locus][tIdx] = hapSum.retrieveLogSum();
			}
		}
		return logNoRecoByLocus;
	}
	
	// 6) compute the number of coal later events by locus
	public double[][] computeLogCoalNowByLocus() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[][] logCoalLaterByLocus = new double[this.L-1][this.d];
		
		for (int locus=0; locus < this.L-1; locus++) {
			for (int tIdx = 0; tIdx < this.d; tIdx++) {
				
				if (tIdx > 0) {
					logCoalLaterByLocus[locus][tIdx] = this.forwardLogCoalVals[locus][tIdx-1] + this.core.getLogCoalNow(tIdx) + this.backwardLogMargTimeVals[locus+1][tIdx] - this.totalLogProb;
				} else {
					logCoalLaterByLocus[locus][tIdx] = Math.log(0);
				}
			}
		}
		return logCoalLaterByLocus;
	}

	// 7) compute the number of not coal later events, by locus
	public double[][] computeLogCoalLaterByLocus() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[][] logNoCoalLaterByLocus = new double[this.L-1][this.d];
		
		for (int locus=0; locus < this.L-1; locus++) {
			for (int tIdx = 0; tIdx < this.d; tIdx++) {
				
				if (tIdx > 0 && tIdx < this.d-1) {
					logNoCoalLaterByLocus[locus][tIdx] = this.forwardLogCoalVals[locus][tIdx-1] + this.core.getLogCoalLater(tIdx) + this.backwardLogAncientTimeVals[locus+1][tIdx+1] - this.totalLogProb;
				} else {
					logNoCoalLaterByLocus[locus][tIdx] = Math.log(0);
				}
			}
		}
		return logNoCoalLaterByLocus;
	}

	// --------------------------
	// NEW EXPECTED EVENT COUNTS
	//---------------------------
	
	// compute the initial counts (first locus)
	public double[] computePosteriorInitialCounts() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[] initialCounts = new double[this.d];
		int locus = 0;
		
		for (int tIdx = 0; tIdx < this.d; tIdx++) {
			// sum over the haplotypes
			for (int hapIdx = 0; hapIdx < this.nTrunk; hapIdx++) {
				initialCounts[tIdx] += Math.exp(this.forwardLogVals[locus][tIdx][hapIdx] + this.backwardLogVals[locus][tIdx][hapIdx] - this.totalLogProb);
			}
		}
		return initialCounts;
	}
	
	// 1) compute the number of reco diff coal same events (NOTE: this is 0 for the last time interval)
	public double[] computePosteriorRecoDiffCoalSame() {
	
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[] hatRecoDiffCoalSame = new double[this.d];
		
		for (int tIdx = 0; tIdx < this.d; tIdx++) {
			for (int locus=0; locus < this.L-1; locus++) {
				
				// if tIdx = d-1, just adding on 0
				if (tIdx < this.d-1) {
					hatRecoDiffCoalSame[tIdx] += Math.exp(this.forwardLogAncientTimeVals[locus][tIdx+1] + this.core.getLogRecoDiffCoalSame(tIdx) + this.backwardLogMargTimeVals[locus+1][tIdx] - this.totalLogProb);
				}
			}
		}
		return hatRecoDiffCoalSame;
	}
	
	// 2) compute the number of reco diff coal later events (NOTE: this is 0 for the last time interval, cannot coalesce later)
	public double[] computePosteriorRecoDiffCoalLater() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[] hatRecoDiffCoalLater = new double[this.d];
		
		for (int tIdx = 0; tIdx < this.d; tIdx++) {
			for (int locus=0; locus < this.L-1; locus++) {
				
				// if tIdx = d-1, just adding on 0
				if (tIdx < this.d-1) {
					hatRecoDiffCoalLater[tIdx] += Math.exp(this.forwardLogAncientTimeVals[locus][tIdx+1] + this.core.getLogRecoDiffCoalLater(tIdx) + this.backwardLogAncientTimeVals[locus+1][tIdx+1] - this.totalLogProb);
				}
			}
		}
		return hatRecoDiffCoalLater;
	}
	
	// 3) compute the number of reco same coal same events
	public double[] computePosteriorRecoSameCoalSame() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[] hatRecoSameCoalSame = new double[this.d];
		
		for (int tIdx = 0; tIdx < this.d; tIdx++) {
			for (int locus=0; locus < this.L-1; locus++) {
				hatRecoSameCoalSame[tIdx] += Math.exp(this.forwardLogMargTimeVals[locus][tIdx] + this.core.getLogRecoSameCoalSame(tIdx) + this.backwardLogMargTimeVals[locus+1][tIdx] - this.totalLogProb);
			}
		}
		return hatRecoSameCoalSame;
	}
		
	// 4) compute the number of reco same coal later events (NOTE: this is 0 for the last time interval, cannot coalesce later)
	public double[] computePosteriorRecoSameCoalLater() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[] hatRecoSameCoalLater = new double[this.d];
		
		for (int tIdx = 0; tIdx < this.d; tIdx++) {
			for (int locus=0; locus < this.L-1; locus++) {
				
				// if tIdx = d-1, just adding on 0
				if (tIdx < this.d-1) {
					hatRecoSameCoalLater[tIdx] += Math.exp(this.forwardLogMargTimeVals[locus][tIdx] + this.core.getLogRecoSameCoalLater(tIdx) + this.backwardLogAncientTimeVals[locus+1][tIdx+1] - this.totalLogProb);
				}
			}
		}
		return hatRecoSameCoalLater;
	}
		
	// 5) compute the number of no reco events
	public double[] computePosteriorNoReco() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		GeneticTypeMultiplicity<H>[] hapIdxMap = this.configuration.getHapIdxMap();
		double[] hatNoReco = new double[this.d];
		
		for (int tIdx = 0; tIdx < this.d; tIdx++) {
			for (int locus=0; locus < this.L-1; locus++) {
				int dstAllele = this.haplotype.getAllele(locus+1);
				
				// sum over the haplotypes
				for (int hapIdx = 0; hapIdx < this.nTrunk; hapIdx++) {
					int srcAllele = hapIdxMap[hapIdx].geneticType.getAllele(locus+1);
					hatNoReco[tIdx] += Math.exp(this.forwardLogVals[locus][tIdx][hapIdx] + this.core.getLogNoReco(tIdx) + this.core.getLogEmission(srcAllele, dstAllele, tIdx) + this.backwardLogVals[locus+1][tIdx][hapIdx] - this.totalLogProb);
				}
			}
		}
		return hatNoReco;
	}
		
	// 6) compute the number of coal later events
	public double[] computePosteriorCoalNow() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[] hatCoalLater = new double[this.d];
		
		for (int tIdx = 0; tIdx < this.d; tIdx++) {
			for (int locus=0; locus < this.L-1; locus++) {
				
				// if tIdx = 0, just adding on 0
				if (tIdx > 0) {
					hatCoalLater[tIdx] += Math.exp(this.forwardLogCoalVals[locus][tIdx-1] + this.core.getLogCoalNow(tIdx) + this.backwardLogMargTimeVals[locus+1][tIdx] - this.totalLogProb);
				}
			}
		}
		return hatCoalLater;
	}
	
	// 7) compute the number of not coal later events
	public double[] computePosteriorCoalLater() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[] hatNoCoalLater = new double[this.d];
		
		for (int tIdx = 0; tIdx < this.d; tIdx++) {
			for (int locus=0; locus < this.L-1; locus++) {
				
				// if tIdx = 0 or d-1, just adding on 0
				if (tIdx > 0 && tIdx < this.d-1) {
					hatNoCoalLater[tIdx] += Math.exp(this.forwardLogCoalVals[locus][tIdx-1] + this.core.getLogCoalLater(tIdx) + this.backwardLogAncientTimeVals[locus+1][tIdx+1] - this.totalLogProb);
				}
			}
		}
		return hatNoCoalLater;
	}
	
	// compute the expected number of times we see allele a in time bin i
	public double[][] computePosteriorEmissions() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		
		double[][] E = new double[this.d][this.core.numAlleles()];
		LogSum expectedSum = new LogSum(this.L * this.nTrunk);
		
		for (int iTime=0; iTime < this.d; iTime++) {
			for (int a=0; a < this.core.numAlleles(); a++) {
				expectedSum.reset();
				
				for (int locus=0; locus < this.L; locus++) {
					if (this.haplotype.getAllele(locus) == a) {
						for (int iHap=0; iHap < this.nTrunk; iHap++) {
							expectedSum.addLogSummand(this.forwardLogVals[locus][iTime][iHap] + this.backwardLogVals[locus][iTime][iHap]);
						}
					}
				}
				
				double retrieveSum = expectedSum.retrieveLogSum();
				E[iTime][a] = Math.exp(retrieveSum - this.totalLogProb);
			}
		}
		
		return E;
	}
	
	// for calculating the expected segments, we need the expected number of self-transitions
	// compute the number of times we move from time bin i to time bin j
	public double[] computeExpectedSelfTrans() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		GeneticTypeMultiplicity<H>[] hapIdxMap = this.configuration.getHapIdxMap();
		
		// initialize log sum terms
		LogSum forwardOnlyCalc = new LogSum(this.nTrunk);
		LogSum backwardEmCalc  = new LogSum(this.nTrunk);
		LogSum forwardBackward = new LogSum(this.nTrunk);
		LogSum expectedSum     = new LogSum(this.L-1);
		
		// initialize reused variables
		double n_ihap    = 1; // double so we don't get 0 when we divide by n below, hap multiplicity always 1 in our case
		int srcAllele    = 0; // initialize to 0 (never used)
		int dstAllele    = 0; // initialize to 0 (never used)
		double noRecTerm = 0; // initialize to 0 (never used)
		
		double[] hatSelfTrans = new double[this.d];
		for (int tIdx = 0; tIdx < this.d; tIdx++) {
				
			expectedSum.reset();
					
			for (int locus = 0; locus < this.L-1; locus++) {
				dstAllele = this.haplotype.getAllele(locus+1);
				
				forwardOnlyCalc.reset();
				backwardEmCalc.reset();
				forwardBackward.reset();
				
				// for each haplotype
				for (int hapIdx=0; hapIdx < this.nTrunk; hapIdx++) {
					srcAllele = hapIdxMap[hapIdx].geneticType.getAllele(locus+1);
					
					forwardOnlyCalc.addLogSummand(this.forwardLogVals[locus][tIdx][hapIdx]);
					backwardEmCalc.addLogSummand(this.core.getLogEmission(srcAllele, dstAllele, tIdx) + this.backwardLogVals[locus+1][tIdx][hapIdx] + Math.log(n_ihap/this.nTrunk));
					
					forwardBackward.addLogSummand(this.forwardLogVals[locus][tIdx][hapIdx] + this.core.getLogEmission(srcAllele, dstAllele, tIdx) + this.backwardLogVals[locus+1][tIdx][hapIdx]);
				}
				
				double recTerm = forwardOnlyCalc.retrieveLogSum() + backwardEmCalc.retrieveLogSum() + this.core.getLogSelfTrans(tIdx);
				noRecTerm = this.core.getLogNoReco(tIdx) + forwardBackward.retrieveLogSum();
				expectedSum.addLogSummand(LogSum.computePairLogSum(recTerm, noRecTerm));
			}
				
			hatSelfTrans[tIdx] = Math.exp(expectedSum.retrieveLogSum() - this.totalLogProb);
		}
		return hatSelfTrans;
	}
	
	// -----------------------------------
	// FORWARD AND BACKWARD PROBABILITIES
	//------------------------------------
	
	// new forward probabilities 
	private void computeNewForwardsProbabilities(GeneticTypeMultiplicity<H>[] hapIdxMap) {
	
		double lognTrunk = Math.log(this.nTrunk);
		
		LogSum margTimesCalc = new LogSum(this.nTrunk);
		LogSum recoTermCalc = new LogSum(3);
		double logRecoTerm = 0;
		
		for (int locus = 0; locus < this.L; locus++) {
			int dstAllele = this.haplotype.getAllele(locus);

			// first locus, initialization, this is the same as before (only transitions have been modified)
			if (locus == 0) {
				for (int tIdx = 0; tIdx < this.d; tIdx++) {
				
					margTimesCalc.reset();
					for (int hapIdx = 0; hapIdx < this.nTrunk; hapIdx++) {
						int srcAllele = hapIdxMap[hapIdx].geneticType.getAllele(locus);
						double logMutationProb = this.core.getLogEmission(srcAllele, dstAllele, tIdx);
						
						this.forwardLogVals[locus][tIdx][hapIdx] = logMutationProb + this.core.getLogInitialWeight(tIdx) - lognTrunk;
						margTimesCalc.addLogSummand(this.forwardLogVals[locus][tIdx][hapIdx]);
					}
					this.forwardLogMargTimeVals[locus][tIdx] = margTimesCalc.retrieveLogSum();
				}
				
			// all other loci
			} else {
				
				for (int tIdx = 0; tIdx < this.d; tIdx++) {
					margTimesCalc.reset();
					
					// first time index
					if (tIdx == 0) {
						
						// recombination prob
						logRecoTerm = LogSum.computePairLogSum(this.core.getLogRecoSameCoalSame(tIdx) + this.forwardLogMargTimeVals[locus-1][tIdx], this.core.getLogRecoDiffCoalSame(tIdx) + this.forwardLogAncientTimeVals[locus-1][tIdx+1]) - lognTrunk;
						
						// update currCoalVals
						double coalTerm1 = this.core.getLogRecoSameCoalLater(tIdx) + this.forwardLogMargTimeVals[locus-1][tIdx];
						double coalTerm2 = this.core.getLogRecoDiffCoalLater(tIdx) + this.forwardLogAncientTimeVals[locus-1][tIdx+1];
						this.forwardLogCoalVals[locus-1][tIdx] = LogSum.computePairLogSum(coalTerm1, coalTerm2); // not dividing by n here (will do later)
						
					// middle time indices
					} else if (tIdx < this.d-1) {
						
						// recombination prob
						recoTermCalc.reset();
						recoTermCalc.addLogSummand(this.core.getLogRecoSameCoalSame(tIdx) + this.forwardLogMargTimeVals[locus-1][tIdx]);
						recoTermCalc.addLogSummand(this.core.getLogRecoDiffCoalSame(tIdx) + this.forwardLogAncientTimeVals[locus-1][tIdx+1]);
						recoTermCalc.addLogSummand(this.core.getLogCoalNow(tIdx) + this.forwardLogCoalVals[locus-1][tIdx-1]);
						logRecoTerm = recoTermCalc.retrieveLogSum() - lognTrunk;
						
						// update currCoalVals
						double coalTerm1 = this.core.getLogRecoSameCoalLater(tIdx) + this.forwardLogMargTimeVals[locus-1][tIdx];
						double coalTerm2 = this.core.getLogRecoDiffCoalLater(tIdx) + this.forwardLogAncientTimeVals[locus-1][tIdx+1];
						double intermediate = LogSum.computePairLogSum(coalTerm1, coalTerm2);
						double coalTerm3 = this.core.getLogCoalLater(tIdx) + this.forwardLogCoalVals[locus-1][tIdx-1];
						this.forwardLogCoalVals[locus-1][tIdx] = LogSum.computePairLogSum(intermediate, coalTerm3);
					
					// end time index
					} else {
						
						// recombination prob
						logRecoTerm = LogSum.computePairLogSum(this.core.getLogRecoSameCoalSame(tIdx) + this.forwardLogMargTimeVals[locus-1][tIdx], this.forwardLogCoalVals[locus-1][tIdx-1]) - lognTrunk;
						
						// update currCoalVals (to 0 in this case, cannot coalesce more anciently in the last time interval)
						this.forwardLogCoalVals[locus-1][tIdx] = Math.log(0);
					}
						
					for (int hapIdx = 0; hapIdx < this.nTrunk; hapIdx++) {
						int srcAllele = hapIdxMap[hapIdx].geneticType.getAllele(locus);
						
						// mutation prob
						double logMutationProb = this.core.getLogEmission(srcAllele, dstAllele, tIdx);
						
						// transition prob
						double logNoRecoTerm = this.core.getLogNoReco(tIdx) + this.forwardLogVals[locus-1][tIdx][hapIdx];
						double logTransProb = LogSum.computePairLogSum(logNoRecoTerm, logRecoTerm);
						
						this.forwardLogVals[locus][tIdx][hapIdx] = logMutationProb + logTransProb;
						margTimesCalc.addLogSummand(this.forwardLogVals[locus][tIdx][hapIdx]);
					}
						
					// update marginal time vals
					this.forwardLogMargTimeVals[locus][tIdx] = margTimesCalc.retrieveLogSum();
				}
			}
			
			// update the ancient time vals in a different loop
			double runningAncientLogSum = Double.NEGATIVE_INFINITY;
			for (int i = this.d-1; i >= 0; i--) {
				runningAncientLogSum = LogSum.computePairLogSum(runningAncientLogSum, this.forwardLogMargTimeVals[locus][i]);
				this.forwardLogAncientTimeVals[locus][i] = runningAncientLogSum;
			}
		}
	}
	
	// new backward probabilities 
	private void computeNewBackwardsProbabilities(GeneticTypeMultiplicity<H>[] hapIdxMap) {
				
		double lognTrunk = Math.log(this.nTrunk);
		
		LogSum margTimesCalc = new LogSum(this.nTrunk);
		LogSum recoTermCalc = new LogSum(3);
		double logRecoTerm = 0;
		double[] currCoalVals = new double[this.d];
		
		for (int locus = this.L-1; locus >= 0; locus--) {
			int dstAlleleQ = this.haplotype.getAllele(locus);
			
			// first locus, initialization, this is the same as before (only transitions have been modified)
			if (locus == this.L-1) {
				for (int tIdx = 0; tIdx < this.d; tIdx++) {
					
					margTimesCalc.reset();
					for (int hapIdx = 0; hapIdx < this.nTrunk; hapIdx++) {
						int srcAlleleQ = hapIdxMap[hapIdx].geneticType.getAllele(locus);
						this.backwardLogVals[locus][tIdx][hapIdx] = 0;
						
						double logMutationProb = this.core.getLogEmission(srcAlleleQ, dstAlleleQ, tIdx);
						margTimesCalc.addLogSummand(logMutationProb + this.backwardLogVals[locus][tIdx][hapIdx] - lognTrunk);
					}
					this.backwardLogMargTimeVals[locus][tIdx] = margTimesCalc.retrieveLogSum();
				}
			
			// all other loci
			} else {
				int dstAllele = this.haplotype.getAllele(locus+1);
				
				for (int tIdx = 0; tIdx < this.d; tIdx++) {
					margTimesCalc.reset();
				
					// first time index
					if (tIdx == 0) {
						
						// recombination prob
						logRecoTerm = LogSum.computePairLogSum(this.core.getLogRecoSameCoalSame(tIdx) + this.backwardLogMargTimeVals[locus+1][tIdx], this.core.getLogRecoSameCoalLater(tIdx) + this.backwardLogAncientTimeVals[locus+1][tIdx+1]);
						
						// update currCoalVals
						currCoalVals[tIdx] = LogSum.computePairLogSum(this.core.getLogRecoDiffCoalLater(tIdx) + this.backwardLogAncientTimeVals[locus+1][tIdx+1], this.core.getLogRecoDiffCoalSame(tIdx) + this.backwardLogMargTimeVals[locus+1][tIdx]);
					
					// middle time indices
					} else if (tIdx < this.d-1) {
						
						// recombination prob
						recoTermCalc.reset();
						recoTermCalc.addLogSummand(this.core.getLogRecoSameCoalSame(tIdx) + this.backwardLogMargTimeVals[locus+1][tIdx]);
						recoTermCalc.addLogSummand(currCoalVals[tIdx-1]);
						recoTermCalc.addLogSummand(this.core.getLogRecoSameCoalLater(tIdx) + this.backwardLogAncientTimeVals[locus+1][tIdx+1]);
						logRecoTerm = recoTermCalc.retrieveLogSum();
						
						// update currCoalVals
						double coalValsSum = LogSum.computePairLogSum(this.core.getLogRecoDiffCoalLater(tIdx) + this.backwardLogAncientTimeVals[locus+1][tIdx+1], this.core.getLogRecoDiffCoalSame(tIdx) + this.backwardLogMargTimeVals[locus+1][tIdx]);
						currCoalVals[tIdx] = LogSum.computePairLogSum(currCoalVals[tIdx-1], coalValsSum);
					
					// end time index
					} else {
						
						// recombination prob
						logRecoTerm = LogSum.computePairLogSum(this.core.getLogRecoSameCoalSame(tIdx) + this.backwardLogMargTimeVals[locus+1][tIdx], currCoalVals[tIdx-1]);
					}
					
					for (int hapIdx = 0; hapIdx < this.nTrunk; hapIdx++) {
						int srcAllele = hapIdxMap[hapIdx].geneticType.getAllele(locus+1);
						
						// mutation prob
						double logMutationProb = this.core.getLogEmission(srcAllele, dstAllele, tIdx);
						
						// no recombination prob
						double logNoRecoTerm = this.core.getLogNoReco(tIdx) + this.backwardLogVals[locus+1][tIdx][hapIdx] + logMutationProb;
						
						this.backwardLogVals[locus][tIdx][hapIdx] = LogSum.computePairLogSum(logNoRecoTerm, logRecoTerm);
						
						// update the margTimeCalc
						int srcAlleleQ = hapIdxMap[hapIdx].geneticType.getAllele(locus);
						double logMutationProbQ = this.core.getLogEmission(srcAlleleQ, dstAlleleQ, tIdx);
						margTimesCalc.addLogSummand(logMutationProbQ + this.backwardLogVals[locus][tIdx][hapIdx] - lognTrunk);
					}
						
					// update marginal time vals
					this.backwardLogMargTimeVals[locus][tIdx] = margTimesCalc.retrieveLogSum();
				}	
			}	
				
			// update the ancient time vals in a different loop
			this.backwardLogAncientTimeVals[locus][this.d-1] = this.core.getLogCoalNow(this.d-1) + this.backwardLogMargTimeVals[locus][this.d-1];
			for (int i = this.d-2; i >= 0; i--) {
				double ancientTerm1 = this.core.getLogCoalNow(i) + this.backwardLogMargTimeVals[locus][i];
				double ancientTerm2 = Math.log(1 - Math.exp(this.core.getLogCoalNow(i))) + this.backwardLogAncientTimeVals[locus][i+1];
				this.backwardLogAncientTimeVals[locus][i] = LogSum.computePairLogSum(ancientTerm1, ancientTerm2);
			}
		}
	}
}
