/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.smcsd;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.math.ConvergenceException;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.optimization.GoalType;
import org.apache.commons.math.optimization.RealPointValuePair;
import org.apache.commons.math.optimization.direct.NelderMead;
import org.apache.commons.math.optimization.univariate.BrentOptimizer;

import com.martiansoftware.jsap.FlaggedOption;
import com.martiansoftware.jsap.JSAP;
import com.martiansoftware.jsap.JSAPException;
import com.martiansoftware.jsap.JSAPResult;
import com.martiansoftware.jsap.Parameter;
import com.martiansoftware.jsap.SimpleJSAP;

import edu.berkeley.utility.Haplotype;
import edu.berkeley.utility.ParamSet;
import edu.berkeley.utility.Utility;

public class EmWrapper {
	
	// globals for optimization
    private final static double STEP = 0.1;      // initial step size for Nelder-Mead
    private final static double MIN_SIZE = 0.01; // lower bound for the population size scaling factors
    private final static double MAX_SIZE = 1000; // upper bound for the population size scaling factors
    
    // globals for printing info
    private final static int BYTES_PER_DOUBLE = 8;        // number of bytes per double, used for estimating memory requirements
	private final static double GIGA_BYTE = 1000000000;   // number of bytes for one Gb, used for estimating memory requirments
	private final static String PROG = "diCal-v1.3"; // current version of diCal
	
