/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.smcsd;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import edu.berkeley.utility.HapConfig;
import edu.berkeley.utility.Haplotype;
import edu.berkeley.utility.LogSum;
import edu.berkeley.utility.ParamSet;
import edu.berkeley.utility.StationaryProbability;
import edu.berkeley.utility.Utility;

public class EstepQuadLol implements Estep {	
	
	// fixed parameters
	private final List<List<Haplotype>> chromosomes; // list of all chromosomes for all nTotal haplotypes
	private final ParamSet pSet;                     // parameter set
	private final double[] times;                    // discretization points
	private final double[] sizes;                    // the population sizes during each of the above intervals
	private final int d;                             // number of discretization intervals
	private final int nTrunk;                        // number of haplotypes in the trunk
	private final int numCores;                      // number of cores to use (should be <= n)
	private final double[] descendingAscendingTable; // precomputed table to help compute nbar
	
	// for decoding
	private int decodingInt;
	private List<Map<Haplotype, Integer>> hap2IndexMaps;
	
	// values to compute
	private double eStepLogLikelihood;
	private double[] expectedInitialCounts; // expected counts of coalescence for the first locus, for each time interval
	private double[][] expectedTransitions; // expected transitions from time i to j
	private double[][] expectedEmissions;   // expected number of emissions of allele a from time i
	private double[] expectedSegments;      // number of expected recombination segments in each time i
	
	public EstepQuadLol(List<List<Haplotype>> chromosomes, ParamSet pSet, double[] times, double[] sizes, int numCores, double[] descendingAscendingTable, int decodingInt, List<Map<Haplotype, Integer>> hap2IndexMaps) {
		assert chromosomes.size() > 0;
		this.chromosomes = chromosomes;
		this.pSet = pSet;
		this.times = times;
		this.sizes = sizes;
		this.d = times.length;
		this.nTrunk = chromosomes.get(0).size() - 1;
		this.numCores = numCores;
		this.descendingAscendingTable = descendingAscendingTable;
		this.decodingInt = decodingInt;
		this.hap2IndexMaps = hap2IndexMaps;
		
		// this initializes everything
		computeEstepLolMultiCore();
	}
	
	public double getEstepLogLikelihood() { return this.eStepLogLikelihood; }
	public double[] getExpectedInitialCounts() { return this.expectedInitialCounts; }
	public double[][] getExpectedTransitions() { return this.expectedTransitions; }
	public double[][] getExpectedEmissions() { return this.expectedEmissions; }
	public double[] getExpectedSegments() { return this.expectedSegments; }
	
	// LOL (leave one out likelihood), using multiple cores
	private void computeEstepLolMultiCore() {
		
		this.expectedInitialCounts = new double[this.d];
		this.expectedTransitions = new double[this.d][this.d];
		this.expectedEmissions = new double[this.d][this.pSet.numAlleles()];
		this.expectedSegments = new double[this.d];
			
		// construct the conditional sampling distribution based on params
		CoreQuad core = new CoreQuad(this.pSet, this.times, this.sizes, this.nTrunk, this.descendingAscendingTable);
		
		// set up likelihood calculation (add up the likelihoods from each chromosome)
		int numChroms = this.chromosomes.size();
		LogSum logLikelihoodSum = new LogSum(numChroms);
		
		// special case if we have sample size 2 and symmetric matrix, don't redo work
		int numConsidered = 0;
		if (this.nTrunk == 1 && StationaryProbability.isSymmetric(this.pSet.getMutationMatrix().getArray())) {
			numConsidered = 1;
		} else {
			numConsidered = this.nTrunk +1 ;
		}
		
		// first loop over each chromosome
		for (int c=0; c < numChroms; c++) {
		
			// create list of threads (loop over haplotypes so we leave each out in turn)
			List<HapThread> allHapThreads = new ArrayList<HapThread>();
			for (int i=0; i < numConsidered; i++) {
				allHapThreads.add(new HapThread(c, i, core));
			}
			
			// create task executor to manage threads
			ExecutorService taskExecutor = Executors.newFixedThreadPool(this.numCores);
			
			// run our threads inside this try/catch block
			try {
				
				List<Future<HapInfo>> futures = taskExecutor.invokeAll(allHapThreads);
				taskExecutor.shutdown(); // important to shutdown
				
				// add up all our expected transitions, emissions, and likelihoods
				double chromLogLikelihood = 0;
				for (int i=0; i < numConsidered; i++) {
					HapInfo hapInfo = futures.get(i).get();
					chromLogLikelihood += hapInfo.likelihood;
					this.expectedInitialCounts = Utility.addArrays(this.expectedInitialCounts, hapInfo.initialCounts);
					this.expectedTransitions = Utility.addMatrices(this.expectedTransitions, hapInfo.expectedTransitions);
					this.expectedEmissions = Utility.addMatrices(this.expectedEmissions, hapInfo.expectedEmissions);
				}
				logLikelihoodSum.addLogSummand(chromLogLikelihood);
				
				// compute the expected segments (only add on segments where the time index has changed, i.e. recombination)
				// note that this (intentionally) does not include "hidden" recombination events which do not change the state
				for (int i=0; i < this.d; i++) {
					for (int j=0; j < this.d; j++) {
						if (i != j) {
							this.expectedSegments[j] += this.expectedTransitions[i][j];
						}
					}
				}
				
			} catch (InterruptedException e) {
				System.out.println("interrupted excpetion in one of our threads");
				Throwable s = e.getCause();
				s.printStackTrace(); // this is more informative than printing the stack trace from e
			} catch (ExecutionException e) {
				System.out.println("execution exception in one of our threads");
				Throwable s = e.getCause();
				s.printStackTrace(); // this is more informative than printing the stack trace from e
			}
		}
		
		this.eStepLogLikelihood = logLikelihoodSum.retrieveLogSum() - Math.log(numChroms);
	}
	
