/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.utility;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

import Jama.Matrix;

// where we keep all of our simple utility functions
public class Utility {
		
	// util for printing matrix
	public static String getMatrixString(double[][] matrix) {
		String printstr = "";
		for (int i=0; i < matrix.length; i++) {
			printstr += Arrays.toString(matrix[i]) + "\n";
		}
		return printstr;
	}
	
	// util for adding two lxw matrices
	public static double[][] addMatrices(double[][] m1, double[][] m2) {
		
		// get length
		int l = m1.length;
		assert l == m2.length;
		
		// get width
		assert l > 0;
		int w = m1[0].length;
		assert w == m2[0].length;
		
		double[][] total = new double[l][w];
		for (int i=0; i < l; i++) {
			for (int j=0; j < w; j++) {
				total[i][j] = m1[i][j] + m2[i][j];
			}
		}
		return total;
	}
	
	// util for adding up an array of ints
	public static int sumArray(int[] a) {
		int sum = 0;
		for (int i=0; i < a.length; i++) {
			sum += a[i];
		}
		return sum;
	}
		
	// util for adding up an array of doubles
	public static double sumArray(double[] a) {
		double sum = 0;
		for (int i=0; i < a.length; i++) {
			sum += a[i];
		}
		return sum;
	}
	
	// util for adding two arrays element-wise
	public static double[] addArrays(double[] a1, double[] a2) {
		
		// get length
		int l = a1.length;
		assert l == a2.length;
		
		double[] total = new double[l];
		for (int i=0; i < l; i++) {
			total[i] = a1[i] + a2[i];
		}
		return total;
	}
	
	// util for multiplying a matrix by a constant
	public static double[][] multMatrix(double[][] mat, double c) {
		int l = mat.length;
		int w = mat[0].length;
		double[][] newMat = new double[l][w];
		for (int i=0; i < l; i++) {
			for (int j=0; j < w; j++) {
				newMat[i][j] = mat[i][j] * c;
			}
		}
		return newMat;
	}
	
	// small struct for holding the best time index, best hap, and probability
	public static class BestViterbi {
		
		public final int timeIdx;
		public final int hapIdx;
		public final double prob;
		
		public BestViterbi(int timeIdx, int hapIdx, double prob) {
			this.timeIdx = timeIdx;
			this.hapIdx = hapIdx;
			this.prob = prob;
		}
		
		public double getTime(double[] times) {
			int d = times.length;
			assert this.timeIdx < d;
			if (this.timeIdx == d-1) {
				return times[d-1];
			} else {
				return (times[this.timeIdx] + times[this.timeIdx+1])/2;
			}
		}
	}
	
	// find the maximum element in a matrix, along with the arg max row and column
	public static BestViterbi maxElement(double[][] matrix) {
		int argmaxRow = 0;
		int argmaxCol = 0;
		double maxValue = Double.NEGATIVE_INFINITY;
		
		for (int r=0; r < matrix.length; r++) {
			for (int c=0; c < matrix[0].length; c++) {
				if (matrix[r][c] > maxValue) {
					argmaxRow = r;
					argmaxCol = c;
					maxValue = matrix[r][c];
				}
			}
		}
		
		BestViterbi best = new BestViterbi(argmaxRow, argmaxCol, maxValue);
		return best;
	}
	
	// descending ascending factorial
	public static double descendingAscendingFac(int n, int i) {
		double prod = 1;
		double cast = 0;
		for (int j=0; j < i; j++) {
			cast = n-j;
			prod *= cast/(n+j);
		}
		return prod;
	}
	
	// from a short list of popSizes, expand it based on pattern of params
	public static double[] expandSizes(double[] sizesShort, int[] paramPattern) {
		int k = paramPattern.length;
		assert (sizesShort.length == k);
		
		int d = 0;
		for (int p=0; p < k; p++) {
			d += paramPattern[p];
		}
		
		double[] popSizes = new double[d]; // we will still have d pop sizes, but some of them might be the same
		int index = 0;
		for (int i=0; i < k; i++) {
			for (int j=0; j < paramPattern[i]; j++) {
				popSizes[index] = sizesShort[i];
				index += 1;
			}
		}
		return popSizes;
	}
	
	// see if we are out of the size bounds
    public static boolean outOfBounds(double[] sizesShort, double minSize, double maxSize) {
    	for (int s=0; s < sizesShort.length; s++) {
    		if (sizesShort[s] < minSize || sizesShort[s] > maxSize) {
    			return true;
    		}
    	}
    	return false;
    }
	
	// for any given n, generate a square matrix where each number appears once in each row and column
	public static int[][] getPermuationSquare(int n) {
		int[][] allPerms = new int[n][n];
		for (int i=0; i < n; i++) {
			for (int j=0; j < n; j++) {
				allPerms[i][j] = (i+j) % n;
			}
		}
		return allPerms;
	}
	
