/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.smcsd;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import edu.berkeley.utility.HapConfig;
import edu.berkeley.utility.Haplotype;
import edu.berkeley.utility.LogSum;
import edu.berkeley.utility.ParamSet;
import edu.berkeley.utility.Utility;

public class EstepQuadPac implements Estep {	
	
	// fixed parameters
	private final List<List<Haplotype>> chromosomes;   // list of all chromosomes for all nTotal haplotypes
	private final ParamSet pSet;                       // parameter set
	private final double[] times;                      // discretization points
	private final double[] sizes;                      // the population sizes during each of the above intervals
	private final int d;                               // number of discretization intervals
	private final int nTotal;                          // total number of haplotypes
	private final int[][] allPerms;                    // permutations of haplotypes
	private final int numPerms;                        // number of permutations
	private final int numCores;                        // number of cores to use
	private final double[][] descendingAscendingTable; // precomputed table to help compute nbar
	
	// for decoding
	private int decodingInt;
	private List<Map<Haplotype, Integer>> hap2IndexMaps;
	
	// values to compute
	private double eStepLogLikelihood;
	private double[] allLogLikelihoods;            // likelihoods for each permutation
	private double[][][] allExpectedInitialCounts; // expected counts of coalescence for the first locus
	private double[][][][] allExpectedTrans;       // expected transitions from time i to j, for each permutation and sample size
	private double[][][][] allExpectedMuts;        // expected number of emissions of allele a from time i, for each permutation and sample size
	private double[] expectedSegments;             // number of expected recombination segments in each time i
	
	// default constructor
	public EstepQuadPac(List<List<Haplotype>> chromosomes, ParamSet pSet, double[] times, double[] sizes, int[][] allPerms, int numCores, double[][] descendingAscendingTable, int decodingInt, List<Map<Haplotype, Integer>> hap2IndexMaps) {
		assert chromosomes.size() > 0;
		this.chromosomes = chromosomes;
		this.pSet = pSet;
		this.times = times;
		this.sizes = sizes;
		this.d = times.length;
		this.nTotal = chromosomes.get(0).size();
		this.allPerms = allPerms;
		this.numPerms = allPerms.length;
		this.numCores = numCores;
		this.descendingAscendingTable = descendingAscendingTable;
		this.decodingInt = decodingInt;
		this.hap2IndexMaps = hap2IndexMaps;
		
		// this initializes everything
		computeEstepPacMultiCore();
	}
	
	// getters for posterior counts
	public double getEstepLogLikelihood() { return this.eStepLogLikelihood; }
	public double[] getAllLogLikelihoods() { return this.allLogLikelihoods; }
	public double[][][] getAllExpectedInitialCounts() { return this.allExpectedInitialCounts; }
	public double[][][][] getAllExpectedTransitions() { return this.allExpectedTrans; }
	public double[][][][] getAllExpectedEmissions() { return this.allExpectedMuts; }
	public double[] getExpectedSegments() { return this.expectedSegments; }
	
