////////////////////////////////////////////////////////////////////////////////
//
// fast_pmf_cycle.c
//
////////////////////////////////////////////////////////////////////////////////
#include "fast_pmf_cycle.h"
#include <stdio.h>
#include <stdlib.h>
#include <math.h>

////////////////////////////////////////////////////////////////////////////////
//
// GLOBALS
//
////////////////////////////////////////////////////////////////////////////////
double NONE_VAL = -9999;
double ABS_Z_SCORE_MAX = 10;
double MIN_LOG_RATIO = -100;

////////////////////////////////////////////////////////////////////////////////
//
// ACCESSOR FUNCTIONS
//
////////////////////////////////////////////////////////////////////////////////

////////////////////////////////////////
//
// get_DEBUG
//
////////////////////////////////////////
int get_DEBUG() {
  return s->DEBUG;
}

////////////////////////////////////////
//
// set_DEBUG
//
////////////////////////////////////////
void set_DEBUG(int DEBUG) {
  s->DEBUG = DEBUG;
}

////////////////////////////////////////
//
// get_motif_mean
//
////////////////////////////////////////
double get_motif_mean(int mod, int pos) {
  return d->motif_mean[mod][pos];
}

////////////////////////////////////////
//
// set_motif_mean
//
////////////////////////////////////////
void set_motif_mean(int mod, int pos, double val) {
  d->motif_mean[mod][pos] = val;
}

////////////////////////////////////////
//
// get_motif_std
//
////////////////////////////////////////
double get_motif_std(int mod, int pos) {
  return d->motif_std[mod][pos];
}

////////////////////////////////////////
//
// set_motif_std
//
////////////////////////////////////////
void set_motif_std(int mod, int pos, double val) {
  d->motif_std[mod][pos] = val;
}

////////////////////////////////////////
//
// get_bg_mean
//
////////////////////////////////////////
double get_bg_mean(int mod) {
  return d->bg_mean[mod];
}

////////////////////////////////////////
//
// set_bg_mean
//
////////////////////////////////////////
void set_bg_mean(int mod, double val) {
  d->bg_mean[mod] = val;
}

////////////////////////////////////////
//
// get_bg_std
//
////////////////////////////////////////
double get_bg_std(int mod) {
  return d->bg_std[mod];
}

////////////////////////////////////////
//
// set_bg_std
//
////////////////////////////////////////
void set_bg_std(int mod, double val) {
  d->bg_std[mod] = val;
}

////////////////////////////////////////
//
// get_region
//
////////////////////////////////////////
double get_region(int mod, int pos) {
  return d->region[mod][pos];
}

////////////////////////////////////////
//
// set_region
//
////////////////////////////////////////
void set_region(int mod, int pos, double val) {
  d->region[mod][pos] = val;
}

////////////////////////////////////////
//
// get_return_pol
//
////////////////////////////////////////
int get_return_pol() {
  return s->return_pol;
}

////////////////////////////////////////
//
// get_return_loc
//
////////////////////////////////////////
int get_return_loc() {
  return s->return_loc;
}

////////////////////////////////////////////////////////////////////////////////
//
// init
//
////////////////////////////////////////////////////////////////////////////////
void init(int num_mod, int width, int wander_dist, double prior_model,
	  double prior_bg, double std_factor) {
  int i;

  ////////////////////
  // main struct alloc
  ////////////////////
  d = (Data*) malloc(sizeof(Data));

  // set some values
  d->num_mod     = num_mod;
  d->width       = width;
  d->wander_dist = wander_dist;
  d->std_factor  = std_factor;
  d->prior_model = prior_model;
  d->prior_bg    = prior_bg;

  // motif_mean alloc
  d->motif_mean = (double**) malloc(sizeof(double*) * num_mod);
  for(i=0 ; i < num_mod ; i++) {
    d->motif_mean[i] = (double*) malloc(sizeof(double) * (width+1));
  }

  // motif_std alloc
  d->motif_std = (double**) malloc(sizeof(double*) * num_mod);
  for(i=0 ; i < num_mod ; i++) {
    d->motif_std[i] = (double*) malloc(sizeof(double) * (width+1));
  }

  // region alloc
  int total_size = width + wander_dist + 1;
  d->region = (double**) malloc(sizeof(double*) * num_mod);
  for(i=0 ; i < num_mod ; i++) {
    d->region[i] = (double*) malloc(sizeof(double) * total_size);
  }

  // bg_mean alloc
  d->bg_mean      = (double*) malloc(sizeof(double) * num_mod);

  // bg_std alloc
  d->bg_std       = (double*) malloc(sizeof(double) * num_mod);

  // mean_of_stds alloc
  d->mean_of_stds = (double*) malloc(sizeof(double) * num_mod);


  ////////////////
  // scratch alloc
  ////////////////
  s = (Scratch*) malloc(sizeof(Scratch));

  // weights_pos_motif alloc
  s->weights_motif_pos = (double**) malloc(sizeof(double*) * num_mod);
  for(i=0 ; i < num_mod ; i++) {
    s->weights_motif_pos[i] = (double*) malloc(sizeof(double) * (wander_dist+1));
  }

  // weights_neg_motif alloc
  s->weights_motif_neg = (double**) malloc(sizeof(double*) * num_mod);
  for(i=0 ; i < num_mod ; i++) {
    s->weights_motif_neg[i] = (double*) malloc(sizeof(double) * (wander_dist+1));
  }

  // weights_pos_bg alloc
  s->weights_bg_pos = (double**) malloc(sizeof(double*) * num_mod);
  for(i=0 ; i < num_mod ; i++) {
    s->weights_bg_pos[i] = (double*) malloc(sizeof(double) * (wander_dist+1));
  }

  // weights_neg_bg alloc
  s->weights_bg_neg = (double**) malloc(sizeof(double*) * num_mod);
  for(i=0 ; i < num_mod ; i++) {
    s->weights_bg_neg[i] = (double*) malloc(sizeof(double) * (wander_dist+1));
  }

  // votes_pos alloc
  s->votes_pos = (double*) malloc(sizeof(double) * (wander_dist+1));

  // votes_pos alloc
  s->votes_neg = (double*) malloc(sizeof(double) * (wander_dist+1));
}

