/*
 * License: FreeBSD (Berkeley Software Distribution)
 * Copyright (c) 2013, Sara Sheehan, Kelley Harris, Yun Song
 */

package edu.berkeley.utility;

// this class is solely for computing the "best" parameter pattern for the given expected segments
public class ParamPattern {
	
	private int minIntervals;
	private int maxIntervals;
	
	public ParamPattern(int minIntervals, int maxIntervals) {
		this.minIntervals = minIntervals;
		this.maxIntervals = maxIntervals;
	}

	// method for computing the sum of expected segs up to (and including) index
	private static double testSum(double[] expectedSegments, int index) {
		assert expectedSegments.length > index;
		
		double sum = 0;
		for (int i=0; i < index+1; i++) {
			sum += expectedSegments[i];
		}
		return sum;
	}
	
	// method for assessing the goodness of a parameter pattern
	private static double goodnessOfFit(int[] paramPattern, double[] expectedSegments, boolean addingSegs) {
		int p = paramPattern.length;
		int d = expectedSegments.length;
		double[] paramSegments = new double[p];
		double targetSegments = Utility.sumArray(expectedSegments) / p;
		double maxError = Double.NEGATIVE_INFINITY;
		
		int index = 0;
		double error = 0;
		for (int i=0; i < p; i++) {
			for (int m=0; m < paramPattern[i]; m++) {
				
				if (index < d) {
					paramSegments[i] += expectedSegments[index];
					index ++;
				}
			}
			// compute the error for each parameter
			if (addingSegs) {
				error = targetSegments - paramSegments[i];
			} else {
				error = paramSegments[i] - targetSegments;
			}
			// find the max error
			if (error > maxError) {
				maxError = error;
			}
		}
		
		return maxError;
	}
	
	// method for finding the parameter that needs more intervals, based on expected segments
	private int minIndex(int[] paramPattern, double[] expectedSegments) {
		
		// first go through and find all the parameters that could take another interval
		int p = paramPattern.length;
		boolean[] addMore = new boolean[p];
		for (int i=0; i < p; i++) {
			if (paramPattern[i] < this.maxIntervals) {
				addMore[i] = true;
			} else {
				addMore[i] = false;
			}
		}
		
		// test each one to find the min error
		int minIndex = 0;
		double minError = Float.MAX_VALUE;
		for (int i=0; i < paramPattern.length; i++) {
			if (addMore[i]) {
				paramPattern[i] += 1;
				double testError = goodnessOfFit(paramPattern, expectedSegments, true);
				if (testError <= minError) { // this <= moves intervals to the end, where they are generally needed more
					minError = testError;
					minIndex = i;
				}
				paramPattern[i] -= 1;
			}
		}
		
		return minIndex;
	}
	
	// method for finding the parameter that needs fewer intervals (starting from the front)
	private int maxIndex(int[] paramPattern, double[] expectedSegments) {
		
		// first go through and find all the parameters that could have an interval removed
		int p = paramPattern.length;
		boolean[] removeMore = new boolean[p];
		for (int i=0; i < p; i++) {
			if (paramPattern[i] > this.minIntervals) {
				removeMore[i] = true;
			} else {
				removeMore[i] = false;
			}
		}
		
		// test each one to find the min error
		int maxIndex = 0;
		double minError = Float.MAX_VALUE;
		for (int i=0; i < paramPattern.length; i++) {
			if (removeMore[i]) {
				paramPattern[i] -= 1;
				double testError = goodnessOfFit(paramPattern, expectedSegments, false);
				if (testError < minError) { // do not use <= here, to push intervals to the beginning
					minError = testError;
					maxIndex = i;
				}
				paramPattern[i] += 1;
			}
		}
		
		return maxIndex;
	}
	
	// method for removing intervals if we have too many
	private int[] removeIntervals(int d, int[] paramPattern, double[] expectedSegments) {
		int paramSum = Utility.sumArray(paramPattern);
		
		// add on intervals until we get to d
		for (int i=0; i < paramSum - d; i++) {
			int maxIndex = maxIndex(paramPattern, expectedSegments);
			paramPattern[maxIndex] -= 1;
		}
		
		assert d == Utility.sumArray(paramPattern);
		return paramPattern;
	}
	
	// method for adding intervals if we have too few
	private int[] addIntervals(int d, int[] paramPattern, double[] expectedSegments) {
		int paramSum = Utility.sumArray(paramPattern);
		
		// add on intervals until we get to d
		for (int i=0; i < d - paramSum; i++) {
			int minIndex = minIndex(paramPattern, expectedSegments);
			paramPattern[minIndex] += 1;
		}
		
		assert d == Utility.sumArray(paramPattern);
		return paramPattern;
	}
	
	// given an array of expected segments, find the best param pattern
	public int[] rebalanceParams(double[] expectedSegments, int p) {
		int d = expectedSegments.length;
		
		// compute the target number of expected segments
		double targetSum = Utility.sumArray(expectedSegments) / p;
		System.out.println("target segments per parameter " + targetSum);
		
		int[] paramPattern = new int[p];
		
		// for each parameter
		int currIndex = 0; // index for expectedSegments
		for (int j=0; j < p; j++) {
			double currTargetSum = targetSum * (j+1);
			int bestIntervals = this.minIntervals; // minimum to add on
			
			// if currIndex + maxIntervals is too many, decrease
			int currMaxIntervals = this.maxIntervals;
			if (currIndex + this.maxIntervals >= d) {
				currMaxIntervals = d - currIndex - 1;
			}
			
			// if we have any indices to test
			if (currMaxIntervals+1-this.minIntervals > 0) {
			
				// test out each possible number of parameters, from MIN to MAX
				double[] testTargetArray = new double[currMaxIntervals+1-this.minIntervals];
				for (int testIntervals = this.minIntervals; testIntervals <= currMaxIntervals; testIntervals++) {
					testTargetArray[testIntervals-this.minIntervals] = testSum(expectedSegments, currIndex+testIntervals-1);
				}
			
				// find the best (closest to currTargetSum) interval
				double bestDiff = Float.POSITIVE_INFINITY;
				for (int m=0; m < testTargetArray.length; m++) {
					double testDiff = Math.abs(testTargetArray[m] - currTargetSum);
					if (testDiff < bestDiff) {
						bestDiff = testDiff;
						bestIntervals = m+this.minIntervals;
					}
				}
			}
			
			// update the paramPattern and currIndex
			paramPattern[j] = bestIntervals;
			currIndex += bestIntervals;
		}
		
		// double check our parameter pattern sums to d, add/remove intervals if not
		int paramSum = Utility.sumArray(paramPattern);
		if (paramSum < d) {
			return addIntervals(d, paramPattern, expectedSegments);
		} else if (paramSum > d) {
			return removeIntervals(d, paramPattern, expectedSegments);
		} else {
			return paramPattern;
		}
	}
}
