/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun S. Song
 * 
 * based on:
 * Copyright (c) 2011, Joshua Paul, Matthias SteinrŸcken, Yun Song
 */

package edu.berkeley.smcsd;

import java.util.Map;

import edu.berkeley.smcsd.CoreQuad;

import edu.berkeley.utility.HapConfig;
import edu.berkeley.utility.Haplotype;
import edu.berkeley.utility.Haplotype.GeneticTypeMultiplicity;
import edu.berkeley.utility.LogSum;
import edu.berkeley.utility.Utility;

import edu.berkeley.utility.Utility.BestViterbi;

// original decoding class, with variable population size
public final class DecodeQuad<H extends Haplotype, C extends HapConfig<H>> {
	
	private final CoreQuad core;
	private final int d;
	private final int nTrunk;

	private H haplotype;
	private C configuration;
	private final int L;
	
	private double[][][] forwardPVals;
	private double[][][] backwardPVals;
	
	private double[][] currQVals;
	private double[] currRVals;
	private double[] currTVals;
	
	private double totalLogProb;
	
	public DecodeQuad(CoreQuad core, H haplotype, C configuration) {
		this.core = core;
		this.d = core.numIntervals();
		this.nTrunk = core.nTrunk();
		
		this.haplotype = haplotype;
		this.configuration = configuration;
		this.L = haplotype.getNumLoci();
		
		this.forwardPVals = new double[this.L][this.d][this.nTrunk];
		this.backwardPVals = new double[this.L][this.d][this.nTrunk];
	
		this.currQVals = new double[2][this.d];
		this.currRVals = new double[this.d];
		this.currTVals = new double[this.d];
		
		this.totalLogProb = 0; // initialize to 0 for now
	}
	
	public double computeForwardBackward() {
		
		GeneticTypeMultiplicity<H>[] hapIdxMap = this.configuration.getHapIdxMap();
		
		computeForwardsProbabilities(hapIdxMap);
		computeBackwardsProbabilities(hapIdxMap);

		// initialize the total log probability
		LogSum totalProbCalc = new LogSum(this.nTrunk * this.d);
		int lastLocus = this.L-1;
		for (int tIndex = 0; tIndex < this.d; tIndex++) {
			for (int hapIndex = 0; hapIndex < this.nTrunk; hapIndex++) {
				totalProbCalc.addLogSummand(this.forwardPVals[lastLocus][tIndex][hapIndex] + this.backwardPVals[lastLocus][tIndex][hapIndex]);
			}
		}

		this.totalLogProb = totalProbCalc.retrieveLogSum();
		return this.totalLogProb;
	}
	
	// getters for forward and backward
	public double[][][] getForwardLogProbs() { return this.forwardPVals; }
	public double[][][] getBackwardLogProbs() { return this.backwardPVals; }
	
	// -----------------------------------
	// EXPECTED TRANSITIONS AND EMISSIONS
	//------------------------------------
	
	// compute the initial counts (first locus)
	public double[] computeExpectedInitialCounts() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		double[] initialCounts = new double[this.d];
		int locus = 0;
		