	// PAC method, using multiple cores
	private void computeEstepPacMultiCore() {
	
		this.allLogLikelihoods = new double[this.numPerms];
		this.allExpectedInitialCounts = new double[this.numPerms][this.nTotal-1][this.d];
		this.allExpectedTrans = new double[this.numPerms][this.nTotal-1][this.d][this.d];
		this.allExpectedMuts = new double[this.numPerms][this.nTotal-1][this.d][this.pSet.numAlleles()];
		this.expectedSegments = new double[this.d];
		
		// first loop over each chromosome
		int numChroms = this.chromosomes.size();
		double[][] permChromLogLikelihoods = new double[this.numPerms][numChroms];
		for (int c=0; c < numChroms; c++) {
		
			// create list of threads for each permutation and each hap we will leave out
			List<HapThread> allHapThreads = new ArrayList<HapThread>();
			for (int n=1; n < this.nTotal; n++) {
				// create the appropriate core probabilities
				CoreQuad core = new CoreQuad(this.pSet, this.times, this.sizes, n, descendingAscendingTable[n-1]);
				for (int p=0; p < this.numPerms; p++) {
					allHapThreads.add(new HapThread(c, this.allPerms[p][this.nTotal-1-n], core, p)); // add in appropriate hap
				}
			}
			
			// create task executor to manage threads
			ExecutorService taskExecutor = Executors.newFixedThreadPool(this.numCores);
			
			// run our threads inside this try/catch block
			try {
				
				List<Future<HapInfo>> futures = taskExecutor.invokeAll(allHapThreads);
				taskExecutor.shutdown(); // important to shutdown
				
				// for each future (i.e. hap), record the hap's expected transitions/emissions/likelihood
				for (int f=0; f < futures.size(); f++) {
					HapInfo hapInfo = futures.get(f).get();
					int permIdx = hapInfo.permIdx;
					int nTrunk = hapInfo.hapNtrunk;
					permChromLogLikelihoods[permIdx][c] += hapInfo.hapLikelihood;
					
					this.allExpectedInitialCounts[permIdx][nTrunk-1] = Utility.addArrays(this.allExpectedInitialCounts[permIdx][nTrunk-1], hapInfo.initialCounts);
					this.allExpectedTrans[permIdx][nTrunk-1] = Utility.addMatrices(this.allExpectedTrans[permIdx][nTrunk-1], hapInfo.expectedTransitions);
					this.allExpectedMuts[permIdx][nTrunk-1]  = Utility.addMatrices(this.allExpectedMuts[permIdx][nTrunk-1], hapInfo.expectedEmissions);
				}
				
			} catch (InterruptedException e) {
				System.out.println("interrupted excpetion in one of our threads");
				Throwable s = e.getCause();
				s.printStackTrace(); // this is more informative than printing the stack trace from e
			} catch (ExecutionException e) {
				System.out.println("execution exception in one of our threads");
				Throwable s = e.getCause();
				s.printStackTrace(); // this is more informative than printing the stack trace from e
			}
			
			// compute the expected segments (only add on segments where the time index has changed, i.e. recombination)
			// note that this (intentionally) does not include "hidden" recombination events which do not change the state
		    for (int p=0; p < this.numPerms; p++) {
		    	for (int n=1; n < this.nTotal; n++) {
		    		for (int i=0; i < this.d; i++) {
		    			for (int j=0; j < this.d; j++) {
		    				if (i != j) {
		    					this.expectedSegments[j] += this.allExpectedTrans[p][n-1][i][j];
		    				}
		    			}
		    		}
		    	}
		    }
		}
		
		// log sum over chroms to get perm likelihoods, then compute the total log prob: this is a sum over permutations (not product) so using logsum
		LogSum logLikelihoodSum = new LogSum(this.numPerms);
		for (int p=0; p < this.numPerms; p++) {
			LogSum permLogSum = new LogSum(numChroms);
			for (int c=0; c < numChroms; c++) {
				permLogSum.addLogSummand(permChromLogLikelihoods[p][c]);
			}
			double permLogLikelihood = permLogSum.retrieveLogSum() - Math.log(numChroms);
			this.allLogLikelihoods[p] = permLogLikelihood;
			logLikelihoodSum.addLogSummand(permLogLikelihood);
		}
	    
	    this.eStepLogLikelihood = logLikelihoodSum.retrieveLogSum() - Math.log(this.numPerms);
	}
	
	// small struct to hold expected trans and expected muts
	private class HapInfo {
		private double hapLikelihood;
		private double[] initialCounts;
		private double[][] expectedTransitions;
		private double[][] expectedEmissions;
		private int permIdx;
		private int hapNtrunk;
		
		public HapInfo(double likelihood, double[] initialCounts, double[][] expectedTransitions, double[][] expectedEmissions, int permIdx, int hapNtrunk) {
			this.hapLikelihood = likelihood;
			this.initialCounts = initialCounts;
			this.expectedTransitions = expectedTransitions;
			this.expectedEmissions = expectedEmissions;
			this.permIdx = permIdx;
			this.hapNtrunk = hapNtrunk;
		}
	}
	