	// small struct to hold expected trans and expected muts
	private class HapInfo {
		private double likelihood;
		private double[] initialCounts;
		private double[][] expectedTransitions;
		private double[][] expectedEmissions;
		
		public HapInfo(double likelihood, double[] initialCounts, double[][] expectedTransitions, double[][] expectedEmissions) {
			this.likelihood = likelihood;
			this.initialCounts = initialCounts;
			this.expectedTransitions = expectedTransitions;
			this.expectedEmissions = expectedEmissions;
		}
	}
	
	// our hap thread, returns type HapInfo
	private class HapThread implements Callable<HapInfo> {
		
		private int chromIdx;
		private int hapIdx;
		private CoreQuad core;
		
		public HapThread(int chromIdx, int hapIdx, CoreQuad core) {
			this.chromIdx = chromIdx;
			this.hapIdx = hapIdx;
			this.core = core;
		}
	
		// compute expected trans for a haplotype (this will be the method that gets it's own thread for each hap)
		public HapInfo call() {
			System.out.println("starting hap " + this.hapIdx + " in chrom " + this.chromIdx);
			
			// get the hapList and associated map for our current chromosome
			List<Haplotype> hapList = chromosomes.get(this.chromIdx);
			HapConfig<Haplotype> hapConfig = new HapConfig<Haplotype>();
			for (int i = 0; i < hapList.size(); i++) {
				if (i != this.hapIdx) {
					hapConfig.adjustType(hapList.get(i), 1); // add in all but the left out one
				}
			}
			
			// "take out" the desired haplotype and compute its likelihood, expected transition matrix, and expected emission matrix
			Haplotype sHap = hapList.get(this.hapIdx);
			DecodeQuad<Haplotype, HapConfig<Haplotype>> decode = new DecodeQuad<Haplotype, HapConfig<Haplotype>>(this.core, sHap, hapConfig);
			double hapLikelihood = decode.computeForwardBackward(); // this initializes everything
			double[] hapExpectedInitialCounts = decode.computeExpectedInitialCounts();
			double[][] hapExpectedTransitions = decode.computeExpectedTransitions();
			double[][] hapExpectedEmissions = decode.computeExpectedEmissions();

			
			// if we want to, compute/print decoding
			if (decodingInt == 1) {
				Map<Haplotype, Integer> hap2IndexMap = hap2IndexMaps.get(this.chromIdx);
				double[][] absorptionTimesHaps = decode.computePosteriorDecoding(hap2IndexMap);
				DecodeQuad.printPosteriorDecoding(this.hapIdx, absorptionTimesHaps);
			} else if (decodingInt == 2) {
				Map<Haplotype, Integer> hap2IndexMap = hap2IndexMaps.get(this.chromIdx);
				double[][] absorptionTimesHaps = decode.computeViterbiDecoding(hap2IndexMap);
				DecodeQuad.printPosteriorDecoding(this.hapIdx, absorptionTimesHaps);
			} else if (decodingInt == 3) {
				double[] decodingTimes = decode.computePosteriorMeanTime();
				DecodeQuad.printPosteriorMeanTime(this.hapIdx, decodingTimes);
			} else if (decodingInt == 4) {
				double[][] decodingTimesProbs = decode.computePosteriorDecodingTime();
				DecodeQuad.printPosteriorDecodingTime(this.hapIdx, decodingTimesProbs);
			} else if (decodingInt == 5) {
				double[][] decodingTimesProbs = decode.computePosteriorDecodingTime();
				DecodeQuad.printPosteriorDecodingTimeProb(this.hapIdx, decodingTimesProbs);
			}
			
			// store hapInfo
			System.out.println("finished hap " + this.hapIdx + " : " + hapLikelihood);
			return new HapInfo(hapLikelihood, hapExpectedInitialCounts, hapExpectedTransitions, hapExpectedEmissions);
		}
	}
}