	// generate random permuations
	public static int[][] getRandomPermuations(int n, int numPerms, int seed) {
		List<Integer> permutation = new ArrayList<Integer>(n);
		for (int j=0; j < n; j++) {
			permutation.add(j);
		}
		int[][] allPerms = new int[numPerms][n+1];
		for (int p=0; p < numPerms; p++) {
			    	
			// shuffle the list and add on permutation
			if (seed != 0) {
				Random random = new Random(seed);
				Collections.shuffle(permutation, random);
			} else {
				Collections.shuffle(permutation);
			}
			
			int[] newperm = new int[n];
			for (int i=0; i < n; i++) {
				newperm[i] = permutation.get(i);
			}
			allPerms[p] = newperm;
		}
		return allPerms;
	}
	
	// parse the paramPattern string
	public static int[] parseParamPattern(String paramPatternStr) {
		String[] patternStrs = paramPatternStr.replace("+", " ").split(" ");
		int[] paramPattern = new int[patternStrs.length];
		for (int i=0; i < patternStrs.length; i++) {
			paramPattern[i] = Integer.parseInt(patternStrs[i]);
		}
		return paramPattern;
	}
	
	// make map from alleles (A,C,G,T,N) to integers
	public static Map<Character,Byte> makeBase2IntMap() {
		Map<Character,Byte> base2IntMap = new HashMap<Character,Byte>();
		base2IntMap.put('A',(byte)0);
		base2IntMap.put('a',(byte)0);
		base2IntMap.put('C',(byte)1);
		base2IntMap.put('c',(byte)1);
		base2IntMap.put('G',(byte)2);
		base2IntMap.put('g',(byte)2);
		base2IntMap.put('T',(byte)3);
		base2IntMap.put('t',(byte)3);
		base2IntMap.put('N',(byte)4);
		base2IntMap.put('n',(byte)4);
		return base2IntMap;
	}
	
	// read in haplotypes at just the segregating sites from a vcf file, one list of haps for each chromosome
	public static List<List<Haplotype>> readHaplotypesVcf(String refFilename, String vcfFilename, ParamSet pSet, int nTrunk) throws IOException {
		
		Map<Character,Byte> base2IntMap = makeBase2IntMap();
		
		// first read the reference to get the base sequences for each haplotype
		BufferedReader refReader = new BufferedReader(new FileReader(refFilename));
		
		List<String> refChromNames = new ArrayList<String>();
		List<String> refChromStrings = new ArrayList<String>();
		List<List<byte[]>> chromosomeAlleleConfigs = new ArrayList<List<byte[]>>();
		String refReadString = null;
		String currChromString = ""; // this is so that we can build up a chromosome on multiple lines
		
		while ((refReadString = refReader.readLine()) != null) {
			
			// if we are in a header line, read in the chrom name, then afterwards start reading reference lines
			if (refReadString.charAt(0) == '>') {
				String chromName = refReadString.trim().split(" ")[0].substring(1); // just take the first token for the chrom name (so we can include other info in the fasta header)
				if (refChromNames.contains(chromName)) {
					System.err.println("Error: duplicate reference sequence name: " + chromName + "."); // make sure we haven't seen this chromosome before
				}
				refChromNames.add(chromName);
				
				// add on the old chromString and initialize the new one
				if (currChromString != "") {
					refChromStrings.add(currChromString);
				}
				currChromString = "";
			
			// if we are in a sequence line, add on the sequence
			} else {
				currChromString += refReadString.trim();
			}
		}
		
		// add on last refChromString, then parse them all into bytes
		refChromStrings.add(currChromString);
		
		for (String chromString : refChromStrings) {
			int L = chromString.trim().length();
			byte[] chromSeq = new byte[L];
			List<byte[]> alleleConfigList = new ArrayList<byte[]>();
		
			// convert sequence to ints
			for (int locus=0; locus < L; locus++) {
				char currChar = chromString.charAt(locus);
				if (base2IntMap.containsKey(currChar)) {
					chromSeq[locus] = base2IntMap.get(currChar);
				} else {
					throw new IOException("unknown allele value " + currChar + " (Should be A/a,C/c,G/g,T/t,N/n)");
				}
			
				// make sure we do not have more alleles than we planned for with our mutation matrix
				// allele values 0-3 represent A,C,G,T and allele value 4 represents an unknown base
				if (chromSeq[locus] > pSet.numAlleles()) {
					throw new IOException("defined allele value " + chromSeq[locus] + " does not match the mutation matrix size " + pSet.numAlleles());
				}
			}
		
			// then create nTrunk+1 haplotypes and add it to our list of chromosomes
			for (int n=0; n < nTrunk+1; n++) {
				alleleConfigList.add(Arrays.copyOf(chromSeq, L)); // it is important to make a copy here so we can change the seg sites without changing all the arrays
			}
			chromosomeAlleleConfigs.add(alleleConfigList);
		}
		
		// check our number of chromosomes
		int numChroms = refChromNames.size();
		assert chromosomeAlleleConfigs.size() == numChroms;
		
		// then read in the lines of the small vcf file (should be no comments)
		BufferedReader vcfReader = new BufferedReader(new FileReader(vcfFilename));
		String segSiteReadString = null;
		while ((segSiteReadString = vcfReader.readLine()) != null) {
			
			// parse out chromName, locus, and allele config
			String[] tokens = segSiteReadString.trim().split(" ");
			String chromName = tokens[0];
			int locus = Integer.parseInt(tokens[1])-1; // subtract 1 since vcf is indexed from 1
			String alleleStr = tokens[2];
			
			// checks on the data
			assert refChromNames.contains(chromName);
			int chromIdx = refChromNames.indexOf(chromName);
			assert alleleStr.length() >= nTrunk+1; // make sure we have at least enough haplotypes
			assert locus < chromosomeAlleleConfigs.get(chromIdx).get(0).length;
			
			// assign each allele to the right place
			for (int n=0; n < nTrunk+1; n++) {
				char currChar = alleleStr.charAt(n);
				if (base2IntMap.containsKey(currChar)) {
					chromosomeAlleleConfigs.get(chromIdx).get(n)[locus] = base2IntMap.get(currChar);
				} else {
					throw new IOException("unknown allele value " + currChar + " (Should be A/a,C/c,G/g,T/t,N/n)");
				}
				
				// make sure we do not have more alleles than we planned for with our mutation matrix
				// allele values 0-3 represent A,C,G,T and allele value 4 represents an unknown base
				if (chromosomeAlleleConfigs.get(chromIdx).get(n)[locus] > pSet.numAlleles()) {
					throw new IOException("defined allele value " + chromosomeAlleleConfigs.get(chromIdx).get(n)[locus] + " does not match the mutation matrix size " + pSet.numAlleles());
				}
			}
		}
		
		// finally create our haplotypes
		List<List<Haplotype>> chromosomes = new ArrayList<List<Haplotype>>(numChroms);
		for (int c=0; c < numChroms; c++) {
			List<Haplotype> hapList = new ArrayList<Haplotype>(nTrunk+1);
			for (int n=0; n < nTrunk+1; n++) {
				Haplotype haplotype = new Haplotype(chromosomeAlleleConfigs.get(c).get(n));
				hapList.add(haplotype);
			}
			chromosomes.add(hapList);
		}
		
		return chromosomes;
	}
	