	// our hap thread, returns type HapInfo
	private class HapThread implements Callable<HapInfo> {
		
		private int chromIdx;
		private int hapIdx;
		private CoreQuad core;
		private int permIdx; // hap needs to know its position in the permutation
		private int[] permutation;
		
		public HapThread(int chromIdx, int hapIdx, CoreQuad core, int permIdx) {
			this.chromIdx = chromIdx;
			this.hapIdx = hapIdx;
			this.core = core;
			this.permIdx = permIdx;
			this.permutation = allPerms[permIdx];
		}
	
		// compute info for a left-out haplotype (this will be the method that gets it's own thread)
		public HapInfo call() {
			System.out.println("starting hap " + this.hapIdx + " in permutation " + Arrays.toString(permutation) + " in chrom " + this.chromIdx);
			
			// get the hapList and associated map for our current chromosome
			List<Haplotype> hapList = chromosomes.get(this.chromIdx);
			
			// add in all the haps *after* the hap at hapIdx
			HapConfig<Haplotype> hapConfig = new HapConfig<Haplotype>();
			boolean startAdding = false;
			for (int i = 0; i < hapList.size(); i++) {
				if (this.permutation[i] == this.hapIdx) {
					startAdding = true; 
				} else if (startAdding) {
					hapConfig.adjustType(hapList.get(this.permutation[i]), 1); // make sure we add the index from the permutation
				}
			}
			
			// assert that the number of haps we have added matched that of the core
			assert (hapConfig.totalGeneticTypes() == this.core.nTrunk());
			
			// "take out" the desired haplotype and compute its likelihood, expected transition matrix, and expected emission matrix
			Haplotype sHap = hapList.get(this.hapIdx);
			DecodeQuad<Haplotype, HapConfig<Haplotype>> decode = new DecodeQuad<Haplotype, HapConfig<Haplotype>>(this.core, sHap, hapConfig);
			double hapLikelihood = decode.computeForwardBackward(); // this initializes everything
			double[] hapExpectedInitialCounts = decode.computeExpectedInitialCounts();
			double[][] hapExpectedTrans = decode.computeExpectedTransitions();
			double[][] hapExpectedMuts = decode.computeExpectedEmissions();
			
			// if we want to, compute/print decoding
			if (decodingInt == 1) {
				Map<Haplotype, Integer> hap2IndexMap = hap2IndexMaps.get(this.chromIdx);
				double[][] absorptionTimesHaps = decode.computePosteriorDecoding(hap2IndexMap);
				DecodeQuad.printPosteriorDecoding(this.hapIdx, absorptionTimesHaps);
			} else if (decodingInt == 2) {
				Map<Haplotype, Integer> hap2IndexMap = hap2IndexMaps.get(this.chromIdx);
				double[][] absorptionTimesHaps = decode.computeViterbiDecoding(hap2IndexMap);
				DecodeQuad.printPosteriorDecoding(this.hapIdx, absorptionTimesHaps);
			} else if (decodingInt == 3) {
				double[] decodingTimes = decode.computePosteriorMeanTime();
				DecodeQuad.printPosteriorMeanTime(this.hapIdx, decodingTimes);
			} else if (decodingInt == 4) {
				double[][] decodingTimesProbs = decode.computePosteriorDecodingTime();
				DecodeQuad.printPosteriorDecodingTime(this.hapIdx, decodingTimesProbs);
			} else if (decodingInt == 5) {
				double[][] decodingTimesProbs = decode.computePosteriorDecodingTime();
				DecodeQuad.printPosteriorDecodingTimeProb(this.hapIdx, decodingTimesProbs);
			}
			
			// store hapInfo
			System.out.println("finished hap " + this.hapIdx + " in permutation " + Arrays.toString(this.permutation) + ": " + hapLikelihood);
			return new HapInfo(hapLikelihood, hapExpectedInitialCounts, hapExpectedTrans, hapExpectedMuts, this.permIdx, this.core.nTrunk());
		}
	}
}