////////////////////////////////////////////////////////////////////////////////
//
// set_mean_of_stds
//
////////////////////////////////////////////////////////////////////////////////
void set_motif_mean_of_std() {
  int mod;
  int motif_index;

  for (mod = 0 ; mod < d->num_mod ; mod++) {
    // init
    d->mean_of_stds[mod] = 0;

    // compute mean
    for(motif_index = 0 ; motif_index <= d->width ; motif_index++) {
      d->mean_of_stds[mod] += d->motif_std[mod][motif_index];
    }
    d->mean_of_stds[mod] /= (d->width + 1);
  }
}

////////////////////////////////////////////////////////////////////////////////
//
// compute_pmf
//
////////////////////////////////////////////////////////////////////////////////
void compute_pmf() {
  int mod;
  int center_index;
  int pol;
  int motif_index;

  double temp_sum_log_A = (d->width + 1) * -1 * (d->std_factor * d->std_factor / 2);

  if (s->DEBUG) {
    printf("\nstd factor = %.4g\n", d->std_factor);
    printf("width = %d\n", d->width);
    printf("\ncumulative weights\n");
    setvbuf ( stdout , NULL , _IONBF , 1024 );
  }

  for (mod = 0 ; mod < d->num_mod ; mod++) {
    for (pol = -1 ; pol <= 1 ; pol += 2) {
      for (center_index = 0 ; center_index <= d->wander_dist ; center_index++) {

	if (s->DEBUG) {
	  printf("center_index = %d, pol = %d, mod = %d\n", center_index, pol, mod);
	}

	int num_none = 0;
	double temp_sum_log_P = 0;
	double temp_sum_log_B = 0;

	for (motif_index = 0 ; motif_index <= d->width ; motif_index++) {
	  int index = (pol == 1) ? motif_index : (d->width - motif_index);
	  int loc = center_index + index;

	  double logr = d->region[mod][loc];
	  if (logr == NONE_VAL) {
	    num_none++;

	  } else {
	    //find P, the probability under the motif model
	    double mean = d->motif_mean[mod][motif_index];
	    //	    double std  = d->motif_std[mod][motif_index];
	    double std  = d->motif_std[mod][motif_index]; // + d->bg_std[mod] / 2;
	    double Z_score_P_j = (logr - mean) / std;
	    if (fabs(Z_score_P_j) > ABS_Z_SCORE_MAX) {
	      //printf("Z_score_P_j (%.4g) too small, setting to %.4g\n", fabs(Z_score_P_j), ABS_Z_SCORE_MAX);
	      Z_score_P_j = ABS_Z_SCORE_MAX;
	    }
	    //double P_j = exp(-(Z_score_P_j * Z_score_P_j / 2));
	    //temp_sum_log_P += log( P_j );
	    temp_sum_log_P += -(Z_score_P_j * Z_score_P_j / 2);

	    //find B, the probability under the uniform background model
	    double bg_mean = d->bg_mean[mod];
	    //	    double bg_std  = d->mean_of_stds[mod];
	    double bg_std  = d->mean_of_stds[mod]; // + d->bg_std[mod] / 2;
	    double Z_score_B_j = (logr - bg_mean) / bg_std;
	    if (fabs(Z_score_B_j) > ABS_Z_SCORE_MAX) {
	      Z_score_B_j = ABS_Z_SCORE_MAX;
	    }
	    //double B_j = exp(-(Z_score_B_j * Z_score_B_j / 2));
	    //temp_sum_log_B += log( B_j );
	    temp_sum_log_B += -(Z_score_B_j * Z_score_B_j / 2);

	    if (s->DEBUG) {
	      printf("\t%d\t%.4g\t%.4g\t%.4g\t%.4g\t%.4g\t%.4g\t%.4g\n",
		     motif_index, logr, mean, std, Z_score_P_j, exp(-(Z_score_P_j * Z_score_P_j / 2)),
		     temp_sum_log_P, temp_sum_log_B);
	    }

	    // find A, the probability under the background model
	    //double A_j = exp(-(d->std_factor * d->std_factor / 2));
	    //temp_sum_log_A += log( A_j );
	  }
	}

	if (num_none > 0) {
	  int num_not_none = d->width + 1 - num_none;

	  double ave_P = temp_sum_log_P / num_not_none;
	  //double ave_A = temp_sum_log_A / num_not_none;
	  double ave_B = temp_sum_log_B / num_not_none;

	  temp_sum_log_P += ave_P * num_none;
	  //temp_sum_log_A += ave_A * num_none;
	  temp_sum_log_B += ave_B * num_none;
	}

	if (s->DEBUG) {
	  printf("CC\t%d\t%d\t%d\t%.4g\t%.4g\t%.4g\n",
		 mod,
		 pol,
		 center_index,
		 temp_sum_log_P,
		 temp_sum_log_A,
		 temp_sum_log_B);
	}

	if (pol == 1) {
	  s->weights_motif_pos[mod][center_index] = d->prior_model * exp(temp_sum_log_P);
	  s->weights_bg_pos[mod][center_index]    = d->prior_model * exp(temp_sum_log_A) +
	                                            d->prior_bg    * exp(temp_sum_log_B);

	} else {
	  s->weights_motif_neg[mod][center_index] = d->prior_model * exp(temp_sum_log_P);
	  s->weights_bg_neg[mod][center_index]    = d->prior_model * exp(temp_sum_log_A) +
	                                            d->prior_bg    * exp(temp_sum_log_B);
	}
      }
    }
  }
}

