/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.utility;

import static org.junit.Assert.assertTrue;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.math.FunctionEvaluationException;
import org.junit.Test;

import edu.berkeley.smcsd.CoreLinear;
import edu.berkeley.smcsd.CoreQuad;
import edu.berkeley.smcsd.DecodeLinear;
import edu.berkeley.smcsd.DecodeQuad;
import edu.berkeley.smcsd.Estep;
import edu.berkeley.smcsd.EstepLinearLol;
import edu.berkeley.smcsd.EstepLinearPac;
import edu.berkeley.smcsd.EstepQuadLol;
import edu.berkeley.smcsd.EstepQuadPac;
import edu.berkeley.smcsd.MstepLinear;
import edu.berkeley.smcsd.MstepLinearLol;
import edu.berkeley.smcsd.MstepLinearPac;
import edu.berkeley.smcsd.MstepQuad;
import edu.berkeley.smcsd.MstepQuadLol;
import edu.berkeley.smcsd.MstepQuadPac;
import edu.berkeley.utility.Discretization.DiscretizationInterval;
import edu.berkeley.utility.ParamSet;

public class UnitTests {
	
	private final double PRECISION1 = 1e-10;
	private final double PRECISION2 = 1e-9;
	private final double PRECISION3 = 1e-7; // for expected segments only
	
	private final String PSET_FILENAME = "ex/params.txt";
	private final String FASTA_FILENAME = "ex/data.fasta";
	private final String REF_FILENAME  = "ex/reference.fasta";
	private final String VCF_FILENAME = "ex/data_strippedVcf.txt";
	private final double[] TIMES = {0, 0.25, 0.5, 1, 2};
	private final double[] SIZES = {1, 0.1, 0.4, 1.5, 3};
	private final int NTRUNK = 9;
	private final int NUM_PERMS = 5;
	private final int SEED = 4711;
	private final int NUM_CORES = 1;
	private final int DECODING = 0;
	private final int[] PARAM_PATTERN = {1, 1, 1, 1, 1};
	private final double MIN_SIZE = 0.01;
	private final double MAX_SIZE = 1000;
	
	// this function computes the z(i,j) probabilities in terms of the new event probabilities
	// should be the same as before
	private double newLogRecoProb(CoreLinear core, int i, int j) {

		double z = 0;

		if (i == j) {
			z += Math.exp(core.getLogRecoSameCoalSame(i));
		}

		// build up the first part of the product
		double prod = 1;
		for (int m=j-1; m >= Math.min(i, j); m--) {
			prod *= Math.exp(core.getLogCoalLater(m));
		}
		double firstProd = prod; // store for later (in the case when i < j)

		// compute the sum
		for (int k=Math.min(i,j)-1; k >= 0; k--) {
			z += Math.exp(core.getLogRecoDiffCoalLater(k)) * prod * Math.exp(core.getLogCoalNow(j));
			prod *= Math.exp(core.getLogCoalLater(k));
		}

		if (i < j) {
			z += Math.exp(core.getLogRecoSameCoalLater(i))*firstProd/Math.exp(core.getLogCoalLater(i))*Math.exp(core.getLogCoalNow(j));
		}

		if (i > j) {
			z += Math.exp(core.getLogRecoDiffCoalSame(j));
		}

		return Math.log(z);
	}
	