	// parse fasta style format (note that this assumes each sequence is on ONE line)
	// each line must be exactly the length specified in the parameter file, and may use only the characters A,C,G,T (capital or lowercase)
	// note: unknown bases (represented by N or n) are now supported
	public static List<Haplotype> readHaplotypesFasta(String fastaFilename, ParamSet pSet, int nTrunk) throws IOException {
		BufferedReader reader = new BufferedReader(new FileReader(fastaFilename));	
		List<Haplotype> hapList = new ArrayList<Haplotype>(nTrunk+1);
		Map<Character,Byte> base2IntMap = makeBase2IntMap();
		
		// only read in the first nTrunk+1 haplotypes
		String nextLine = null;
		Integer numLoci = null;
		for (int i=0; i < nTrunk+1; i++) {
			nextLine = readLine(reader);
			String haplotypeString = nextLine.trim();
		 
			if (i == 0) {
				numLoci = haplotypeString.length();
			} else if (haplotypeString.length() != numLoci) {
				throw new IOException("haplotype has length " + haplotypeString.length() + " (Should have length " + numLoci + ")");
			}
			
			byte[] alleleConfig = new byte[numLoci];
			for (int l = 0; l < numLoci; l++) {
				char currChar = haplotypeString.charAt(l);
				if (base2IntMap.containsKey(currChar)) {
					alleleConfig[l] = base2IntMap.get(currChar);
				} else {
					throw new IOException("unknown allele value " + currChar + " (Should be A/a,C/c,G/g,T/t,N/n)");
				}
				
				// make sure we do not have more alleles than we planned for with our mutation matrix
				// allele values 0-3 represent A,C,G,T and allele value 4 represents an unknown base
				if (alleleConfig[l] > pSet.numAlleles()) {
					throw new IOException("defined allele value " + alleleConfig[l] + " does not match the mutation matrix size " + pSet.numAlleles());
				}
			}
			// add our new haplotype
			Haplotype haplotype = new Haplotype(alleleConfig);
			hapList.add(haplotype); 
		}
		return hapList;
	}
	
