/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.smcsd;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import edu.berkeley.utility.HapConfig;
import edu.berkeley.utility.Haplotype;
import edu.berkeley.utility.LogSum;
import edu.berkeley.utility.ParamSet;
import edu.berkeley.utility.Utility;

public class EstepLinearPac implements Estep {	
	
	// fixed parameters
	private final List<List<Haplotype>> chromosomes;   // list of all chromosomes for all nTotal haplotypes
	private final ParamSet pSet;                       // parameter set
	private final double[] times;                      // discretization points
	private final double[] sizes;                      // the current sizes during each of the above intervals
	private final int d;                               // number of discretization intervals
	private final int nTotal;                          // total number of haplotypes
	private final int[][] allPerms;                    // permutations of haplotypes
	private final int numPerms;                        // number of permutations
	private final int numCores;                        // number of cores to use
	private final double[][] descendingAscendingTable; // precomputed table to help compute nbar
	private final boolean printExpectedSegs;           // whether or not to print expected segments
	
	// values to compute, indexed by permutation, trunk size, time interval (then allele, for emissions)
	private double eStepLogLikelihood;
	private double[] allLogLikelihoods; // likelihood for each permutation
	private double[][][] allExpectedInitialCounts;
	private double[][][] allExpectedRecoDiffCoalSame;
	private double[][][] allExpectedRecoDiffCoalLater;
	private double[][][] allExpectedRecoSameCoalSame;
	private double[][][] allExpectedRecoSameCoalLater;
	private double[][][] allExpectedNoReco;
	private double[][][] allExpectedCoalNow;
	private double[][][] allExpectedCoalLater;
	private double[][][][] allExpectedEmissions;
	private double[][][] allExpectedSelfTrans; // number of times we expect to go from time i to i in each permutation, trunk size
	private double[] expectedSegments; // number of expected recombination segments in each time i
	
	public EstepLinearPac(List<List<Haplotype>> chromosomes, ParamSet pSet, double[] times, double[] sizes, int[][] allPerms, int numCores, double[][] descendingAscendingTable, boolean printExpectedSegs) {
		assert chromosomes.size() > 0;
		this.chromosomes = chromosomes;
		this.pSet = pSet;
		this.times = times;
		this.sizes = sizes;
		this.d = times.length;
		this.nTotal = chromosomes.get(0).size();
		this.allPerms = allPerms;
		this.numPerms = allPerms.length;
		this.numCores = numCores;
		this.descendingAscendingTable = descendingAscendingTable;
		this.printExpectedSegs = printExpectedSegs;
		
		// this initializes all our posterior counts
		computeEstepPacMultiCore();
	}
	
	// getters for posterior counts
	public double getEstepLogLikelihood() { return this.eStepLogLikelihood; }
	public double[] getAllLogLikelihoods() { return this.allLogLikelihoods; }
	public double[][][] getAllExpectedInitialCounts() { return this.allExpectedInitialCounts; }
	public double[][][] getAllExpectedRecoDiffCoalSame() { return this.allExpectedRecoDiffCoalSame; }
	public double[][][] getAllExpectedRecoDiffCoalLater() { return this.allExpectedRecoDiffCoalLater; }
	public double[][][] getAllExpectedRecoSameCoalSame() { return this.allExpectedRecoSameCoalSame; }
	public double[][][] getAllExpectedRecoSameCoalLater() { return this.allExpectedRecoSameCoalLater; }
	public double[][][] getAllExpectedNoReco() { return this.allExpectedNoReco; }
	public double[][][] getAllExpectedCoalNow() { return this.allExpectedCoalNow; }
	public double[][][] getAllExpectedCoalLater() { return this.allExpectedCoalLater; }
	public double[][][][] getAllExpectedEmissions() { return this.allExpectedEmissions; }
	public double[] getExpectedSegments() { return this.expectedSegments; }
	