	// test the probabilities of the old (quad) core to make sure they add to 1
	@Test
	public void testCoreQuad() throws IOException {
		
		// initialize descending/ascending factorial table so we don't have to recompute
		double[][] descendingAscendingTable = new double[NTRUNK][NTRUNK];
		for (int n=1; n <= NTRUNK; n++) {
			for (int i=1; i <= NTRUNK; i++) {
				descendingAscendingTable[n-1][i-1] = Utility.descendingAscendingFac(n, i);
			}
		}
		
		ParamSet pSet = new ParamSet(PSET_FILENAME);
		CoreQuad core = new CoreQuad(pSet, TIMES, SIZES, NTRUNK, descendingAscendingTable[NTRUNK-1]);
		int d = TIMES.length;
		
		// test transition probs
		for (int i = 0; i < d; i++) {
			double noRecMass = Math.exp(core.getLogNoRecombinationTransition(i));

			double recMass = 0;
			for (int j = 0; j < d; j++) {
				recMass += Math.exp(core.getLogRecombinationTransition(i, j));
			}
			assertTrue("Incorrect transition mass for quad core " + (recMass + noRecMass) + ".", Math.abs(recMass + noRecMass - 1) < PRECISION1);
		}
		
		// test emission probs
		for (int i = 0; i < d; i++) {
			for (int a = 0; a < pSet.numAlleles(); a++) {
				double checkRow = 0;
				for (int b = 0; b < pSet.numAlleles(); b++) {
					checkRow += Math.exp(core.getLogEmission(a, b, i));
				}
				assertTrue("Incorrect mutation mass for quad core " + checkRow + ".", Math.abs(checkRow - 1) < PRECISION1);
			}
		}
		
		// test transitions lead to the correct initial density
		double[] nbar = Discretization.computeNbarArray(TIMES, SIZES, NTRUNK, descendingAscendingTable[NTRUNK-1]);
		DiscretizationInterval[] tPoints = Discretization.makeDiscretizationPoints(TIMES, SIZES, NTRUNK, nbar);
		double[][] p = new double[d][d];
		for (int i = 0; i < d; i++) {
			for (int j = 0; j < d; j++) {
				if (i == j) {
					p[i][i] += Math.exp(core.getLogNoRecombinationTransition(i));
				}
				p[i][j] += Math.exp(core.getLogRecombinationTransition(i, j));
			}
		}
		
		for (int j = 0; j < d; j++) {
			double testSum = 0;
			
			for (int i = 0; i < d; i++) {
				testSum += tPoints[i].weight * p[i][j];
			}
			assertTrue("Stationary " + testSum + " does not match weight " + tPoints[j].weight + ".", Math.abs(testSum - tPoints[j].weight) < PRECISION1);
		}
	}
		
	// test the emission probabilities of the new (linear) core
	@Test
	public void testCoreLinear() throws IOException {
		
		// initialize descending/ascending factorial table so we don't have to recompute
		double[][] descendingAscendingTable = new double[NTRUNK][NTRUNK];
		for (int n=1; n <= NTRUNK; n++) {
			for (int i=1; i <= NTRUNK; i++) {
				descendingAscendingTable[n-1][i-1] = Utility.descendingAscendingFac(n, i);
			}
		}
				
		ParamSet pSet = new ParamSet(PSET_FILENAME);
		CoreLinear core = new CoreLinear(pSet, NTRUNK, TIMES, SIZES, descendingAscendingTable[NTRUNK-1], true);
		int d = TIMES.length;
			
		// test emission probs
		for (int i = 0; i < d; i++) {
			for (int a = 0; a < pSet.numAlleles(); a++) {
				double checkRow = 0;
				for (int b = 0; b < pSet.numAlleles(); b++) {
					checkRow += Math.exp(core.getLogEmission(a, b, i));
				}
				assertTrue("Incorrect mutation mass for linear core " + checkRow + ".", Math.abs(checkRow - 1) < PRECISION1);
			}
		}
	}	
	
	// test recombination, z^(i,j), no recombination, y^(i), and emission
	// emission is computed differently so this is good to check too
	@Test
	public void testAllCores() throws IOException {
		
		// initialize descending/ascending factorial table so we don't have to recompute
		double[][] descendingAscendingTable = new double[NTRUNK][NTRUNK];
		for (int n=1; n <= NTRUNK; n++) {
			for (int i=1; i <= NTRUNK; i++) {
				descendingAscendingTable[n-1][i-1] = Utility.descendingAscendingFac(n, i);
			}
		}
				
		ParamSet pSet = new ParamSet(PSET_FILENAME);
		CoreQuad oldCore = new CoreQuad(pSet, TIMES, SIZES, NTRUNK, descendingAscendingTable[NTRUNK-1]);
		CoreLinear newCore = new CoreLinear(pSet, NTRUNK, TIMES, SIZES, descendingAscendingTable[NTRUNK-1], true);
		int d = TIMES.length;
		
		// test recombination, z^(i,j)
		for (int i=0; i < d; i++) {
			for (int j=0; j < d; j++) {
				double oldProb = oldCore.getLogRecombinationTransition(i, j);
				double newProb = newLogRecoProb(newCore, i, j);
				assertTrue("New reco prob " + newProb + " does not match old prob " + oldProb + ".", Math.abs(newProb - oldProb) < PRECISION1);
			}
		}
		
		// test no recombination, y^(i)
		for (int i=0; i < d; i++) {
			double oldProb = oldCore.getLogNoRecombinationTransition(i);
			double newProb = newCore.getLogNoReco(i);
			assertTrue("New no reco prob " + newProb + " does not match old prob " + oldProb + ".", Math.abs(newProb - oldProb) < PRECISION1);
		}
		
		// test emissions
		for (int i=0; i < d; i++) {
			for (int a=0; a < pSet.numAlleles(); a++) {
				for (int b=0; b < pSet.numAlleles(); b++) {
					double oldProb = oldCore.getLogEmission(a, b, i);
					double newProb = newCore.getLogEmission(a, b, i);
					assertTrue("New mut prob " + newProb + " does not match old prob " + oldProb + ".", Math.abs(newProb - oldProb) < PRECISION1);
				}
			}
		}
	}
	