	public static double[][] readTimesFile(String timesFile) throws IOException {
		BufferedReader timesReader = new BufferedReader(new FileReader(timesFile));
		
		String[] timeStrings = readLine(timesReader).split(" ");
		String[] sizeStrings = readLine(timesReader).split(" "); // this could be empty
		double[][] timesSizes = new double[2][timeStrings.length];
		
		for (int i=0; i < timeStrings.length; i++) {
			timesSizes[0][i] = Double.parseDouble(timeStrings[i].trim());
		}
		// if population sizes are specified
		if (sizeStrings.length > 1) {
			if (timeStrings.length != sizeStrings.length) { System.err.println("Error: number of sizes in file does not match number of times."); System.exit(0); };
			for (int i=0; i < sizeStrings.length; i++) {
				timesSizes[1][i] = Double.parseDouble(sizeStrings[i].trim());
			}
		}
		
		return timesSizes;
	}
	
	public static String readLine(BufferedReader reader) throws IOException {
		String readString = null;
		while ((readString = reader.readLine()) != null) {
			readString = readString.trim();
			
			if (readString.startsWith("#")) continue; // for parameter format
			if (readString.startsWith(">")) continue; // for fasta
			if (readString.isEmpty()) continue;
			
			return readString;
		}
		
		return "";
	}
	
	// no longer assuming we have four alleles (but if we have less, should be in order, i.e. for two alleles, should be A/C)
	// the number of alleles is now determined by the mutation matrix
	public static Matrix readMatrix(String matrixString) throws IOException {
		String[] rowStrings = matrixString.split("\\|");
		int numAlleles = rowStrings.length;

		Matrix m = new Matrix(numAlleles, numAlleles);
		for (int a1 = 0; a1 < numAlleles; a1++) {
			String[] rElements = rowStrings[a1].split(",");
			
			if (rElements.length != numAlleles) { throw new IOException(rElements.length + " column(s) in transition matrix detected (Should be " + numAlleles + ")"); }
			
			double rowSum = 0;
			for (int a2 = 0; a2 < numAlleles; a2++) {
				m.set(a1, a2, Double.parseDouble(rElements[a2].trim()));
				rowSum += m.get(a1, a2);
			}
			
			for (int a2 = 0; a2 < numAlleles; a2++) {
				m.set(a1, a2, m.get(a1, a2) / rowSum);
			}
		}
		return m;
	}
	
	// function to compute discretization for d points based on the data and parameters (described in the appendix)
	public static double[] computeDiscretization(List<List<Haplotype>> chromosomes, int d, ParamSet pSet, int unknown) {
		List<Double> ibsLengths = new ArrayList<Double>();
		double theta = pSet.getMutationRate();
		double rho = pSet.getRecombinationRate();
		
		double offset = 0.001; // offset to add if the last two discretization points are equal

		for (List<Haplotype> hapList : chromosomes) { // accommodating multiple chromosomes
			assert hapList.size() > 0;
			int numLoci = hapList.get(0).getNumLoci();
			int startLen = 1;
			for (int i=0; i < hapList.size(); i++) {
				for (int j=i+1; j < hapList.size(); j++) {
					Double length = new Double(startLen);
					for (int k=0; k < numLoci; k++) {
						// modifying for missing data
						if (alleleMatch(hapList.get(i).getAllele(k), hapList.get(j).getAllele(k), unknown)) {
							length += new Double(1);
					    } else {
					    	if (length != 0) {
					    		ibsLengths.add(length);
					    	}
					    	length = new Double(startLen);
					    }
					}
				}
			}
		}

		Collections.sort(ibsLengths);
		int numDiffs=ibsLengths.size();

		double[] endPoints = new double[d];
		endPoints[0] = 0;
		for (int i=1; i < d; i++) {
			double biggestLength = ibsLengths.get(i*numDiffs/d);
			endPoints[d-i] = 1.0/(biggestLength*(rho+theta));
		}
		
		// inserting check to make sure no two points are the same
		for (int i=1; i < d; i++) {
			if (endPoints[i-1] == endPoints[i]) {
				if (i < d-1) {
					endPoints[i] = (endPoints[i]+endPoints[i+1])/2;
				} else {
					endPoints[i] = endPoints[i]+offset;
				}
			}
		}
		
		return endPoints;
	}
	
	// function to compute allele matching, making sure N/4 matches with all other alleles
	private static boolean alleleMatch(int a, int b, int unknown) {
		if (a == unknown || b == unknown || a == b) {
			return true;
		}
		return false;
	}
	
	// rescale the prior based on an end point
	public static double[] rescaleQuadraturePoints(double[] priorPoints, double endTime) {
		int d = priorPoints.length;
		double scalingFactor = endTime/priorPoints[d-1];
			
		double[] scaledPoints = new double[d];
		for (int i=0; i < d; i++) {
			scaledPoints[i] = priorPoints[i]*scalingFactor;
		}
		return scaledPoints;
	}
}