	public static void main(String[] args) throws IOException, JSAPException, ConvergenceException, FunctionEvaluationException, IllegalArgumentException {
		SimpleJSAP jsap = new SimpleJSAP("diCal", "Use diCal to infer population sizes and/or HMM decoding.",
	        new Parameter[] {
				
				// ------------------- REQUIRED PARAMETERS -------------------
				
				// the files of haplotypes in fasta format (all must be the same length), now can accommodate more than one chromosome/scaffold, each in a separate file
	            new FlaggedOption( "fastaFiles", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'F', "fastaFiles", "The input fasta files, one for each chromosome/scaffold. All haplotypes within each file must all have the same length."),
	            
	            // the reference fasta file
	            new FlaggedOption( "referenceFile", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'R', "referenceFile", "The inpupt fasta reference file."),
	            
	            // the file of haplotypes in small vcf format (should be used in combination with a reference)
	            new FlaggedOption( "strippedVcfFile", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.NOT_REQUIRED, 'V', "strippedVcfFile", "The input stripped vcf file of haplotypes."),
	            
	            // the parameter file
	            new FlaggedOption( "paramFile", JSAP.STRING_PARSER, JSAP.NO_DEFAULT, JSAP.REQUIRED, 'I', "paramFile", "The input parameter file, see example file for formatting."),
	            
	            // ------------------- RECOMMENDED PARAMETERS -------------------
	                
	            // will take the first nTotal haplotypes, default is 2
	            new FlaggedOption( "nTotal", JSAP.INTEGER_PARSER, "2", JSAP.NOT_REQUIRED, 'n', "nTotal", "The number of haplogypes."),
	            
	            // flag for the parameter pattern, no longer using adaptive param pattern
	            new FlaggedOption( "paramPattern", JSAP.STRING_PARSER, "3 2 3", JSAP.NOT_REQUIRED, 'p', "paramPattern", "Pattern of parameters spanning intervals."),
	            
	            // flag for scaling times to have an end time, default it -1 to specify it is not used
	            new FlaggedOption( "endTime", JSAP.DOUBLE_PARSER, "-1", JSAP.NOT_REQUIRED, 't', "endTime", "Specify an end time in coalescence units, default is no time rescaling."),
	            
	            // ------------------- OPTIONAL PARAMETERS -------------------
	            
	            // flag for leave-out-out vs. PAC
	            new FlaggedOption( "leaveOneOut", JSAP.BOOLEAN_PARSER, "1", JSAP.NOT_REQUIRED, 'l', "leaveOneOut", "PAC method, default is leave-one-out, use 0 for traditional PAC (experimental)."),
	            
	            // flag for a seed to initialize the random generator for the permutations (only source of randomness)
	            new FlaggedOption( "seed", JSAP.INTEGER_PARSER, "0", JSAP.NOT_REQUIRED, 's', "seed", "Seed for the random number generator for the permutations (PAC only)."),
	            
	            // flag for the number of permutations for PAC
	            new FlaggedOption( "numPerms", JSAP.INTEGER_PARSER, "5", JSAP.NOT_REQUIRED, 'g', "numPerms", "The number of permutations (PAC only)."),
	           
	            // flag for linear (recommended) vs. quadratic time in d
	            new FlaggedOption( "linear", JSAP.BOOLEAN_PARSER, "1", JSAP.NOT_REQUIRED, 'u', "linear", "Linear runtime in the number of discretization points is recommended, use 0 for quadratic."),
	            
	            // the number of threads to use (default 1 thread)
	            new FlaggedOption( "numCores", JSAP.INTEGER_PARSER, "1", JSAP.NOT_REQUIRED, 'c', "numCores", "The number of cores to use."),
	            
	            // what type of decoding (if any) to print at the end
	            new FlaggedOption( "decodingInt", JSAP.INTEGER_PARSER, "0", JSAP.NOT_REQUIRED, 'd', "decodingInt", "0: no decoding\n1: posterior decoding\n2: Viterbi decoding\n3: posterior mean time\n4: posterior decoding time\n5: posterior decoding time/probability"),
	            
	            // adding flag for number of iterations, default 20
	            new FlaggedOption( "numIter", JSAP.INTEGER_PARSER, "20", JSAP.NOT_REQUIRED, 'N', "numIter", "Number of iterations."),
	            
	            // flag for a file of times (and sizes, optional)
	            new FlaggedOption( "timesFile", JSAP.STRING_PARSER, "", JSAP.NOT_REQUIRED, 'T', "timesFile", "File of times (and sizes, optional), see example for formatting."),
	            
	            // flag for whether or not we should print the expected segments
	            new FlaggedOption( "printExpectedSegs", JSAP.BOOLEAN_PARSER, "0", JSAP.NOT_REQUIRED, 'e', "printExpectedSegs", "Use 1 to print the expected number of coalescent events in each time interval."),
			}
		);
		
		// program parameters (compliments of JSAP)
		JSAPResult config = jsap.parse(args);
		
		// parse required parameters
		String fastaString = config.getString("fastaFiles");
		String refFilename = config.getString("referenceFile");
		String vcfFilename = config.getString("strippedVcfFile");
		String paramFile = config.getString("paramFile");
		
		// parse recommended parameters
		int nTotal = config.getInt("nTotal");
		int nTrunk = nTotal-1;
		String paramPatternStr = config.getString("paramPattern");
		double endTime = config.getDouble("endTime");
		
		// parse optional parameters
		boolean leaveOneOut = config.getBoolean("leaveOneOut");
		int seed = config.getInt("seed");
		int numPerms = config.getInt("numPerms");
		boolean linear = config.getBoolean("linear");
		int numCores = config.getInt("numCores");
		int decodingInt = config.getInt("decodingInt");
		int maxIter = config.getInt("numIter");
		String timesFile = config.getString("timesFile");
		boolean printExpectedSegs = config.getBoolean("printExpectedSegs");
		
		// ------------------- REQUIRED PARAMETERS -------------------
		
		// read in parameter data from file
		if (paramFile == null) { System.exit(0); }
		ParamSet pSet = new ParamSet(paramFile);
		int unknown = pSet.numAlleles(); // this integer will be used for unknown alleles
		
		// read in haplotype data from the files and create a list of hapConfigs
		List<List<Haplotype>> chromosomes = new ArrayList<List<Haplotype>>();
		String dataString = "";
		int L = 0; // this is the total number of loci over all chromosomes;
		if (fastaString == null && (refFilename == null || vcfFilename == null)) {
			System.err.println("Error: must either specify a fasta file of haplotypes (-f), or a fasta reference file (-r) and a stripped vcf file of haplotypes (-v).");
			System.exit(0);
		}
		else if (fastaString != null) {
			dataString = fastaString;
			String[] fastaFiles = fastaString.split(" ");
			for (String fileName : fastaFiles) {
				List<Haplotype> newChrom = Utility.readHaplotypesFasta(fileName, pSet, nTrunk);
				chromosomes.add(newChrom);
				L += newChrom.get(0).getNumLoci();
			}
		} else {
			dataString = refFilename + " " + vcfFilename;
			chromosomes = Utility.readHaplotypesVcf(refFilename, vcfFilename, pSet, nTrunk);
			for (List<Haplotype> chrom : chromosomes) {
				L += chrom.get(0).getNumLoci();
			}
		}
		
		// if decoding, create hap to index map
		List<Map<Haplotype, Integer>> hap2IndexMaps = new ArrayList<Map<Haplotype, Integer>>();
		for (List<Haplotype> hapList : chromosomes) {
			Map<Haplotype, Integer> hap2IndexMap = new HashMap<Haplotype, Integer>();
			if (decodingInt != 0) {
				for (int i=0; i < hapList.size(); i++) {
					hap2IndexMap.put(hapList.get(i), i);
				}
			}
			hap2IndexMaps.add(hap2IndexMap);
		}
		
		// ------------------- RECOMMENDED PARAMETERS -------------------
		
		// basic checks on the input
		if (nTrunk < 1) { System.err.println("Error: number of haplotypes should be >= 2."); System.exit(0); }
		if (endTime <= 0 && endTime != -1) { System.err.println("Error: end time should be positive."); System.exit(0); }
		
		// initialize the parameter pattern
		int[] paramPattern = Utility.parseParamPattern(paramPatternStr);
		int numParam = paramPattern.length;
		int d = Utility.sumArray(paramPattern);
		if (numParam <= 0) { System.err.println("Error: number of inferred parameters <= 0."); System.exit(0); }
		
		// initialize descending/ascending factorial table so we don't have to recompute
		double[][] descendingAscendingTable = new double[nTrunk][nTrunk];
		for (int n=1; n <= nTrunk; n++) {
			for (int i=1; i <= nTrunk; i++) {
				descendingAscendingTable[n-1][i-1] = Utility.descendingAscendingFac(n, i);
			}
		}
				
		// ------------------- OPTIONAL PARAMETERS -------------------
		
		if (numCores <= 0) { System.err.println("Error: number of threads <= 0."); System.exit(0); }
		if (maxIter == 0 && decodingInt == 0) { System.err.println("Error: num EM iterations is 0 and no decoding printed (program does nothing)."); System.exit(0); }
		
		// initialize sizes to 1
		double[] currentSizes = new double[d];
		double[] currentSizesShort = new double[numParam];
		Arrays.fill(currentSizes, 1); // initialize start sizes to 1 by default
		Arrays.fill(currentSizesShort, 1); // initialize to 1 even in the case when a history has been specified (usually would not want to infer a history)
		
		// if times have not been initialized through user specification, compute them
		double[] times = null;
		if (timesFile == "") {
			times = Utility.computeDiscretization(chromosomes, d, pSet, unknown);
			// rescale if an end time is specified
			if (endTime != -1) {
				times = Utility.rescaleQuadraturePoints(times, endTime);
			}
		
		// if times/sizes have been specified in a file, read these in
		} else {
			double[][] timesSizes = Utility.readTimesFile(timesFile);
			times = timesSizes[0];
			if (times.length != d) { System.err.println("Error: number of times in file does not match number of intervals (flag -d)."); System.exit(0); };
			if (timesSizes[1][0] != 0) {
				currentSizes = timesSizes[1];
			}
		}
		
		// print out info about our run
		System.out.println(getInfoString(dataString, paramFile, pSet,
			d, numParam, nTotal, paramPatternStr, endTime, 
			leaveOneOut, seed, numPerms, linear, numCores, decodingInt, maxIter, times, currentSizes, L, printExpectedSegs));
		
		// ------------------- SET UP OPTIMIZATION -------------------
		
		// optimization for linear mode
		BrentOptimizer optimizer = new BrentOptimizer();
		
		// optimization for quadratic mode
		double currentMstepValue = 0;
		double[] stepSizes = new double[numParam];
		Arrays.fill(stepSizes, STEP);
		NelderMead nelderMead = new NelderMead();
		nelderMead.setStartConfiguration(stepSizes); // set the initial step sizes for each parameter
		
		// set up the Mstep
		MstepLinear mStepLinear = null;
		MstepQuad mStepQuad = null;
		if (linear) {
			if (leaveOneOut) {
				mStepLinear = new MstepLinearLol(pSet, times, nTrunk, paramPattern, descendingAscendingTable[nTrunk-1]);
			} else {
				mStepLinear = new MstepLinearPac(pSet, times, nTotal, paramPattern, descendingAscendingTable);
			}
		} else {
			if (leaveOneOut) {
				mStepQuad = new MstepQuadLol(pSet, times, nTrunk, paramPattern, MIN_SIZE, MAX_SIZE, descendingAscendingTable[nTrunk-1]);
			} else {
				mStepQuad = new MstepQuadPac(pSet, times, nTotal, paramPattern, MIN_SIZE, MAX_SIZE, descendingAscendingTable);
			}
		}
		
		// if we are in PAC mode, create a list of random permutations
		int[][] allPerms = Utility.getRandomPermuations(nTotal, numPerms, seed);
		
		// ------------------- EM LOOP -------------------
		
		for (int iter=0; iter < maxIter; iter++) {
			System.out.println("\n--------\nITER=" + iter + "\n--------");
			
			// ESTEP------
			System.out.println("calling E-step with sizes: " + Arrays.toString(currentSizes));
			Estep eStep;
			if (linear) {
				if (leaveOneOut) {
					eStep = new EstepLinearLol(chromosomes, pSet, times, currentSizes, numCores, descendingAscendingTable[nTrunk-1], printExpectedSegs);
				} else {
					eStep = new EstepLinearPac(chromosomes, pSet, times, currentSizes, allPerms, numCores, descendingAscendingTable, printExpectedSegs);
				}
			} else {
				if (leaveOneOut) {
					eStep = new EstepQuadLol(chromosomes, pSet, times, currentSizes, numCores, descendingAscendingTable[nTrunk-1], 0, hap2IndexMaps); // not decoding during EM
				} else {
					eStep = new EstepQuadPac(chromosomes, pSet, times, currentSizes, allPerms, numCores, descendingAscendingTable, 0, hap2IndexMaps); // not decoding during EM
				}
			}
			
			if (printExpectedSegs) {
				System.out.println("expected segments: " + Arrays.toString(eStep.getExpectedSegments()));
			}
			System.out.println("E-step log likelihood: " + eStep.getEstepLogLikelihood());
			
			// MSTEP------
			System.out.println("calling M-step");
			
			// linear in d mode
			if (linear) {
				mStepLinear.updateEachIter(eStep);
				
				// update one parameter at a time, not setting last size to previous one in this case (should always have at least 2 intervals in last parameter)
				int startIdx = 0;
				int endIdx = paramPattern[0];
				for (int pIdx = 0; pIdx < numParam; pIdx++) {
					double bestSize = optimizer.optimize(mStepLinear, GoalType.MAXIMIZE, MIN_SIZE, MAX_SIZE);
					
					// update all sizes within parameter to the same (best) size
					for (int j = startIdx; j < endIdx; j++) {
						currentSizes[j] = bestSize;
					}
					
					if (pIdx < numParam-1) {
						mStepLinear.updateParamIdx(pIdx+1, bestSize);
						startIdx = endIdx;
						endIdx += paramPattern[pIdx+1];
					}
				}
			
			// quadratic in d mode
			} else {
				mStepQuad.updateEachIter(eStep);
					
				// the Nelder-Mead optimization (maximize log likelihood for LOL)
				RealPointValuePair pointValuePair = nelderMead.optimize(mStepQuad, GoalType.MAXIMIZE, currentSizesShort);
	
				// get the minimum value and parameters at minimum
				currentSizesShort = pointValuePair.getPoint();
				currentMstepValue = pointValuePair.getValue();
				
				// reset the current population sizes to the best ones the Mstep just found	
				currentSizes = Utility.expandSizes(currentSizesShort, paramPattern);
				System.out.println("M-step log likelihood: " + currentMstepValue);
			}
		       
			// print out the current/final sizes
			if (iter == maxIter-1) {
				System.out.println("final sizes: " + Arrays.toString(currentSizes));
			} else {
				System.out.println("current sizes: " + Arrays.toString(currentSizes));
			}
		}
		
		// ------------------- OPTIONAL DECODING -------------------
		
		// if we are decoding, call Estep one more time (outside EM loop) with final sizes, and numThreads = 1 so in right order
		// NOTE: this will not be linear in d
		if (decodingInt != 0) {
			System.out.println("\nBeginning decoding:");
			Estep decodeEstep;
			
			// LOL
			if (leaveOneOut) {
				decodeEstep = new EstepQuadLol(chromosomes, pSet, times, currentSizes, 1, descendingAscendingTable[nTrunk-1], decodingInt, hap2IndexMaps);
			// PAC
			} else {
				decodeEstep = new EstepQuadPac(chromosomes, pSet, times, currentSizes, allPerms, 1, descendingAscendingTable, decodingInt, hap2IndexMaps);
			}
			System.out.println("expected segments: " + Arrays.toString(decodeEstep.getExpectedSegments()));
		}
	}
	