		for (int tIdx = 0; tIdx < this.d; tIdx++) {
			// sum over the haplotypes
			for (int hapIdx = 0; hapIdx < this.nTrunk; hapIdx++) {
				initialCounts[tIdx] += Math.exp(this.forwardPVals[locus][tIdx][hapIdx] + this.backwardPVals[locus][tIdx][hapIdx] - this.totalLogProb);
			}
		}
		return initialCounts;
	}
	
	// compute the number of times we move from time bin i to time bin j
	public double[][] computeExpectedTransitions() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		GeneticTypeMultiplicity<H>[] hapIdxMap = this.configuration.getHapIdxMap();
		
		// initialize log sum terms
		LogSum forwardOnlyCalc = new LogSum(this.nTrunk);
		LogSum backwardEmCalc  = new LogSum(this.nTrunk);
		LogSum forwardBackward = new LogSum(this.nTrunk);
		LogSum expectedSum   = new LogSum(this.L-1);
		
		// initialize reused variables
		double n_ihap    = 1; // double so we don't get 0 when we divide by n below, hap multiplicity always 1 in our case
		int srcAllele    = 0; // initialize to 0 (never used)
		int dstAllele    = 0; // initialize to 0 (never used)
		double noRecTerm = 0; // initialize to 0 (never used)
		
		double[][] A = new double[this.d][this.d];
		
		// for all TIME states i and j
		for (int iTime = 0; iTime < this.d; iTime++) {
			for (int jTime = 0; jTime < this.d; jTime++) {
				
				expectedSum.reset();
						
				for (int locus = 0; locus < this.L-1; locus++) {
					dstAllele = this.haplotype.getAllele(locus+1);
					
					forwardOnlyCalc.reset();
					backwardEmCalc.reset();
					forwardBackward.reset();
					
					// for each haplotype
					for (int hapIdx=0; hapIdx < this.nTrunk; hapIdx++) {
						srcAllele = hapIdxMap[hapIdx].geneticType.getAllele(locus+1);
						
						forwardOnlyCalc.addLogSummand(this.forwardPVals[locus][iTime][hapIdx]);
						backwardEmCalc.addLogSummand(this.core.getLogEmission(srcAllele, dstAllele, jTime) + this.backwardPVals[locus+1][jTime][hapIdx] + Math.log(n_ihap/this.nTrunk));
						
						if (iTime == jTime) {
							forwardBackward.addLogSummand(this.forwardPVals[locus][iTime][hapIdx] + this.core.getLogEmission(srcAllele, dstAllele, jTime) + this.backwardPVals[locus+1][jTime][hapIdx]);
						}
					}
					
					double recTerm = forwardOnlyCalc.retrieveLogSum() + backwardEmCalc.retrieveLogSum() + this.core.getLogRecombinationTransition(iTime, jTime);
					if (iTime == jTime) {
						noRecTerm = this.core.getLogNoRecombinationTransition(iTime) + forwardBackward.retrieveLogSum();
						expectedSum.addLogSummand(LogSum.computePairLogSum(recTerm, noRecTerm));
					} else {
						expectedSum.addLogSummand(recTerm);
					}
				}
					
				A[iTime][jTime] = Math.exp(expectedSum.retrieveLogSum() - this.totalLogProb);
			}
		}
		
		return A;
	}
	
	// compute the expected number of times we see allele a in time bin i
	public double[][] computeExpectedEmissions() {
		
		assert (this.totalLogProb != 0); // make sure we have initialized forward/backward/logProb
		
		double[][] E = new double[this.d][this.core.numAlleles()];
		LogSum expectedSum = new LogSum(this.L * this.nTrunk);
		
		for (int iTime=0; iTime < this.d; iTime++) {
			for (int a=0; a < this.core.numAlleles(); a++) {
				expectedSum.reset();
				
				for (int locus=0; locus < this.L; locus++) {
					if (this.haplotype.getAllele(locus) == a) {
						for (int iHap=0; iHap < this.nTrunk; iHap++) {
							expectedSum.addLogSummand(this.forwardPVals[locus][iTime][iHap] + this.backwardPVals[locus][iTime][iHap]);
						}
					}
				}
				
				double retrieveSum = expectedSum.retrieveLogSum();
				E[iTime][a] = Math.exp(retrieveSum - this.totalLogProb);
			}
		}
		
		return E;
	}
	
	// -------------------
	// DECODING FUNCTIONS
	//--------------------
	
	// function for printing a decoding of just posterior mean times (no haps)
	public static void printPosteriorMeanTime(int hapIdx, double[] absorptionTimes) {
		System.out.println("decoding hap startLocus endLocus absorptionTime");
		for (int curIdx = 0; curIdx < absorptionTimes.length; curIdx++) {
			System.out.println("DC " + hapIdx + " " + curIdx + " " + absorptionTimes[curIdx]);
		}
	}
	
	// function for printing a decoding of just times (no haps)
	public static void printPosteriorDecodingTime(int hapIdx, double[][] absorptionTimesProbs) {
		
		System.out.println("decoding hap startLocus endLocus absorptionTime");
		int curIdx = 1, prevIdx = 0;
		double currentDecoding = absorptionTimesProbs[prevIdx][0];
		while (curIdx < absorptionTimesProbs.length) {
			if (currentDecoding != absorptionTimesProbs[curIdx][0]) {
				System.out.println("DC " + hapIdx + " " + prevIdx + " " + curIdx + " " + currentDecoding);
				currentDecoding = absorptionTimesProbs[curIdx][0];
				prevIdx = curIdx;
			}
			curIdx++;
		}
		System.out.println("DC " + hapIdx + " " + prevIdx + " " + curIdx + " " + currentDecoding);
	}
	
	// function for printing a decoding of just times (no haps) and also the probability
	public static void printPosteriorDecodingTimeProb(int hapIdx, double[][] absorptionTimesProbs) {
		System.out.println("decoding hap startLocus endLocus absorptionTime probability");
		for (int curIdx = 0; curIdx < absorptionTimesProbs.length; curIdx++) {
			System.out.println("DC " + hapIdx + " " + curIdx + " " + absorptionTimesProbs[curIdx][0] + " " + Math.exp(absorptionTimesProbs[curIdx][1]));
		}
	}
	
	// function for printing a decoding of just times (no haps)
	public static void printPosteriorDecoding(int hapIdx, double[][] absorptionTimesHaps) {
			
		System.out.println("decoding hap startLocus endLocus absorptionTime absorptionHap");
		int curIdx = 1, prevIdx = 0;
		double currentDecodingTime = absorptionTimesHaps[prevIdx][0];
		int currentDecodingHap = (int) absorptionTimesHaps[prevIdx][1];
		while (curIdx < absorptionTimesHaps.length) {
			if (currentDecodingTime != absorptionTimesHaps[curIdx][0] || currentDecodingHap != absorptionTimesHaps[curIdx][1]) {
				System.out.println("DC " + hapIdx + " " + prevIdx + " " + curIdx + " " + currentDecodingTime + " " + currentDecodingHap);
				currentDecodingTime = absorptionTimesHaps[curIdx][0];
				currentDecodingHap = (int) absorptionTimesHaps[curIdx][1];
				prevIdx = curIdx;
			}
			curIdx++;
		}
		System.out.println("DC " + hapIdx + " " + prevIdx + " " + curIdx + " " + currentDecodingTime + " " + currentDecodingHap);
	}
	
	// function to compute the posterior decoding times (marginalizes out hap)
	public double[][] computePosteriorDecodingTime() {
		
		double[] timeIdxMap = this.core.getPoints();
		double[][] posteriorDecoding = computePosteriorProbabilityTime();
		
		double[][] absorptionTimesProbs = new double[this.L][2];
		for (int locus = 0; locus < this.L; locus++) {
			double bestT = 0;
			double bestP = Float.NEGATIVE_INFINITY;
			for (int tIndex = 0; tIndex < timeIdxMap.length; tIndex++) {
				double testLogProb = posteriorDecoding[locus][tIndex];
				if (testLogProb > bestP) {
					bestP = testLogProb;
					if (tIndex != timeIdxMap.length-1) {
						bestT = (timeIdxMap[tIndex]+timeIdxMap[tIndex+1])/2;
					} else {
						bestT = timeIdxMap[tIndex];
					}
				}
			}
			absorptionTimesProbs[locus][0] = bestT;
			absorptionTimesProbs[locus][1] = bestP;
 		}
		return absorptionTimesProbs;
	}
	
	// function to compute the posterior decoding (best time/hap at each site independently)
	public double[][] computePosteriorDecoding(Map<Haplotype, Integer> hap2IndexMap) {
		
		double[] timeIdxMap = this.core.getPoints();
		GeneticTypeMultiplicity<H>[] configMap = this.configuration.getHapIdxMap();
		double[][][] posteriorDecoding = computePosteriorProbabilities();
		double[][] absorptionTimesHaps = new double[this.L][2]; // will hold time and hap
				
		for (int locus = 0; locus < this.L; locus++) {
			double bestP = Float.NEGATIVE_INFINITY;
			double bestT = 0;
			int bestH = 0;
			for (int tIndex = 0; tIndex < timeIdxMap.length; tIndex++) {
				for (int hapIdx = 0; hapIdx < this.nTrunk; hapIdx++) {
					double testLogProb = posteriorDecoding[locus][tIndex][hapIdx];
					if (testLogProb > bestP) {
						// change the best prob
						bestP = testLogProb;
						// change the best time (average over the interval, except the last)
						if (tIndex != timeIdxMap.length-1) {
							bestT = (timeIdxMap[tIndex]+timeIdxMap[tIndex+1])/2; 
						} else {
							bestT = timeIdxMap[tIndex];
						}
						// change the best hap
						bestH = hapIdx;
					}
				}
			}
			
			// store our best ones, with the hap index referring to the original list of ALL haplotypes
			int bestHapOverallIndex = hap2IndexMap.get(configMap[bestH].geneticType);
			absorptionTimesHaps[locus][0] = bestT;
			absorptionTimesHaps[locus][1] = bestHapOverallIndex;
		}
		return absorptionTimesHaps;
	}
	
	// function to compute the posterior mean time
	public double[] computePosteriorMeanTime() {
		
		double[] timeIdxMap = this.core.getPoints();
		double[][] posteriorDecoding = computePosteriorProbabilityTime();
		
		double[] posteriorMeans = new double[this.L];
		for (int locus = 0; locus < this.L; locus++) {
			LogSum meanSum = new LogSum(timeIdxMap.length);
			for (int tIndex = 0; tIndex < timeIdxMap.length; tIndex++) {
				double logProb = posteriorDecoding[locus][tIndex];
				if (tIndex != timeIdxMap.length-1) {
					meanSum.addLogSummand(logProb + Math.log((timeIdxMap[tIndex]+timeIdxMap[tIndex+1])/2)); // average over time interval
				} else {
					meanSum.addLogSummand(logProb + Math.log(timeIdxMap[tIndex]));
				}
			}
			posteriorMeans[locus] = Math.exp(meanSum.retrieveLogSum());
		}
		
		return posteriorMeans;
	}
	
	// function to compute posterior decoding time (marginalize out haplotype)
	public double[][] computePosteriorProbabilityTime() {
	
		double[][][] fullPosteriorProbabilities = computePosteriorProbabilities();
		
		// marginalize out the haplotype (note that this is slow, if we need this every time, should do something faster)
		double[][] margPosteriorDecoding = new double[this.L][this.d];
		for (int locus = 0; locus < this.L; locus++) {
			
			for (int tIndex = 0; tIndex < this.d; tIndex++) {
				LogSum margSum = new LogSum(this.nTrunk);
				for (int hapIndex = 0; hapIndex < this.nTrunk; hapIndex++) {
					margSum.addLogSummand(fullPosteriorProbabilities[locus][tIndex][hapIndex]);
				}
				margPosteriorDecoding[locus][tIndex] = margSum.retrieveLogSum() - this.totalLogProb;	
			}
		}
				
		return margPosteriorDecoding;
	}
	
	// compute all probabilities (both time and hap at each decode site)
	public double[][][] computePosteriorProbabilities() {
		
		// display the whole decoding table (L bases x (dxn) hidden states, find the prob of being in each state)
		double[][][] fullPosteriorProbabilities = new double[this.L][this.d][this.nTrunk];
		for (int locus = 0; locus < this.L; locus++) {
			for (int tIndex = 0; tIndex < this.d; tIndex++) {
				for (int hapIndex = 0; hapIndex < this.nTrunk; hapIndex++) {
					double testLogProb = this.forwardPVals[locus][tIndex][hapIndex] + this.backwardPVals[locus][tIndex][hapIndex];
					fullPosteriorProbabilities[locus][tIndex][hapIndex] = testLogProb;
				}
			}	
		}		
		return fullPosteriorProbabilities;
	}
	
	// note this implementation of Viterbi is quadratic in both nTrunk and d
	public double[][] computeViterbiDecoding(Map<Haplotype, Integer> hap2IndexMap) {
		
		double lognTrunk = Math.log(this.nTrunk);
		GeneticTypeMultiplicity<H>[] configMap = this.configuration.getHapIdxMap();
		
		double[][][] V = new double[this.L][this.d][this.nTrunk];
		int[][][][] ptr = new int[this.L][2][this.d][this.nTrunk];
		
		// INITIALIZATION
		
		// for the first loci, initialize probabilities
		int dstAllele = this.haplotype.getAllele(0);
		for (int tCurr=0; tCurr < this.d; tCurr++) {
			for (int hCurr=0; hCurr < this.nTrunk; hCurr++) {
				int srcAllele = configMap[hCurr].geneticType.getAllele(0);
				double logMutationProb = this.core.getLogEmission(srcAllele, dstAllele, tCurr);
				V[0][tCurr][hCurr] = logMutationProb + this.core.getLogInitialWeight(tCurr) + Math.log(1) - lognTrunk; // assuming multiplicity 1
			}
		}
		
		// RECURSION
		
		// for all loci after the first one
		double[][] testMatrix = new double[this.d][this.nTrunk];
		for (int l=1; l < this.L; l++) {
			dstAllele = this.haplotype.getAllele(l);
			
			// for each new state
			for (int tCurr=0; tCurr < this.d; tCurr++) {
				for (int hCurr=0; hCurr < this.nTrunk; hCurr++) {
					
					for (int tPrev=0; tPrev < this.d; tPrev++) {
						for (int hPrev=0; hPrev < this.nTrunk; hPrev++) {
							
							// compute transition and test value
							double logTrans = this.core.getLogRecombinationTransition(tPrev, tCurr) + Math.log(1) - lognTrunk; // assuming multiplicity 1
							if (tPrev == tCurr && hPrev == hCurr) {
								logTrans = LogSum.computePairLogSum(logTrans, this.core.getLogNoRecombinationTransition(tCurr));
							}
							testMatrix[tPrev][hPrev] = V[l-1][tPrev][hPrev] + logTrans;
						}
					}
					
					int srcAllele = configMap[hCurr].geneticType.getAllele(l);
					double logMutationProb = this.core.getLogEmission(srcAllele, dstAllele, tCurr);
					
					// update Viterbi probability and pointer
					BestViterbi best = Utility.maxElement(testMatrix);
					V[l][tCurr][hCurr] = logMutationProb + best.prob;
					ptr[l][0][tCurr][hCurr] = best.timeIdx; // time index
					ptr[l][1][tCurr][hCurr] = best.hapIdx; // haplotype index
				}
			}
		}
		
		// TERMINATION
		
		int[][] bestPath = new int[this.L][2];
		BestViterbi lastBest = Utility.maxElement(V[L-1]);
		bestPath[this.L-1][0] = lastBest.timeIdx;
		bestPath[this.L-1][1] = lastBest.hapIdx;
		double viterbiLogProb = lastBest.prob; // total probability
		
		// TRACEBACK
		
		for (int l=this.L-1; l > 0; l--) {
			bestPath[l-1][0] = ptr[l][0][bestPath[l][0]][bestPath[l][1]];
			bestPath[l-1][1] = ptr[l][1][bestPath[l][0]][bestPath[l][1]];
		}
		
		// STORE PATH
		
		double[][] absorptionTimesHaps = new double[this.L][2];
		for (int l=0; l < this.L; l++) {
			BestViterbi best = new BestViterbi(bestPath[l][0], bestPath[l][1], viterbiLogProb);
			absorptionTimesHaps[l][0] = best.getTime(this.core.getPoints());
			absorptionTimesHaps[l][1] = hap2IndexMap.get(configMap[best.hapIdx].geneticType);
		}
		
		return absorptionTimesHaps;
	}
	
	// -----------------------------------
	// FORWARD AND BACKWARD PROBABILITIES
	//------------------------------------
	
	private void computeForwardsProbabilities(GeneticTypeMultiplicity<H>[] hapIdxMap) {
				
		double lognTrunk = Math.log(this.nTrunk);
		
		LogSum tValCalc = new LogSum(this.nTrunk);
		LogSum recValCalc = new LogSum(this.d);
		
		int currReadIdx = 0;
		for (int locus = 0; locus < this.L; locus++) {
			int dstAllele = this.haplotype.getAllele(locus);
			
			// first locus, initialization
			if (locus == 0) {
				for (int tIndex = 0; tIndex < this.d; tIndex++) {
					
					tValCalc.reset();
					for (int hapIdx = 0; hapIdx < this.nTrunk; hapIdx++) {
						int srcAllele = hapIdxMap[hapIdx].geneticType.getAllele(locus);
						
						double logMutationProb = this.core.getLogEmission(srcAllele, dstAllele, tIndex);
						this.forwardPVals[locus][tIndex][hapIdx] = logMutationProb + this.core.getLogInitialWeight(tIndex) + Math.log(1) - lognTrunk; // assuming multiplicity 1
						tValCalc.addLogSummand(this.forwardPVals[locus][tIndex][hapIdx]);
					}

					this.currQVals[1-currReadIdx][tIndex] = tValCalc.retrieveLogSum();
					this.currRVals[tIndex] = Double.NEGATIVE_INFINITY;
					this.currTVals[tIndex] = 0;
				}
				
			// all other loci
			} else {
				for (int tIndex = 0; tIndex < this.core.numIntervals(); tIndex++) {
					recValCalc.reset();
					for (int tIndexSrc = 0; tIndexSrc < this.core.numIntervals(); tIndexSrc++) {
						recValCalc.addLogSummand(this.currQVals[currReadIdx][tIndexSrc] + this.core.getLogRecombinationTransition(tIndexSrc, tIndex));
					}

					double logRecTransition = recValCalc.retrieveLogSum();
					double logNoRecTransition = this.core.getLogNoRecombinationTransition(tIndex);
					
					double logFullNoRecTransition = this.currTVals[tIndex] + logNoRecTransition;
					double logFullRecTransition = LogSum.computePairLogSum(logRecTransition, this.currRVals[tIndex]);
					
					tValCalc.reset();

					for (int hapIdx = 0; hapIdx < this.nTrunk; hapIdx++) {
						int srcAllele = hapIdxMap[hapIdx].geneticType.getAllele(locus);

						double logMutationProb = this.core.getLogEmission(srcAllele, dstAllele, tIndex);
						double logNoRecTerm = logFullNoRecTransition + this.forwardPVals[locus-1][tIndex][hapIdx];
						double logRecTerm = logFullRecTransition + Math.log(1) - lognTrunk; // assuming multiplicity 1 
						
						this.forwardPVals[locus][tIndex][hapIdx] = LogSum.computePairLogSum(logNoRecTerm, logRecTerm) + logMutationProb;
						tValCalc.addLogSummand(this.forwardPVals[locus][tIndex][hapIdx]);
					}
					
					this.currQVals[1-currReadIdx][tIndex] = tValCalc.retrieveLogSum();
					this.currRVals[tIndex] = Double.NEGATIVE_INFINITY;
					this.currTVals[tIndex] = 0;
				}	
			}
			currReadIdx = 1 - currReadIdx;
		}
	}
	
	private void computeBackwardsProbabilities(GeneticTypeMultiplicity<H>[] hapIdxMap) {
		double lognTrunk = Math.log(this.nTrunk);
		
		LogSum tValCalc = new LogSum(this.nTrunk);
		LogSum recValCalc = new LogSum(this.d);

		int currReadIdx = 0;
		for (int locus = this.L-1; locus >= 0; locus--) {
			int dstAlleleQ = this.haplotype.getAllele(locus);
			
			// first locus, initialization
			if (locus == this.L-1) {
				for (int tIndex = 0; tIndex < this.d; tIndex++) {
					tValCalc.reset();
					for (int hapIdx = 0; hapIdx < this.nTrunk; hapIdx++) {
						int srcAlleleQ = hapIdxMap[hapIdx].geneticType.getAllele(locus);
						
						this.backwardPVals[locus][tIndex][hapIdx] = 0;
						
						double logMutationProbQ = this.core.getLogEmission(srcAlleleQ, dstAlleleQ, tIndex);
						tValCalc.addLogSummand(logMutationProbQ + this.backwardPVals[locus][tIndex][hapIdx] + Math.log(1) - lognTrunk); // assuming multiplicity 1
					}
					
					this.currQVals[1-currReadIdx][tIndex] = tValCalc.retrieveLogSum();
					this.currRVals[tIndex] = Double.NEGATIVE_INFINITY;
					this.currTVals[tIndex] = 0;
				}
			
			// all other loci
			} else {
				int dstAllele = this.haplotype.getAllele(locus+1);
				
				for (int tIndex = 0; tIndex < this.core.numIntervals(); tIndex++) {
					recValCalc.reset();
					for (int tIndexDst = 0; tIndexDst < this.core.numIntervals(); tIndexDst++) {
						recValCalc.addLogSummand(this.currQVals[currReadIdx][tIndexDst] + this.core.getLogRecombinationTransition(tIndex, tIndexDst));
					}

					double logRecTransition = recValCalc.retrieveLogSum();
					double logNoRecTransition = this.core.getLogNoRecombinationTransition(tIndex);

					double logFullNoRecTransition = this.currTVals[tIndex] + logNoRecTransition;
					double logFullRecTransition = LogSum.computePairLogSum(logRecTransition, this.currRVals[tIndex]);
					
					tValCalc.reset();
					for (int hapIdx = 0; hapIdx < this.nTrunk; hapIdx++) {
						int srcAllele = hapIdxMap[hapIdx].geneticType.getAllele(locus+1);
						int srcAlleleQ = hapIdxMap[hapIdx].geneticType.getAllele(locus);
						
						double logMutationProb = this.core.getLogEmission(srcAllele, dstAllele, tIndex);
						double logNoRecTerm = logFullNoRecTransition + this.backwardPVals[locus+1][tIndex][hapIdx] + logMutationProb;
						this.backwardPVals[locus][tIndex][hapIdx] = LogSum.computePairLogSum(logNoRecTerm, logFullRecTransition);

						double logMutationProbQ = this.core.getLogEmission(srcAlleleQ, dstAlleleQ, tIndex);
						tValCalc.addLogSummand(logMutationProbQ + this.backwardPVals[locus][tIndex][hapIdx] + Math.log(1) - lognTrunk); // assuming multiplicity 1
					}
					
					this.currQVals[1-currReadIdx][tIndex] = tValCalc.retrieveLogSum();
					this.currRVals[tIndex] = Double.NEGATIVE_INFINITY;
					this.currTVals[tIndex] = 0;
				}
			}
			currReadIdx = 1 - currReadIdx;
		}
	}
}
