/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.smcsd;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import edu.berkeley.utility.HapConfig;
import edu.berkeley.utility.Haplotype;
import edu.berkeley.utility.LogSum;
import edu.berkeley.utility.ParamSet;
import edu.berkeley.utility.StationaryProbability;
import edu.berkeley.utility.Utility;

public class EstepLinearLol implements Estep {	
	
	// fixed parameters
	private final List<List<Haplotype>> chromosomes; // list of all chromosomes for all nTotal haplotypes
	private final ParamSet pSet;                     // parameter set
	private final double[] times;                    // discretization points
	private final double[] sizes;                    // the current sizes during each of the above intervals
	private final int d;                             // number of discretization intervals
	private final int nTrunk;                        // number of haplotypes in the trunk
	private final int numCores;                      // number of cores to use (should be <= n)
	private final double[] descendingAscendingTable; // precomputed table to help compute nbar
	private final boolean printExpectedSegs;         // whether or not to print expected segments
	
	// values to compute
	private double eStepLogLikelihood;
	private double[] expectedInitialCounts;
	private double[] expectedRecoDiffCoalSame;
	private double[] expectedRecoDiffCoalLater;
	private double[] expectedRecoSameCoalSame;
	private double[] expectedRecoSameCoalLater;
	private double[] expectedNoReco;
	private double[] expectedCoalNow;
	private double[] expectedCoalLater;
	private double[][] expectedEmissions;
	private double[] expectedSelfTrans; // number of times we expect to go from time i to i
	private double[] expectedSegments;  // number of expected recombination segments in each time i
	
	// adding option for printing expected segments
	public EstepLinearLol(List<List<Haplotype>> chromosomes, ParamSet pSet, double[] times, double[] sizes, int numCores, double[] descendingAscendingTable, boolean printExpectedSegs) {
		assert chromosomes.size() > 0;
		this.chromosomes = chromosomes;
		this.pSet = pSet;
		this.times = times;
		this.sizes = sizes;
		this.d = times.length;
		this.nTrunk = chromosomes.get(0).size() - 1;
		this.numCores = numCores;
		this.descendingAscendingTable = descendingAscendingTable;
		this.printExpectedSegs = printExpectedSegs;
		
		// this initializes all our posterior counts
		computeEstepLolMultiCore();
	}
	
	// getters for posterior counts
	public double getEstepLogLikelihood() { return this.eStepLogLikelihood; }
	public double[] getExpectedInitialCounts() { return this.expectedInitialCounts; }
	public double[] getExpectedRecoDiffCoalSame() { return this.expectedRecoDiffCoalSame; }
	public double[] getExpectedRecoDiffCoalLater() { return this.expectedRecoDiffCoalLater; }
	public double[] getExpectedRecoSameCoalSame() { return this.expectedRecoSameCoalSame; }
	public double[] getExpectedRecoSameCoalLater() { return this.expectedRecoSameCoalLater; }
	public double[] getExpectedNoReco() { return this.expectedNoReco; }
	public double[] getExpectedCoalNow() { return this.expectedCoalNow; }
	public double[] getExpectedCoalLater() { return this.expectedCoalLater; }
	public double[][] getExpectedEmissions() { return this.expectedEmissions; }
	public double[] getExpectedSegments() { return this.expectedSegments; }