	// PAC method, using multiple cores
	private void computeEstepPacMultiCore() {
		
		this.allLogLikelihoods = new double[this.numPerms];
		this.allExpectedInitialCounts = new double[this.numPerms][this.nTotal-1][this.d];
		this.allExpectedRecoDiffCoalSame = new double[this.numPerms][this.nTotal-1][this.d];
		this.allExpectedRecoDiffCoalLater = new double[this.numPerms][this.nTotal-1][this.d];
		this.allExpectedRecoSameCoalSame = new double[this.numPerms][this.nTotal-1][this.d];
		this.allExpectedRecoSameCoalLater = new double[this.numPerms][this.nTotal-1][this.d];
		this.allExpectedNoReco = new double[this.numPerms][this.nTotal-1][this.d];
		this.allExpectedCoalNow = new double[this.numPerms][this.nTotal-1][this.d];
		this.allExpectedCoalLater = new double[this.numPerms][this.nTotal-1][this.d];
		this.allExpectedEmissions = new double[this.numPerms][this.nTotal-1][this.d][this.pSet.numAlleles()];
		this.allExpectedSelfTrans = new double[this.numPerms][this.nTotal-1][this.d];
		this.expectedSegments = new double[this.d];
		
		// first loop over each chromosome
		int numChroms = this.chromosomes.size();
		double[][] permChromLogLikelihoods = new double[this.numPerms][numChroms];
		for (int c=0; c < numChroms; c++) {
		
			// create list of threads for each permutation and each hap we will leave out
			List<HapThread> allHapThreads = new ArrayList<HapThread>();
			for (int n=1; n < this.nTotal; n++) {
				// create the appropriate core probabilities
				CoreLinear core = new CoreLinear(this.pSet, n, this.times, this.sizes, this.descendingAscendingTable[n-1], this.printExpectedSegs);
				for (int p=0; p < this.numPerms; p++) {
					allHapThreads.add(new HapThread(c, this.allPerms[p][this.nTotal-1-n], core, p)); // add in appropriate hap
				}
			}
			
			// create task executor to manage threads
			ExecutorService taskExecutor = Executors.newFixedThreadPool(this.numCores);
			
			// run our threads inside this try/catch block
			try {
				
				List<Future<HapInfo>> futures = taskExecutor.invokeAll(allHapThreads);
				taskExecutor.shutdown(); // important to shutdown
				
				// for each future (i.e. hap), record the hap's expected transitions/emissions/likelihood
				for (int f=0; f < futures.size(); f++) {
					HapInfo hapInfo = futures.get(f).get();
					int permIdx = hapInfo.permIdx;
					int nTrunk = hapInfo.hapNtrunk;
					permChromLogLikelihoods[permIdx][c] += hapInfo.hapLikelihood;
					
					this.allExpectedInitialCounts[permIdx][nTrunk-1] = Utility.addArrays(this.allExpectedInitialCounts[permIdx][nTrunk-1], hapInfo.initialCounts);
					this.allExpectedRecoDiffCoalSame[permIdx][nTrunk-1]  = Utility.addArrays(this.allExpectedRecoDiffCoalSame[permIdx][nTrunk-1], hapInfo.recoDiffCoalSame);
					this.allExpectedRecoDiffCoalLater[permIdx][nTrunk-1] = Utility.addArrays(this.allExpectedRecoDiffCoalLater[permIdx][nTrunk-1], hapInfo.recoDiffCoalLater);
					this.allExpectedRecoSameCoalSame[permIdx][nTrunk-1]  = Utility.addArrays(this.allExpectedRecoSameCoalSame[permIdx][nTrunk-1], hapInfo.recoSameCoalSame);
					this.allExpectedRecoSameCoalLater[permIdx][nTrunk-1] = Utility.addArrays(this.allExpectedRecoSameCoalLater[permIdx][nTrunk-1], hapInfo.recoSameCoalLater);
					this.allExpectedNoReco[permIdx][nTrunk-1] = Utility.addArrays(this.allExpectedNoReco[permIdx][nTrunk-1], hapInfo.noReco);
					this.allExpectedCoalNow[permIdx][nTrunk-1] = Utility.addArrays(this.allExpectedCoalNow[permIdx][nTrunk-1], hapInfo.coalNow);
					this.allExpectedCoalLater[permIdx][nTrunk-1] = Utility.addArrays(this.allExpectedCoalLater[permIdx][nTrunk-1], hapInfo.coalLater);
					this.allExpectedEmissions[permIdx][nTrunk-1] = Utility.addMatrices(this.allExpectedEmissions[permIdx][nTrunk-1], hapInfo.emissions);
					
					if (this.printExpectedSegs) {
						this.allExpectedSelfTrans[permIdx][nTrunk-1] = Utility.addArrays(this.allExpectedSelfTrans[permIdx][nTrunk-1], hapInfo.selfTrans);
					}
				}
				
			} catch (InterruptedException e) {
				System.out.println("interrupted excpetion in one of our threads");
				Throwable s = e.getCause();
				s.printStackTrace(); // this is more informative than printing the stack trace from e
			} catch (ExecutionException e) {
				System.out.println("execution exception in one of our threads");
				Throwable s = e.getCause();
				s.printStackTrace(); // this is more informative than printing the stack trace from e
			}
		}
		
		// add up the expected segments in each time interval (total number of recombination-transitions into state j)
		if (this.printExpectedSegs) {
		    for (int p=0; p < this.numPerms; p++) {
		    	for (int n=1; n < this.nTotal; n++) {
		    		for (int tIdx=0; tIdx < this.d; tIdx++) {
		    			
		    			// first add on all the ways we could end up in state j
		    			this.expectedSegments[tIdx] += this.allExpectedRecoDiffCoalSame[p][n-1][tIdx];
		    			this.expectedSegments[tIdx] += this.allExpectedRecoSameCoalSame[p][n-1][tIdx];
		    			this.expectedSegments[tIdx] += this.allExpectedNoReco[p][n-1][tIdx];
						this.expectedSegments[tIdx] += this.allExpectedCoalNow[p][n-1][tIdx];
						
						// then subtract off self transitions of any kind
						this.expectedSegments[tIdx] -= this.allExpectedSelfTrans[p][n-1][tIdx];
		    		}
		    	}
		    }
		}
		
		// log sum over chroms to get perm likelihoods, then compute the total log prob: this is a sum over permutations (not product) so using logsum
		LogSum logLikelihoodSum = new LogSum(this.numPerms);
		for (int p=0; p < this.numPerms; p++) {
			LogSum permLogSum = new LogSum(numChroms);
			for (int c=0; c < numChroms; c++) {
				permLogSum.addLogSummand(permChromLogLikelihoods[p][c]);
			}
			double permLogLikelihood = permLogSum.retrieveLogSum() - Math.log(numChroms);
			this.allLogLikelihoods[p] = permLogLikelihood;
			logLikelihoodSum.addLogSummand(permLogLikelihood);
		}
	    
	    this.eStepLogLikelihood = logLikelihoodSum.retrieveLogSum() - Math.log(this.numPerms);
	}
	