////////////////////////////////////////////////////////////////////////////////
//
// get_votes
//
////////////////////////////////////////////////////////////////////////////////
int get_votes() {
  int mod;
  int pol;
  int center_index;

  // if any single modification's best alignment is not as good as the background,
  // then reject
  for (mod = 0 ; mod < d->num_mod ; mod++) {
    int reject = 1;
    for (center_index = 0 ; center_index <= d->wander_dist ; center_index++) {

      // pol == 1
      if (s->weights_motif_pos[mod][center_index] > s->weights_bg_pos[mod][center_index]) {
	reject = 0;
	break;
      }

      // pol == -1
      if (s->weights_motif_neg[mod][center_index] > s->weights_bg_neg[mod][center_index]) {
	reject = 0;
	break;
      }
    }

    if (reject) {
      if (s->DEBUG) {
	printf ("rejecting on mod = %d\n", mod);
      }

      return 1;
    }
  }

  double norm;
  for (center_index = 0 ; center_index <= d->wander_dist ; center_index++) {
    s->votes_pos[center_index] = 0;
    s->votes_neg[center_index] = 0;

    for (mod = 0 ; mod < d->num_mod ; mod++) {

      // pol = 1
      if (s->weights_motif_pos[mod][center_index] != 0) {
	norm = log( s->weights_motif_pos[mod][center_index] / 
		    s->weights_bg_pos[mod][center_index] );
      } else {
	norm = MIN_LOG_RATIO;
      }
      s->votes_pos[center_index] += norm;

      // pol = -1
      if (s->weights_motif_neg[mod][center_index] != 0) {
	norm = log( s->weights_motif_neg[mod][center_index] / 
		    s->weights_bg_neg[mod][center_index] );
      } else {
	norm = MIN_LOG_RATIO;
      }
      s->votes_neg[center_index] += norm;
    }
  }

  return 0;
}

////////////////////////////////////////////////////////////////////////////////
//
// tally_votes
//
////////////////////////////////////////////////////////////////////////////////
//void tally_votes(int has_exclude_max) {
void tally_votes() {
  int center_index;

  /*
  // one of the modifications had the most probable state being exclude
  if (has_exclude_max) {
    s->return_pol = 1;
    s->return_loc = NONE_VAL;
    return;
  }
  */

  int max_index = NONE_VAL;
  int max_pol   = NONE_VAL;
  double max_val   = NONE_VAL;

  // pol = -1
  for (center_index = 0 ; center_index <= d->wander_dist ; center_index++) {

    if (s->DEBUG) {
      printf("-1\t%d\t%.4g\t%d\t%d\t%.4g\n", center_index, s->votes_neg[center_index],
	     max_index, max_pol, max_val);
    }

    if ((max_val == NONE_VAL) ||
	(max_val < s->votes_neg[center_index])) {
      max_index = center_index;
      max_pol = -1;
      max_val = s->votes_neg[center_index];
    }
  }

  // pol = 1
  for (center_index = 0 ; center_index <= d->wander_dist ; center_index++) {

    if (s->DEBUG) {
      printf("1\t%d\t%.4g\t%d\t%d\t%.4g\n", center_index, s->votes_pos[center_index],
	     max_index, max_pol, max_val);
    }

    if ((max_val == NONE_VAL) ||
	(max_val < s->votes_pos[center_index])) {
      max_index = center_index;
      max_pol = 1;
      max_val = s->votes_pos[center_index];
    }
  }

  if (s->DEBUG) {
    printf("max\t%d\t%d\t%.4g\n", max_index, max_pol, max_val);
  }

  s->return_pol = max_pol;
  s->return_loc = max_index - d->wander_dist/2;
}