	// LOL (leave one out likelihood), using multiple cores
	private void computeEstepLolMultiCore() {
		
		this.expectedInitialCounts = new double[this.d];
		this.expectedRecoDiffCoalSame = new double[this.d];
		this.expectedRecoDiffCoalLater = new double[this.d];
		this.expectedRecoSameCoalSame = new double[this.d];
		this.expectedRecoSameCoalLater = new double[this.d];
		this.expectedNoReco = new double[this.d];
		this.expectedCoalNow = new double[this.d];
		this.expectedCoalLater = new double[this.d];
		this.expectedEmissions = new double[this.d][this.pSet.numAlleles()];
		this.expectedSelfTrans = new double[this.d];
		this.expectedSegments = new double[this.d];
			
		// construct the conditional sampling distribution based on params
		CoreLinear core = new CoreLinear(this.pSet, this.nTrunk, this.times, this.sizes, this.descendingAscendingTable, this.printExpectedSegs);
		
		// set up likelihood calculation (add up the likelihoods from each chromosome)
		int numChroms = this.chromosomes.size();
		LogSum logLikelihoodSum = new LogSum(numChroms);
		
		// special case if we have sample size 2 and symmetric matrix, don't redo work
		int numConsidered = 0;
		if (this.nTrunk == 1 && StationaryProbability.isSymmetric(this.pSet.getMutationMatrix().getArray())) {
			numConsidered = 1;
		} else {
			numConsidered = this.nTrunk + 1;
		}
		
		// first loop over each chromosome
		for (int c=0; c < numChroms; c++) {
		
			// create list of threads (loop over haplotypes so we leave each out in turn)
			List<HapThread> allHapThreads = new ArrayList<HapThread>();
			for (int i=0; i < numConsidered; i++) {
				allHapThreads.add(new HapThread(c, i, core));
			}
			
			// create task executor to manage threads
			ExecutorService taskExecutor = Executors.newFixedThreadPool(this.numCores);
			
			// run our threads inside this try/catch block
			try {
				
				List<Future<HapInfo>> futures = taskExecutor.invokeAll(allHapThreads);
				taskExecutor.shutdown(); // important to shutdown
				
				// add up all our expected transitions, emissions, and likelihoods
				double chromLogLikelihood = 0;
				for (int i=0; i < numConsidered; i++) {
					HapInfo hapInfo = futures.get(i).get();
					chromLogLikelihood += hapInfo.likelihood;
					this.expectedInitialCounts = Utility.addArrays(this.expectedInitialCounts, hapInfo.initialCounts);
					this.expectedRecoDiffCoalSame = Utility.addArrays(this.expectedRecoDiffCoalSame, hapInfo.recoDiffCoalSame);
					this.expectedRecoDiffCoalLater = Utility.addArrays(this.expectedRecoDiffCoalLater, hapInfo.recoDiffCoalLater);
					this.expectedRecoSameCoalSame = Utility.addArrays(this.expectedRecoSameCoalSame, hapInfo.recoSameCoalSame);
					this.expectedRecoSameCoalLater = Utility.addArrays(this.expectedRecoSameCoalLater, hapInfo.recoSameCoalLater);
					this.expectedNoReco = Utility.addArrays(this.expectedNoReco, hapInfo.noReco);
					this.expectedCoalNow = Utility.addArrays(this.expectedCoalNow, hapInfo.coalNow);
					this.expectedCoalLater = Utility.addArrays(this.expectedCoalLater, hapInfo.coalLater);
					this.expectedEmissions = Utility.addMatrices(this.expectedEmissions, hapInfo.emissions);
					
					if (this.printExpectedSegs) {
						this.expectedSelfTrans = Utility.addArrays(this.expectedSelfTrans, hapInfo.selfTrans);
					}
				}
				logLikelihoodSum.addLogSummand(chromLogLikelihood);
				
				// add up the expected segments in each time interval (total number of recombination-transitions into state j)
				if (this.printExpectedSegs) {
					for (int tIdx=0; tIdx < this.d; tIdx++) {
					
						// first add on all the ways we could end up in state j
						this.expectedSegments[tIdx] += this.expectedRecoDiffCoalSame[tIdx];
						this.expectedSegments[tIdx] += this.expectedRecoSameCoalSame[tIdx];
						this.expectedSegments[tIdx] += this.expectedNoReco[tIdx];
						this.expectedSegments[tIdx] += this.expectedCoalNow[tIdx];
					
						// then subtract off self transitions of any kind
						this.expectedSegments[tIdx] -= this.expectedSelfTrans[tIdx];
					}
				}
				
			} catch (InterruptedException e) {
				System.out.println("interrupted excpetion in one of our threads");
				Throwable s = e.getCause();
				s.printStackTrace(); // this is more informative than printing the stack trace from e
			} catch (ExecutionException e) {
				System.out.println("execution exception in one of our threads");
				Throwable s = e.getCause();
				s.printStackTrace(); // this is more informative than printing the stack trace from e
			}
		}
		
		this.eStepLogLikelihood = logLikelihoodSum.retrieveLogSum() - Math.log(numChroms);
	}
	
	// small struct to hold posterior counts
	private class HapInfo {
		private double likelihood;
		private double[] initialCounts;
		private double[] recoDiffCoalSame;
		private double[] recoDiffCoalLater; 
		private double[] recoSameCoalSame;
		private double[] recoSameCoalLater;
		private double[] noReco;
		private double[] coalNow;
		private double[] coalLater;
		private double[][] emissions;
		private double[] selfTrans;
		