	// small struct to hold expected trans and expected muts
	private class HapInfo {
		private double hapLikelihood;
		private double[] initialCounts;
		private double[] recoDiffCoalSame;
		private double[] recoDiffCoalLater; 
		private double[] recoSameCoalSame;
		private double[] recoSameCoalLater;
		private double[] noReco;
		private double[] coalNow;
		private double[] coalLater;
		private double[][] emissions;
		private double[] selfTrans;
		private int permIdx;
		private int hapNtrunk;
		
		public HapInfo(double hapLikelihood, double[] initialCounts, double[] recoDiffCoalSame, double[] recoDiffCoalLater, double[] recoSameCoalSame, double[] recoSameCoalLater, 
				double[] noReco, double[] coalNow, double[] coalLater, double[][] emissions, double[] selfTrans, int permIdx, int hapNtrunk) {
			this.hapLikelihood = hapLikelihood;
			this.initialCounts = initialCounts;
			this.recoDiffCoalSame = recoDiffCoalSame;
			this.recoDiffCoalLater = recoDiffCoalLater;
			this.recoSameCoalSame = recoSameCoalSame;
			this.recoSameCoalLater = recoSameCoalLater;
			this.noReco = noReco;
			this.coalNow = coalNow;
			this.coalLater = coalLater;
			this.emissions = emissions;
			this.selfTrans = selfTrans;
			this.permIdx = permIdx;
			this.hapNtrunk = hapNtrunk;
		}
	}
	
	// our hap thread, returns type HapInfo
	private class HapThread implements Callable<HapInfo> {
		
		private int chromIdx;
		private int hapIdx;
		private CoreLinear core;
		private int permIdx; // hap needs to know its position in the permutation
		private int[] permutation;
		
		public HapThread(int chromIdx, int hapIdx, CoreLinear core, int permIdx) {
			this.chromIdx = chromIdx;
			this.hapIdx = hapIdx;
			this.core = core;
			this.permIdx = permIdx;
			this.permutation = allPerms[permIdx];
		}
	