	// test the forward backward probabilities and the posterior probabilities
	@Test
	public void testPosteriorProbs() throws IOException, FunctionEvaluationException, IllegalArgumentException {
		
		// initialize descending/ascending factorial table so we don't have to recompute
		double[][] descendingAscendingTable = new double[NTRUNK][NTRUNK];
		for (int n=1; n <= NTRUNK; n++) {
			for (int i=1; i <= NTRUNK; i++) {
				descendingAscendingTable[n-1][i-1] = Utility.descendingAscendingFac(n, i);
			}
		}
		
		// set up the parameters and haplotype configuration
		ParamSet pSet = new ParamSet(PSET_FILENAME);
		List<Haplotype> hapListFasta = Utility.readHaplotypesFasta(FASTA_FILENAME, pSet, NTRUNK);
		List<Haplotype> hapListVcf   = Utility.readHaplotypesVcf(REF_FILENAME, VCF_FILENAME, pSet, NTRUNK).get(0);
		
		int L = hapListFasta.get(0).getNumLoci();
		for (int n=0; n < NTRUNK+1; n++) {
			for (int locus=0; locus < L; locus++) {
				assertTrue("Haplotypes in fasta and vcf are not equal at locus " + locus + ".", hapListFasta.get(n).getAllele(locus) == hapListVcf.get(n).getAllele(locus));
			}
		}
		
		// add in all but the first hap
		Haplotype leftOutHap = hapListFasta.get(0);
		HapConfig<Haplotype> hapConfig = new HapConfig<Haplotype>();
		for (int h=1; h < hapListFasta.size(); h++) {
			hapConfig.adjustType(hapListFasta.get(h), 1);
		}
		
		// define old/new cores and decodes
		CoreQuad oldCore = new CoreQuad(pSet, TIMES, SIZES, NTRUNK, descendingAscendingTable[NTRUNK-1]);
		DecodeQuad<Haplotype, HapConfig<Haplotype>> oldDecode = new DecodeQuad<Haplotype, HapConfig<Haplotype>>(oldCore, leftOutHap, hapConfig);
		oldDecode.computeForwardBackward(); // this initializes everything
		
		CoreLinear newCore = new CoreLinear(pSet, NTRUNK, TIMES, SIZES, descendingAscendingTable[NTRUNK-1], true);
		DecodeLinear<Haplotype, HapConfig<Haplotype>> newDecode = new DecodeLinear<Haplotype, HapConfig<Haplotype>>(newCore, leftOutHap, hapConfig);
		newDecode.computeForwardBackward();
		
		// compute all the types of expected transitions
		double[][] logRecoDiffCoalSameByLocus = newDecode.computeLogRecoDiffCoalSameByLocus();
		double[][] logRecoDiffCoalLaterByLocus = newDecode.computeLogRecoDiffCoalLaterByLocus();
		double[][] logRecoSameCoalSameByLocus = newDecode.computeLogRecoSameCoalSameByLocus();
		double[][] logRecoSameCoalLaterByLocus = newDecode.computeLogRecoSameCoalLaterByLocus();
		double[][] logNoRecoByLocus = newDecode.computeLogNoRecoByLocus();
		double[][] logCoalNowByLocus = newDecode.computeLogCoalNowByLocus();
		double[][] logCoalLaterByLocus = newDecode.computeLogCoalLaterByLocus();
		
		// test forward/backward first
		int d = TIMES.length;
		for (int locus=0; locus < L; locus++) {
			for (int i=0; i < d; i++) {
				for (int h=0; h < NTRUNK; h++) {
					double oldForward = oldDecode.getForwardLogProbs()[locus][i][h];
					double oldBackward = oldDecode.getBackwardLogProbs()[locus][i][h];
					double newForward = newDecode.getForwardLogProbs()[locus][i][h];
					double newBackward = newDecode.getBackwardLogProbs()[locus][i][h];
					assertTrue("New forward value " + newForward + " does not match old forward value " + oldForward + ".", Math.abs(newForward - oldForward) < PRECISION1);
					assertTrue("New backward value " + newBackward + " does not match old backward value " + oldBackward + ".",Math.abs(newBackward - oldBackward) < PRECISION1);
				}
			}
		}
		
		// then test decoding
		LogSum recoLogSum = new LogSum(5*d);
		LogSum coalLogSum = new LogSum(4*d);
		
		for (int locus=0; locus < L-1; locus++) {
			
			// ---------- test recombination in state i ----------
			recoLogSum.reset();
			for (int i=0; i < d; i++) {
				recoLogSum.addLogSummand(logRecoDiffCoalSameByLocus[locus][i]);
				recoLogSum.addLogSummand(logRecoDiffCoalLaterByLocus[locus][i]);
				recoLogSum.addLogSummand(logRecoSameCoalSameByLocus[locus][i]);
				recoLogSum.addLogSummand(logRecoSameCoalLaterByLocus[locus][i]);
				recoLogSum.addLogSummand(logNoRecoByLocus[locus][i]); // we do need to include no reco here to sum to 1
			}
			double reco = recoLogSum.retrieveLogSum();
			assertTrue("Recombination log sum " + reco + " is not 0.", Math.abs(reco) < PRECISION2);
			
			// ---------- test coalescence in state i ----------
			coalLogSum.reset();
			for (int i=0; i < d; i++) {
				coalLogSum.addLogSummand(logRecoDiffCoalSameByLocus[locus][i]);
				coalLogSum.addLogSummand(logRecoSameCoalSameByLocus[locus][i]);
				coalLogSum.addLogSummand(logNoRecoByLocus[locus][i]);
				coalLogSum.addLogSummand(logCoalNowByLocus[locus][i]);
			}
			double coal = coalLogSum.retrieveLogSum();
			assertTrue("Coalescent log sum " + coal + " is not 0.", Math.abs(coal) < PRECISION2);
			
			// ---------- test no coalescence ----------
			for (int i=0; i < d; i++) {
				LogSum noCoalLogSum = new LogSum(5*(d-i) + 4*i + 3);
				for (int j=0; j < d; j++) {
					
					if (j <= i) {
						noCoalLogSum.addLogSummand(logRecoDiffCoalSameByLocus[locus][j]);
						noCoalLogSum.addLogSummand(logRecoSameCoalSameByLocus[locus][j]);
						noCoalLogSum.addLogSummand(logNoRecoByLocus[locus][j]);
						noCoalLogSum.addLogSummand(logCoalNowByLocus[locus][j]);
					} else {
						noCoalLogSum.addLogSummand(logRecoDiffCoalSameByLocus[locus][j]);
						noCoalLogSum.addLogSummand(logRecoDiffCoalLaterByLocus[locus][j]);
						noCoalLogSum.addLogSummand(logRecoSameCoalSameByLocus[locus][j]);
						noCoalLogSum.addLogSummand(logRecoSameCoalLaterByLocus[locus][j]);
						noCoalLogSum.addLogSummand(logNoRecoByLocus[locus][j]);
					}
				}
				noCoalLogSum.addLogSummand(logRecoDiffCoalLaterByLocus[locus][i]);
				noCoalLogSum.addLogSummand(logRecoSameCoalLaterByLocus[locus][i]);
				noCoalLogSum.addLogSummand(logCoalLaterByLocus[locus][i]);
				
				double noCoal = noCoalLogSum.retrieveLogSum();
				assertTrue("No coalescence log sum " + noCoal + " is not 0.", Math.abs(noCoal) < PRECISION2);
			}
		}
		
		// now test the E-steps as well, first initializing the chromosomes and the hap2IndexMaps
		List<List<Haplotype>> chromosomes = new ArrayList<List<Haplotype>>();
		chromosomes.add(hapListFasta);
		List<Map<Haplotype, Integer>> hap2IndexMaps = new ArrayList<Map<Haplotype, Integer>>();
		Map<Haplotype, Integer> hap2IndexMap = new HashMap<Haplotype, Integer>();
		for (int i=0; i < hapListFasta.size(); i++) {
			hap2IndexMap.put(hapListFasta.get(i), i);
		}
		hap2IndexMaps.add(hap2IndexMap);
		int[][] allPerms = Utility.getRandomPermuations(NTRUNK+1, NUM_PERMS, SEED);
	
		// first LOL E-step
		Estep eStepLinearLol = new EstepLinearLol(chromosomes, pSet, TIMES, SIZES, NUM_CORES, descendingAscendingTable[NTRUNK-1], true);
		Estep eStepQuadLol = new EstepQuadLol(chromosomes, pSet, TIMES, SIZES, NUM_CORES, descendingAscendingTable[NTRUNK-1], DECODING, hap2IndexMaps);
		assertTrue("Estep LOL likelihoods do not match.", Math.abs(eStepLinearLol.getEstepLogLikelihood() - eStepQuadLol.getEstepLogLikelihood()) < PRECISION1);
		for (int tIdx=0; tIdx < d; tIdx++) {
			assertTrue("LOL expected segments do not match.", Math.abs(eStepLinearLol.getExpectedSegments()[tIdx] - eStepQuadLol.getExpectedSegments()[tIdx]) < PRECISION3);
		}
		
		// then PAC E-step
		Estep eStepLinearPac = new EstepLinearPac(chromosomes, pSet, TIMES, SIZES, allPerms, NUM_CORES, descendingAscendingTable, true);
		Estep eStepQuadPac = new EstepQuadPac(chromosomes, pSet, TIMES, SIZES, allPerms, NUM_CORES, descendingAscendingTable, DECODING, hap2IndexMaps);
		assertTrue("Estep PAC likelihoods do not match.", Math.abs(eStepLinearPac.getEstepLogLikelihood() - eStepQuadPac.getEstepLogLikelihood()) < PRECISION1);
		for (int tIdx=0; tIdx < d; tIdx++) {
			System.out.println(eStepLinearPac.getExpectedSegments()[tIdx] + " " + eStepQuadPac.getExpectedSegments()[tIdx]);
			assertTrue("PAC expected segments do not match.", Math.abs(eStepLinearPac.getExpectedSegments()[tIdx] - eStepQuadPac.getExpectedSegments()[tIdx]) < PRECISION3);
		}
		
		// first LOL M-step
		MstepLinear mStepLinearLol = new MstepLinearLol(pSet, TIMES, NTRUNK, PARAM_PATTERN, descendingAscendingTable[NTRUNK-1]);
		mStepLinearLol.updateEachIter(eStepLinearLol);
		mStepLinearLol.setDebug(true);
		MstepQuad mStepQuadLol = new MstepQuadLol(pSet, TIMES, NTRUNK, PARAM_PATTERN, MIN_SIZE, MAX_SIZE, descendingAscendingTable[NTRUNK-1]);
		mStepQuadLol.updateEachIter(eStepQuadLol);
		mStepQuadLol.setDebug(true);
		
		// linear method
		for (int sizeIdx=0; sizeIdx < d; sizeIdx++) {			
			mStepLinearLol.value(SIZES[sizeIdx]); // prints useful info, although not an explicit test
							
			if (sizeIdx < d-1) {
				mStepLinearLol.updateParamIdx(sizeIdx+1, SIZES[sizeIdx]);
			}
		}
		// quadratic method
		mStepQuadLol.value(SIZES);
					
		// then PAC M-step
		MstepLinear mStepLinearPac = new MstepLinearPac(pSet, TIMES, NTRUNK+1, PARAM_PATTERN, descendingAscendingTable);
		mStepLinearPac.updateEachIter(eStepLinearPac);
		mStepLinearPac.setDebug(true);
		MstepQuad mStepQuadPac = new MstepQuadPac(pSet, TIMES, NTRUNK+1, PARAM_PATTERN, MIN_SIZE, MAX_SIZE, descendingAscendingTable);
		mStepQuadPac.updateEachIter(eStepQuadPac);
		mStepQuadPac.setDebug(true);
		
		// linear method
		for (int sizeIdx=0; sizeIdx < d; sizeIdx++) {			
			mStepLinearPac.value(SIZES[sizeIdx]); // prints useful info, although not an explicit test
							
			if (sizeIdx < d-1) {
				mStepLinearPac.updateParamIdx(sizeIdx+1, SIZES[sizeIdx]);
			}
		}
		// quadratic method
		mStepQuadPac.value(SIZES);
	}
}