		public HapInfo(double likelihood, double[] initialCounts, double[] recoDiffCoalSame, double[] recoDiffCoalLater, double[] recoSameCoalSame, double[] recoSameCoalLater, 
				double[] noReco, double[] coalNow, double[] coalLater, double[][] emissions, double[] selfTrans) {
			this.likelihood = likelihood;
			this.initialCounts = initialCounts;
			this.recoDiffCoalSame = recoDiffCoalSame;
			this.recoDiffCoalLater = recoDiffCoalLater;
			this.recoSameCoalSame = recoSameCoalSame;
			this.recoSameCoalLater = recoSameCoalLater;
			this.noReco = noReco;
			this.coalNow = coalNow;
			this.coalLater = coalLater;
			this.emissions = emissions;
			this.selfTrans = selfTrans;
		}
	}
	
	// our hap thread, returns type HapInfo
	private class HapThread implements Callable<HapInfo> {
		
		private int chromIdx;
		private int hapIdx;
		private CoreLinear core;
		
		public HapThread(int chromIdx, int hapIdx, CoreLinear core) {
			this.chromIdx = chromIdx;
			this.hapIdx = hapIdx;
			this.core = core;
		}
	
		// compute expected trans for a haplotype (this will be the method that gets it's own thread for each hap)
		public HapInfo call() {
			System.out.println("starting hap " + this.hapIdx + " in chrom " + this.chromIdx);
			
			// get the hapList and associated map for our current chromosome
			List<Haplotype> hapList = chromosomes.get(this.chromIdx);
			HapConfig<Haplotype> hapConfig = new HapConfig<Haplotype>();
			for (int i = 0; i < hapList.size(); i++) {
				if (i != this.hapIdx) {
					hapConfig.adjustType(hapList.get(i), 1); // add in all but the left out one
				}
			}
			
			// "take out" the desired haplotype and compute its likelihood, expected transition matrix, and expected emission matrix
			Haplotype sHap = hapList.get(this.hapIdx);
			DecodeLinear<Haplotype, HapConfig<Haplotype>> decode = new DecodeLinear<Haplotype, HapConfig<Haplotype>>(this.core, sHap, hapConfig);
			double hapLikelihood = decode.computeForwardBackward(); // this initializes everything
			
			// new decoding functions
			double[] hapInitialCounts = decode.computePosteriorInitialCounts();
			double[] hapRecoDiffCoalSame = decode.computePosteriorRecoDiffCoalSame();
			double[] hapRecoDiffCoalLater = decode.computePosteriorRecoDiffCoalLater(); 
			double[] hapRecoSameCoalSame = decode.computePosteriorRecoSameCoalSame();
			double[] hapRecoSameCoalLater = decode.computePosteriorRecoSameCoalLater();
			double[] hapNoReco = decode.computePosteriorNoReco();
			double[] hapCoalNow = decode.computePosteriorCoalNow();
			double[] hapCoalLater = decode.computePosteriorCoalLater();
			double[][] hapEmissions = decode.computePosteriorEmissions();
			
			double[] hapSelfTrans = new double[d];
			if (printExpectedSegs) {
				hapSelfTrans = decode.computeExpectedSelfTrans(); // adding in the self transitions to help with computing the expected segments
			}
			
			// store hapInfo
			System.out.println("finished hap " + this.hapIdx + " : " + hapLikelihood);
			return new HapInfo(hapLikelihood, hapInitialCounts, hapRecoDiffCoalSame, hapRecoDiffCoalLater, hapRecoSameCoalSame, hapRecoSameCoalLater, hapNoReco, 
					hapCoalNow, hapCoalLater, hapEmissions, hapSelfTrans);
		}
	}

	// print all relevant values for the given time index
	public String print(int tIdx) {
		String str = "time index: " + tIdx + "\n";
		str += "expectedInitialCounts: " + getExpectedInitialCounts()[tIdx] + "\n";
		str += "expectedRecoDiffCoalSame: " + getExpectedRecoDiffCoalSame()[tIdx] + "\n";
		str += "expectedRecoDiffCoalLater: " + getExpectedRecoDiffCoalLater()[tIdx] + "\n";
		str += "expectedRecoSameCoalSame: " + getExpectedRecoSameCoalSame()[tIdx] + "\n";
		str += "expectedRecoSameCoalLater: " + getExpectedRecoSameCoalLater()[tIdx] + "\n";
		str += "expectedNoReco: " + getExpectedNoReco()[tIdx] + "\n";
		str += "expectedCoalNow: " + getExpectedCoalNow()[tIdx] + "\n";
		str += "expectedCoalLater: " + getExpectedCoalLater()[tIdx] + "\n";
		str += "expectedEmissions: " + Arrays.toString(getExpectedEmissions()[tIdx]);
		return str;
	}
}