		// compute info for a left-out haplotype (this will be the method that gets it's own thread)
		public HapInfo call() {
			System.out.println("starting hap " + this.hapIdx + " in permutation " + Arrays.toString(this.permutation) + " in chrom " + this.chromIdx);
			
			// get the hapList and associated map for our current chromosome
			List<Haplotype> hapList = chromosomes.get(chromIdx);
			
			// add in all the haps *after* the hap at hapIdx
			HapConfig<Haplotype> hapConfig = new HapConfig<Haplotype>();
			boolean startAdding = false;
			for (int i = 0; i < hapList.size(); i++) {
				if (this.permutation[i] == this.hapIdx) {
					startAdding = true; 
				} else if (startAdding) {
					hapConfig.adjustType(hapList.get(this.permutation[i]), 1); // make sure we add the index from the permutation
				}
			}
			
			// assert that the number of haps we have added matched that of the core
			assert (hapConfig.totalGeneticTypes() == this.core.nTrunk());
			
			// "take out" the desired haplotype and compute its likelihood, expected transition matrix, and expected emission matrix
			Haplotype sHap = hapList.get(this.hapIdx);
			DecodeLinear<Haplotype, HapConfig<Haplotype>> decode = new DecodeLinear<Haplotype, HapConfig<Haplotype>>(this.core, sHap, hapConfig);
			double hapLikelihood = decode.computeForwardBackward(); // this initializes everything
			
			// new decoding functions
			double[] hapInitialCounts = decode.computePosteriorInitialCounts();
			double[] hapRecoDiffCoalSame = decode.computePosteriorRecoDiffCoalSame();
			double[] hapRecoDiffCoalLater = decode.computePosteriorRecoDiffCoalLater(); 
			double[] hapRecoSameCoalSame = decode.computePosteriorRecoSameCoalSame();
			double[] hapRecoSameCoalLater = decode.computePosteriorRecoSameCoalLater();
			double[] hapNoReco = decode.computePosteriorNoReco();
			double[] hapCoalNow = decode.computePosteriorCoalNow();
			double[] hapCoalLater = decode.computePosteriorCoalLater();
			double[][] hapEmissions = decode.computePosteriorEmissions();
			
			double[] hapSelfTrans = new double[d];
			if (printExpectedSegs) {
				hapSelfTrans = decode.computeExpectedSelfTrans(); // adding in the self transitions to help with computing the expected segments
			}
			
			// store hapInfo
			System.out.println("finished hap " + this.hapIdx + " in permutation " + Arrays.toString(this.permutation) + ": " + hapLikelihood);
			return new HapInfo(hapLikelihood, hapInitialCounts, hapRecoDiffCoalSame, hapRecoDiffCoalLater, hapRecoSameCoalSame, hapRecoSameCoalLater, hapNoReco, 
					hapCoalNow, hapCoalLater, hapEmissions, hapSelfTrans, this.permIdx, this.core.nTrunk());
		}
	}
	
	// print all relevant values for the given permutation index, trunk size n, and time index
	public String print(int permIdx, int nTrunk, int tIdx) {
		assert nTrunk > 0 && nTrunk < this.nTotal;
		String str = "permutation index: " + permIdx + ", nTrunk: " + nTrunk + ", time index: " + tIdx + "\n";
		str += "expectedInitialCounts: " + getAllExpectedInitialCounts()[permIdx][nTrunk-1][tIdx] + "\n";
		str += "expectedRecoDiffCoalSame: " + getAllExpectedRecoDiffCoalSame()[permIdx][nTrunk-1][tIdx] + "\n";
		str += "expectedRecoDiffCoalLater: " + getAllExpectedRecoDiffCoalLater()[permIdx][nTrunk-1][tIdx] + "\n";
		str += "expectedRecoSameCoalSame: " + getAllExpectedRecoSameCoalSame()[permIdx][nTrunk-1][tIdx] + "\n";
		str += "expectedRecoSameCoalLater: " + getAllExpectedRecoSameCoalLater()[permIdx][nTrunk-1][tIdx] + "\n";
		str += "expectedNoReco: " + getAllExpectedNoReco()[permIdx][nTrunk-1][tIdx] + "\n";
		str += "expectedCoalNow: " + getAllExpectedCoalNow()[permIdx][nTrunk-1][tIdx] + "\n";
		str += "expectedCoalLater: " + getAllExpectedCoalLater()[permIdx][nTrunk-1][tIdx] + "\n";
		str += "expectedEmissions: " + Arrays.toString(getAllExpectedEmissions()[permIdx][nTrunk-1][tIdx]);
		return str;
	}
}