	// format the command line input so we can print it for future reference
	private static String getRunString() {
		String commandStr = System.getProperty("sun.java.command");
		String runStr = PROG + ":" + commandStr.substring(commandStr.indexOf(" "));
		return runStr;
	}
	
	// estimate the memory requirements based on the input parameters, format into an -Xmx flag to the nearest giga-byte
	private static String getMemoryString(int L, int nTotal, int d, int numCores) {
		String memoryStr = "suggested JVM memory flag O(L*d*n*c): ";
		// 7 forward/backward tables total, but only the main forward/backward tables are indexed by n
		int forwardBackwardMemory = (int) Math.ceil((double)(2*L*nTotal*d + 5*L*d)*numCores*BYTES_PER_DOUBLE/GIGA_BYTE); // divide to get the floor, then add 1 to round up
		memoryStr += "-Xmx" + (2*forwardBackwardMemory) + "G"; // multiplying by 2 for security
		return memoryStr;
	}
	
	// util for printing out information about the run we are doing
	private static String getInfoString(String dataString, String paramFile, ParamSet pSet, 
		int d, int numParam, int nTotal, String paramPatternStr, double endTime, 
		boolean leaveOneOut, int seed, int numPerms, boolean linear, int numCores, int decodingInt, int maxIter, double[] times, double[] sizes, int L, boolean printExpectedSegs) {
		String infoStr = getRunString() + "\n";
		
		// required params
		infoStr += "dataset: " + dataString + "\n";
		infoStr += "parameter file: " + paramFile + "\n";
		infoStr += "number of alleles: " + pSet.numAlleles() + "\n";
		infoStr += "mutation rate theta = 4Nmu: " + pSet.getMutationRate() + "\n";
		infoStr += "mutation matrix:\n" + Utility.getMatrixString(pSet.getMutationMatrix().getArray());
		infoStr += "recombination rate rho = 4Nr: " + pSet.getRecombinationRate() + "\n";
		
		// recommended params
		infoStr += "number of parameters to estimate: " + numParam + "\n";
		infoStr += "parameter pattern: " + paramPatternStr + "\n";
		if (endTime == -1) {
			infoStr += "end time: default (none used)\n";
		} else {
			infoStr += "end time: " + endTime + "\n";
		}
		
		// optional params
		if (leaveOneOut) {
			infoStr += "PAC method: LOL\n";
		} else {
			infoStr += "PAC method: PAC with permutation weighting (experimental, leave-one-out is recommended)\n";
			if (seed != 0) {
				infoStr += "PAC seed: " + seed + "\n";
			} else {
				infoStr += "PAC seed: none specified\n";
			}
			infoStr += "PAC number of permutations: " + numPerms + "\n";
		}
		if (linear) {
			infoStr += "Runtime: linear in d\n";
		} else {
			infoStr += "Runtime: quadratic in d\n";
		}
		if (decodingInt == 0) {
			infoStr += "HMM decoding: (0) none\n";
		} else if (decodingInt == 1) {
			infoStr += "HMM decoding: (1) posterior decoding\n";
		} else if (decodingInt == 2) {
			infoStr += "HMM decoding: (2) Viterbi decoding\n";
		} else if (decodingInt == 3) {
			infoStr += "HMM decoding: (3) posterior mean time\n";
		} else if (decodingInt == 4) {
			infoStr += "HMM decoding: (4) posterior decoding time\n";
		} else if (decodingInt == 5) {
			infoStr += "HMM decoding: (5) posterior decoding time and probability\n";
		}
		infoStr += "number of EM iterations: " + maxIter + "\n";
		infoStr += "printing expected coalescent events: " + printExpectedSegs + "\n";
		
		// memory
		infoStr += "L = number of loci: " + L + "\n";
		infoStr += "d = number of discretization intervals: " + d + "\n";
		infoStr += "n = number of haplotypes: " + (nTotal) + "\n";
		infoStr += "c = number of cores: " + numCores + "\n";
		infoStr += getMemoryString(L, nTotal, d, numCores) + "\n";
		
		// initial times and sizes
		infoStr += "discretization times (units of 2N generations): " + Arrays.toString(times) + "\n";
		infoStr += "initial population size scaling factors: " + Arrays.toString(sizes);
		return infoStr;
	}
}
