/*
 * emfunction.cpp
 * This file is part of isoLasso library; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License as
 * published by the Free Software Foundation; version 2.1 of the License.
 *
 *  Created on: 2012-04-03
 *      Author: wei li, yingsheng gao
 */

#include "emfunction.h"
#include "predExpLevel.h"
#include "mathlib.h"
#include "library.h"
#include "NewInstance.h"
#include <algorithm>
#include <functional>
#include <gsl/gsl_matrix.h>
#include <gsl/gsl_linalg.h>
#include <sstream>
#include <numeric>
#include <math.h>
#include <string.h>
#include <stdlib.h>
#include <set>
#include <limits>
#include <iostream>

/* MINIMUM DOUBLE VALUE */
double DOUBLE_MIN_VAL=-99999999999;

/* storing EM parameters and other values*/
struct EMPara{
  /*EM parameters*/
  bool useelimination;
  bool usebias;

  bool singleonly;
  double cutrate;
  bool printdebuginfo;
  
  double minpicut;

  //The weights for indistinguishable reads; should be <1
  double indistcoef;


  /* Negative dirichlet distribution parameters */
  double alphalow;
  double alphahigh;

  /* Loop parameters */
  //Maximum loop iterations
  long maxloopcount;
  //the minimum loop to for the next deletion
  int preventearlydeletion;

  /* EM auxiliary variables */
  //total probability
  double prob; 
  //the value of bias parameter
  double tao; 
  //the actual read count
  long actualReadCnt;

  //initialize; default parameters
  EMPara(){
    singleonly=true;
    useelimination=false;
    maxloopcount=1000;
    cutrate=1e-6;
    usebias=false;
    printdebuginfo=false;
    minpicut=0.01;
    preventearlydeletion=10;
    indistcoef = 0.0;
    actualReadCnt=0;

    alphalow=0;
    alphahigh=5;
  }
};

/* auxiliary function declarations */
//void EMloop(vector<double> &isopi, bool singleonly, vector<vector<double> > &rdSupport, vector<long> &SGCounter, 
//            vector<long> &alphas, bool useelimination, long maxloopcount, double cutrate, bool usebias, 
//						bool printdebuginfo, double &tmpa, long &tao);
void EMloop(vector<double> &isopi,  vector<vector<double> > &rdSupport, vector<long> &rdCounter, 
            vector<vector<double> > &sgSupport, vector<long> &sgCounter, 
            vector<long> &alphas, EMPara& para);

/* Calculate the value of tao using newton-raphson */
double calculatetao(double tau, vector<vector<double> > &pairendgamma, vector<vector<double> > &rdSupport, vector<long> &SGCounter);
double calculateftao(double tau, vector<vector<double> > &pairendgamma, vector<vector<double> > &rdSupport, vector<long> &SGCounter);
double newtonraphson(double tao, vector<vector<double> > &pairendgamma, vector<vector<double> > &rdSupport, vector<long> &SGCounter);
/* Binary searching values of tao using the equation */
double calculatetaobybinsearch(double tau, vector<vector<double> > &pairendgamma, 
   vector<vector<double> > &rdSupport, vector<long> &SGCounter);

/* Calculate the value of tao in another way */
double calculatetao2(double tau, vector<double > &isopi, vector<vector<double> > &rdSupport, vector<long> &SGCounter, EMPara& para);



/* EM function */
void emselect(NewInstance &oneins, 
  vector<vector<long> >&alliso, 
  vector<string> &varpara, 
  vector<vector<long> > &pred, 
  vector<double> &predexplv, 
  STAT &stat)
{
  double smallpositive = 1e-200;
  long usecvg = 0;
  bool correctn = false;

  // Set up parameters
  EMPara para;
  para.tao=0;

  bool nofilter=false;

  for (int i = 0; i < varpara.size(); ++i){
    if (varpara[i] == "--stopcriteria")
      para.cutrate = atof(varpara[i + 1].c_str());
    if (varpara[i] == "--maxloop")
      para.maxloopcount = atol(varpara[i + 1].c_str());
    if (varpara[i] == "--elim")
      para.useelimination = true;//getBoolenValue(varpara[i + 1]);
    if (varpara[i] == "--usebias")
      para.usebias = true;
    if (varpara[i] == "--usecvg")
      usecvg = 1;
    if (varpara[i] == "--correctn")
      correctn = true;
    if (varpara[i] == "--verbose")
      para.printdebuginfo=true;
    if( varpara[i] == "--alpharef")
      para.alphalow=atol(varpara[i+1].c_str());
    if( varpara[i] == "--alpha")
      para.alphahigh=atol(varpara[i+1].c_str());
    if( varpara[i] == "--min-frac")
      para.minpicut=atol(varpara[i+1].c_str());
    if( varpara[i] == "--no-filter")
      nofilter=true;
    //TODO useposbias
  }

  stat.negw = 0; stat.nw = 0; stat.ncandidate = 0; stat.tau = 0; stat.adjust = 1;
  long nIsos = alliso.size(), nExons = alliso[0].size();
  long nSGTypes = oneins.SGTypes.size();

  //length of the exons and isoforms
  vector<long> exonLen;
  vector<long> isoLenOld;
  vector<long> isoLen;
  exonLen.assign(oneins.exonLen.begin(), oneins.exonLen.end());
  for (long i = 0; i < nIsos; ++i){
    long tmp = 0;
    for (long j = 0; j < alliso[i].size(); ++j)
      tmp += alliso[i][j] * exonLen[j];
    isoLenOld.push_back(tmp);
  }
  for (long i = 0; i < isoLenOld.size(); ++i)
    isoLen.push_back(isoLenOld[i] - oneins.readLen + 1);

  //isoform dirs
  vector<long> isoDir;
  zeros(isoDir, nIsos);
  if (oneins.SGDirs.size() > 0){
    for (long i = 0; i < nIsos; ++i){
      for (long j = 0; j < nSGTypes; ++j)
        if (checkcompatible(alliso[i], oneins.SGTypes[j])){
          long currentdir = oneins.SGDirs[j];
          isoDir[i] = isoDir[i] + currentdir;
        }
    }
  }
  for (long i = 0; i < isoDir.size(); ++i)
    if (isoDir[i] > 0) isoDir[i] = 1;
    else if (isoDir[i] < 0) isoDir[i] = -1;


  //wleft, wright: not used anymore?
  vector<vector<long> > wleft, wright;
  if (alliso.size() > 0){
    zeros(wleft, alliso.size(), alliso[0].size());
    wright.assign(wleft.begin(), wleft.end());
  }

  for (long i = 0; i < nIsos; ++i){
    vector<long> candidate;
    for (long j = 0; j < alliso[i].size(); ++j)
      if (alliso[i][j] != 0) candidate.push_back(j);
    long totallen = isoLenOld[i];
    long cdr;
    if (isoDir[i] > 0) {
      for (long j = 1; j < candidate.size(); ++j){
        cdr = candidate[j];
        wleft[i][cdr] = 0;
        for (long k = 0; k <= j - 1; ++k)
          wleft[i][cdr] += oneins.exonLen[candidate[k]];
        wleft[i][cdr] /= totallen;
      }
      for (long j = 0; j < candidate.size(); ++j){
        cdr = candidate[j];
        wright[i][cdr] = 0;
        for (long k = 0; k <= j; ++k)
          wright[i][cdr] += oneins.exonLen[candidate[k]];
        wright[i][cdr] /= totallen;
      }
    }
    else if (isoDir[i] < 0){
      for (long j = 0; j < candidate.size() / 2; ++j){
        long swt; 
        swt = candidate[j];
        candidate[j] = candidate[candidate.size() - 1 - j];
        candidate[candidate.size() - 1 - j] = swt;
      }
      for (long j = 1;  j < candidate.size(); ++j){
        cdr = candidate[j];
        wright[i][cdr] = 0;
        for (long k = 0; k <= j - 1; ++k)
          wright[i][cdr] += oneins.exonLen[candidate[k]];
        wright[i][cdr] /= totallen;
      }
      for (long j = 0; j < candidate.size(); ++j){
        cdr = candidate[j];
        wleft[i][cdr] = 0;
        for (long k = 0; k <= j; ++k)
          wleft[i][cdr] += oneins.exonLen[candidate[k]];
        wleft[i][cdr] /= totallen;
      }
    }
  }

  // incorporate prior gamma
  vector<long> alphas;
  zeros(alphas, nIsos);

  //check if existing candidate includes reference isoform
  for (long i = 0; i < alliso.size(); ++i){
    bool appearinref = false;
    for (long j = 0; j < oneins.Refs.size(); ++j)
      if (checkcompatibleofisoforms(oneins.Refs[j], alliso[i], oneins.exonBoundary)){
        appearinref = true;
        break;
      }
    if (appearinref)
      alphas[i] = para.alphalow;
    else alphas[i] = para.alphahigh;
  }

  //calculate the support of each read (or read type) to each isoform 
  // for single-end/paired-end reads
  
  // single-end read support table, a nSGTypes*ncandidate matrix
  // element (a,b) is the support of each single end read a to isoform b
  vector<vector<double> > SGSupport;
  vector<vector<double> > CVGSGSupport;
  vector<long> CVGSGCount, SGCount;
  SGCount.assign(oneins.SGCounts.begin(), oneins.SGCounts.end());
  zeros(SGSupport, nSGTypes, nIsos);
  long nnentry;
  long ncvgentry = 0;  

  if (usecvg){
    for (long i = 0; i < oneins.Coverage.size(); ++i){
      for (long j = 0; j < oneins.Coverage[i].size(); ++j)
        ncvgentry = ncvgentry + oneins.Coverage[i][j][1];
    }
    zeros(CVGSGSupport, ncvgentry, nIsos);
    zeros(CVGSGCount, ncvgentry);
    nnentry = ncvgentry;
    ncvgentry = 0;
  }

  for (long j = 0; j < nIsos; ++j){
    for (long i = 0; i < nSGTypes; ++i){
      if (checkcompatible(alliso[j], oneins.SGTypes[i]) == true){
        ////there is no pospofile
        //if (isoDir[j] == 0){
            // the number of valid positions for the current type can be calculated using the following two equations.
            // suppose the current type includes exon x1, x2, ... xk. let the position along x1 be a. then a should satisfy
            // 1 <= a <=x1. the last position is a+readlen-1 and should
            // satisfy x1+...+x_k-1 +1 <= a+readlen-1 <=x1+...+x_k.

          vector<long> nf,nelen;
          for (long k = 0; k < oneins.SGTypes[i].size(); ++k){
            if (oneins.SGTypes[i][k] != 0) nf.push_back(k);
            if (oneins.SGTypes[i][k] > 0) nelen.push_back(exonLen[k]);
          }
          long leftp = 0, rightp = 0;
          for (long i = 0; i < nelen.size() - 1; ++i) leftp += nelen[i];
          if (nelen.size() != 0) rightp = leftp + nelen[nelen.size() - 1];
          rightp = rightp - oneins.readLen + 1;
          leftp = leftp - oneins.readLen + 2;
          leftp = max(long(1), leftp);
          rightp = min(nelen[0], rightp);
          if(rightp<leftp)rightp=leftp;
          SGSupport[i][j] = (rightp - leftp + 1) / (double)isoLen[j];
       // }
       // else {
       //   vector<long> nexs;
       //   for (long k = 0; k < oneins.SGTypes[i].size(); ++k){
       //     if (oneins.SGTypes[i][k] != 0) nexs.push_back(k);
       //   }
       //   printf("No pospofile support\n");  
       // }
      }
      
      //fill in the coverage entry
      if (usecvg) {
        long availablepos = 0;
        for (long k = 0; k < exonLen.size(); ++k)
          availablepos += exonLen[k] * oneins.SGTypes[i][k];
        availablepos = availablepos - oneins.readLen + 1;
        if (availablepos < 1)  availablepos = 1;
        double currentprob = SGSupport[i][j] / (double)(availablepos);
        for (long k = 0; k < oneins.Coverage[i].size(); ++k)
          for (long l = 0; l < oneins.Coverage[i][k].size(); ++l){
            CVGSGSupport[ncvgentry][j] = currentprob;
            CVGSGCount[ncvgentry] = oneins.Coverage[i][k][1];
            ++ncvgentry;
          }
      }
    }
    ncvgentry = 0;
  }

  if (usecvg) {
    ncvgentry = nnentry;
    SGSupport.assign(CVGSGSupport.begin(), CVGSGSupport.begin() + ncvgentry); 
    SGCount.assign(CVGSGCount.begin(), CVGSGCount.begin() + ncvgentry);
  }
  //if no reads support, return 0;
  double sumSGSupport = 0;
  for (long i = 0; i < SGSupport.size(); ++i)
    for (long j = 0; j < SGSupport[i].size(); ++j)
      sumSGSupport += SGSupport[i][j];
  if (fabs((double)sumSGSupport) < smalldelta){
    pred.assign(alliso.begin(), alliso.end());
    zeros(predexplv, pred.size());
  }

  // paired-end read support table: PESupport, a nPETypeReads*ncandidate matrix
  // element (a,b) is the support of paired-end read a to isoform b
  // Notice that we have updated the instance structure to compress petypedist.
  // PEcount indicates the number of similar reads (PETypesDistanceCount)
  vector<vector<double> > PESupport;
  vector<long> PEcount;
  //calculate the total number of entries
  int nentry=0; for(int k=0;k<oneins.PETypes.size();k++) nentry+=oneins.PETypes[k][2];//PETypes[k][2]=PETypesDistance[k].size()
  zeros(PESupport, nentry, nIsos);
  PEcount.assign(nentry,1);

  long nstart = 0;
  for (long k  = 0; k < oneins.PETypes.size(); ++k){
    
    //get the types of its pairs
    long nt1 = oneins.PETypes[k][0], nt2 = oneins.PETypes[k][1];

    vector<long> ctype1, ctype2;
    ctype1.assign(oneins.SGTypes[nt1 - 1].begin(), oneins.SGTypes[nt1 - 1].end());
    ctype2.assign(oneins.SGTypes[nt2 - 1].begin(), oneins.SGTypes[nt2 - 1].end());
    vector<long> ed1, ed2;
    for (long l = 0; l < ctype1.size(); ++l)
      if (ctype1[l] != 0) ed1.push_back(l);
    for (long l = 0; l < ctype2.size(); ++l)
      if (ctype2[l] != 0) ed2.push_back(l);
    if (ed1.size() == 0 || ed2.size() == 0){
      printf("EM:Instance :Warning: types are all zero!\n");
      continue;
    }
    long ec1 = ed1[ed1.size() - 1], ec2 = ed2[0];
    long theodist;
    if (ec1 == ec2)
      theodist = 0;
    else
      theodist = oneins.exonBoundary[ec2][0] - oneins.exonBoundary[ec1][1];

    for(int l=0;l<oneins.PETypesDistance[k].size();l++) 
    // iterate candidate isoforms
    for (long j = 0; j < nIsos; ++j){
      if ((checkcompatible(alliso[j], ctype1) == false) || (checkcompatible(alliso[j], ctype2) == false)){
        //zero probability, do nothing
      }
      else {
        long dist = theodist;
        vector<long> jumprange;
        for (long l = ec1 + 1; l <= ec2 - 1; ++l) jumprange.push_back(l);
        long rdist = dist;
        for (long l = 0; l < jumprange.size(); ++l)
          rdist = rdist - exonLen[jumprange[l]] * alliso[j][jumprange[l]];
        vector<long> readdist=oneins.PETypesDistance[k];
        for (long l = 0; l < readdist.size(); ++l){
          readdist[l]  = readdist[l]-rdist + 2 * oneins.peReadLen;
        }
        vector<double> prob;
        zeros(prob, readdist.size());
        for (long l = 0; l < readdist.size(); ++l){
          prob[l] = normpdf(readdist[l], oneins.peReadDis, oneins.peReadDisSTD);
          if (fabs(double(prob[l])) < 1e-200) prob[l] = 1e-200;
          PESupport[nstart+l][j] = prob[l] / isoLen[j];
          //fill the PEcount
          PEcount[nstart+l]=oneins.PETypesDistanceCount[k][l];
        }
      }
    }
    nstart = nstart + oneins.PETypes[k][2];
  }

  // combine as one variable
  // pass to EM according to the read type
  vector<vector<double> >rdSupport;
  vector<long> rdCount;
  if (oneins.PETypes.size() == 0 ){
    para.singleonly = true;
    rdSupport.assign(SGSupport.begin(), SGSupport.end());
    rdCount=SGCount;
  }
  else {
    para.singleonly = false;
    rdSupport.assign(PESupport.begin(), PESupport.end());
    rdCount=PEcount;
  }

  para.actualReadCnt = 0;
  for (long l = 0; l < rdSupport.size(); ++l){
    double rdSum  = 0;
    for (long ll = 0; ll < rdSupport[l].size(); ++ll) rdSum += rdSupport[l][ll];
    if (rdSum > 0) {
      para.actualReadCnt += rdCount[l];
        //DEPRECIATED: para.actualReadCnt++;
    }
  }
  double tau = 0;
  double adjustcoef;

  //cout<<"rdSupport:\n"; for(int i=0;i<rdSupport.size();i++){for(int j=0;j<rdSupport[i].size();j++)cout<<rdSupport[i][j]<<" ";cout<<endl;}
  //cout<<"rdCount:\n";for(int i=0;i<rdCount.size();i++)cout<<rdCount[i]<<" "; cout<<endl;


  //EM loop
  vector<double> isopi(nIsos, 0);
  double sumisopi = 0;
  for (long k = 0; k < nIsos; ++k){
    isopi[k] = (double)rand() / double(RAND_MAX);
    sumisopi += isopi[k];
  }
  for (long k = 0; k < nIsos; ++k) isopi[k] /= sumisopi;
  
  double tmpa;
  long tao=0;
  
  //Initial EM trial
  EMPara para0=para;
  para0.usebias=false; para0.useelimination=false;
  EMloop(isopi,  rdSupport, rdCount, SGSupport, SGCount,  alphas,para0 );
  //EMloop(isopi, singleonly, rdSupport, SGCount, alphas, false, innerloopcount, cutrate, false, false, tmpa, tao);

  if (para.printdebuginfo) {
    printf("Begin formal loop...\n");
    printf("Initial prob:        ");
    for (long k = 0; k < isopi.size(); ++k) printf(" %lf", isopi[k]);
    printf("\n");
  }

  para.tao = 0;
  //Formal EM
  bool usebias=para.usebias;
  if (para.useelimination){
    para.usebias=false;
    EMloop(isopi,  rdSupport, rdCount, SGSupport, SGCount, alphas, para  );
  }
  if (usebias){
    para.useelimination=false;
    para.usebias=usebias;
    EMloop(isopi,  rdSupport, rdCount, SGSupport, SGCount, alphas, para  );
  }

  //update bias parameter
  tao=para.tao;

  vector<double> allisopi(isopi.begin(), isopi.end());
  vector<double> distribreadcnt(alliso.size(),0), assignreads(alliso.size(),0);
  vector<long> sel(alliso.size(),1);
  for (long l = 0; l < isopi.size(); ++l)
    if (!nofilter && isopi[l] < para.minpicut)  sel[l] = 0;

  for (long i = 0; i < oneins.SGTypes.size(); ++i){
    vector<double> probs(alliso.size(),0);
    double sumprobs = 0;
    for (long j = 0; j < alliso.size(); ++j)
      if (checkcompatible(alliso[j], oneins.SGTypes[i]) && sel[j] == 1){
        probs[j] = isopi[j];
        sumprobs += probs[j];
      }
    if (sumprobs > 0){
      for (long j = 0; j < probs.size(); ++j){
        assignreads[j] = probs[j] / sumprobs * oneins.SGCounts[i];
        distribreadcnt[j] = distribreadcnt[j] + assignreads[j];
      }
    }
  }

  if (correctn)
    adjustcoef = 1 / (1.0 + para.actualReadCnt * tao);
  else adjustcoef = 1;

  pred.clear(); 
  predexplv.clear();
  for (long i = 0; i < sel.size(); ++i)
    if (sel[i] > 0){
      pred.push_back(alliso[i]);
      predexplv.push_back(distribreadcnt[i] * adjustcoef / isoLen[i]);
    }
  
  stat.tau = tao;
  stat.adjust = adjustcoef;
}

/* The E-step of EM: Calculate the value of gamma, and return the value of the joint probability in the EM loop */
double EMCalculateGamma(vector<vector<double> >& pairendgamma, 
   vector<vector<double> >& rdSupport, vector<long>& rdCounter, 
   vector<double>& isopi, double tao, EMPara& para)
{
  if (rdSupport.size() > 0){
    zeros(pairendgamma, rdSupport.size(), rdSupport[0].size());
  }
  for (long i = 0; i < rdSupport.size(); ++i){
    for (long j = 0; j < rdSupport[i].size(); ++j){
      //adjust the pairendgamma value here
      //Notice that this works for both single-end and paired-end reads
      pairendgamma[i][j] = (rdSupport[i][j] + rdCounter[i] * tao) * isopi[j];
      if (rdSupport[i][j]<smalldelta) pairendgamma[i][j] = 0;
      if (pairendgamma[i][j] < 0 ) pairendgamma[i][j] = 0;
    }
  }

  // normalize pairendgamma
  vector<double> sumpe(pairendgamma.size(),0);
  for (long i = 0; i < pairendgamma.size(); ++i){
    double ts = 0;
    for (long j = 0; j < pairendgamma[i].size(); ++j) ts += pairendgamma[i][j];
    if (fabs(ts) < smalldelta) sumpe[i]=1;
    else sumpe[i]=ts;
  }
  for (long i = 0; i < pairendgamma.size(); ++i)
    for (long j = 0; j < pairendgamma[i].size(); ++j)
      pairendgamma[i][j] /= sumpe[i];

  //calculate the joint probability
  double jointprob = 0;
  for (long i = 0; i < sumpe.size(); ++i)
    jointprob += log(sumpe[i]) * rdCounter[i];

  return jointprob;
}
       

/* 
The main EM loop function 
Notice the dimention of rdSupport and rdCounter is different for single-end and paired-end reads.
SGSupport and SGCounter are single-end read support and read counter, and are used to estimate bias parameters.
In the single-end read case, sgSupport=rdSupport, sgCounter=rdCounter.
*/
void EMloop(vector<double> &isopi,  
            vector<vector<double> > &rdSupport, vector<long> &rdCounter, 
            vector<vector<double> > &sgSupport, vector<long> &sgCounter, 
            vector<long> &alphas,  EMPara& para){

 

  double & currentprob=para.prob;
  double & tao=para.tao;
  tao = 0;

  double prevprob = DOUBLE_MIN_VAL;
  long loopcounter = 0;
  long nIsos = isopi.size();
  vector<long> selected, deleteorder, deleteloop;
  ones(selected, nIsos);
  //zeros(deleteorder, nIsos);
  //zeros(deleteloop,nIsos);
  vector<vector<double> > pairendgamma;
  vector<vector<double> > biasgamma;

  int nzero=0;
  if(para.useelimination) 
    for(int i=0;i<nIsos;i++)
      if(isopi[i]<para.cutrate){selected[i]=0;nzero++;}

  long  lastdeletion = 0;

  bool showwarning = true;

  while (true){
    for (long i = 0; i < isopi.size(); ++i)
      if (isopi[i] == numeric_limits<double>::quiet_NaN()) {
        printf("Error: nan of isopi\n");
        exit(1);
      }

    //E-step: update pairendgamma and calculate the joint probability
    double jointprob=EMCalculateGamma(pairendgamma,rdSupport,rdCounter,isopi,tao,para);


    //M Step: calculate expression level pi for each isoform
    //First, sum up gammas; calculate the sum for distinguishable and indistinguishable reads separately
    vector<double> sumindistgamma, sumdistgamma;

    //This variable checks if a (single-end?) read supports all isoforms
    vector<long> indist(rdSupport.size(),0);
    long sumselected = 0;
    for (long i = 0; i < selected.size(); ++i)
      sumselected += selected[i];
    for (long i = 0; i < rdSupport.size(); ++i){
      long sumindbin = 0;
      for (long j = 0; j < rdSupport[i].size(); ++j){
        int ispos=(rdSupport[i][j]>0?1:0);
        sumindbin +=ispos* selected[j];        
      }
      if (sumindbin == sumselected) indist[i]=1;
      else indist[i]=0;
    }
    //sum up distinguishable and indistinguishable reads
    if (pairendgamma.size() > 0) {
      zeros(sumindistgamma, pairendgamma[0].size());
      zeros(sumdistgamma, pairendgamma[0].size());
    }
    double sumpegamma=0;
    for (long i = 0; i < pairendgamma.size(); ++i){
      double ts = 0;
      for (long j = 0; j < pairendgamma[i].size(); ++j){
          //single-end mode, rdCounter is used to store # of reads
          ts = pairendgamma[i][j]*rdCounter[i];
          sumpegamma+=ts;
          //DEPRECIATED: in paired-end read mode, rdCounter is not used here
          //if(para.singleonly==false)  ts = pairendgamma[i][j];
        if (indist[i] == 1) 
          sumindistgamma[j]+=(ts);
        else 
          sumdistgamma[j]+=(ts);
      }
    }

    //  % check if we need to put higher weights on distinguishable reads?
    vector<double> sumpairendgamma, sumallgamma;
    double ss = 0;
    for (long i = 0; i < sumindistgamma.size(); ++i){
      sumpairendgamma.push_back(para.indistcoef * sumindistgamma[i] + sumdistgamma[i]);
      ss += sumpairendgamma[i];
      sumallgamma.push_back(sumindistgamma[i] + sumdistgamma[i]);
    }

    if (fabs(ss) <= smalldelta)
      sumpairendgamma=sumallgamma;
     
    //If we need deletion? 
    if (para.useelimination && (loopcounter > para.preventearlydeletion + lastdeletion)){
      ss = 0;
      int numselected=0;
      for (long i = 0; i < sumallgamma.size(); ++i){
        sumallgamma[i] = sumallgamma[i] - alphas[i];
        if (selected[i] == 0) sumallgamma[i] = 0;
        else numselected++;
        ss += sumallgamma[i];
      }
      if (fabs(ss) < smalldelta) {
        break;
      }
      double minv = 0, mini = 0;
      for (long i = 0; i < sumallgamma.size(); ++i)  if (sumallgamma[i] < minv){ minv = sumallgamma[i]; mini = i; }
      if (minv < 0  && numselected>1){
        selected[mini] = 0;
        int nrem=0; for(int i=0;i<selected.size();i++)if(selected[i]>0)nrem++;
        if (para.printdebuginfo){
          cout<<"Del "<<mini<<" at "<<loopcounter<<",remaining:"<<nrem<<",minv:"<<minv;
          cout<<endl;
        }

        deleteorder.push_back( mini);
        deleteloop.push_back(loopcounter);
        lastdeletion = loopcounter;
        //only 1 selection: set the only one to 1 prob, and exit
        if (numselected == 1){
          zeros(isopi, nIsos);
          for (long k = 0; k < selected.size(); ++k)  if (selected[k] != 0) isopi[k] = 1;
          break;
        }
      }
    }

    vector<double> totalgamma=sumpairendgamma;
    ss = 0;
    for (long i = 0; i < totalgamma.size(); ++i){
      if (selected[i] == 0) totalgamma[i] = 0;
      ss += totalgamma[i];
    }
    if (fabs(ss)<smalldelta) { 
     break;
    }

    if (ss != 0 && ss != numeric_limits<double>::quiet_NaN()){
      for (long i = 0; i < totalgamma.size(); ++i)
        isopi[i] = totalgamma[i] / ss;
    }
    else  {
      printf("Error: the sum of gamma is 0 or NAN!\n");
    }

  
    //if we need to correct bias
    if (para.usebias){
      //calculate another set of gamma
      //EMCalculateGamma(biasgamma, sgSupport, sgCounter, isopi, tao,para);
      tao = calculatetao2(tao, isopi, sgSupport, sgCounter,para);
      if (fabs(double(tao)) > 10){
        //may lead to infinity for calculatetao() function (does not appear for calculatetao2()^^)
        if (showwarning){
          double R = 0;
          for (long i = 0; i < rdCounter.size(); ++i) R += rdCounter[i];
          printf("tau: %lf, R: %lf, f(tau): %lf, usebias:", tao, R, calculateftao(tao, biasgamma, sgSupport, sgCounter));
          if (para.usebias) printf(" True\n");
          else printf("False\n");
        }
        tao  = 0;
      }
    }
    currentprob = jointprob;

    if (currentprob==numeric_limits<double>::quiet_NaN()) printf("Error: Nan results\n");

    if (para.printdebuginfo){
      printf("Loop %ld\n", loopcounter);
      printf("Isoform prob: ");
      for (long i = 0; i < isopi.size(); ++i) printf(" %lf", isopi[i]);
      printf("\n");
    }

    ++loopcounter;
    if ((loopcounter > para.preventearlydeletion * 2 + lastdeletion) 
       && ( (fabs(currentprob - prevprob) < para.cutrate * fabs(prevprob))
       || (loopcounter > para.maxloopcount))){
      break;
    }
    else prevprob = currentprob;

  }//end while loop

  if (para.printdebuginfo){
    printf("  final Isoform prob: ");
    for (long i = 0; i < isopi.size(); ++i) printf(" %lf", isopi[i]);
    printf(" iteration: %ld\n", loopcounter);
  }

}

//---------------------------------------------------------------------

/*calculate Pi=\sum (P(ri|tj)P(tj)) in the quasi-multinomial model */
double calculatemultipval(vector<vector<double> >&rdSupport, vector<double> & isopi, vector<double>& pi){
  pi.clear(); pi.assign(rdSupport.size(),0);
  double sum=0;
  for(int i=0;i<rdSupport.size();i++){
    for(int j=0;j<rdSupport[i].size();j++){
      pi[i]+=rdSupport[i][j]*isopi[j];
      if(rdSupport[i][j]<0) cerr<<"Error: rdSupport["<<i<<"]["<<j<<"]="<<rdSupport[i][j]<<endl;
    }
    sum+=pi[i];
    if(pi[i]<0){
      cerr<<"Error: pi="<<pi[i]<<endl;
    }
  }
  return sum;
}

/* Calculate the value of the log quasi multinomial distribution */
double calculateftao2(double tau,  vector<double>& pi,  vector<long> &SGCounter){
  

  //calculate R (total # of reads)
 
  //check if the values of tau is good
  double R2=0;
  for(int i=0;i<pi.size();i++) if ( pi[i]>0 &&  pi[i]+ tau*SGCounter[i]<=0) return 0;
  
  double f=0;//=(1-R)*log(1+R*tau);
  for(int i=0;i<pi.size();i++){
    if(pi[i]<=0)continue;
    R2+=SGCounter[i];
    double xf=(SGCounter[i]-1)*log(pi[i]+tau*SGCounter[i]);
    if(xf==numeric_limits<double>::quiet_NaN()){
      cerr<<"Error: NaN for pi="<<pi[i]<<", tao="<<tau<<", c="<<SGCounter[i]<<endl;
    }
    double logp=log(pi[i]);
    if(logp==numeric_limits<double>::quiet_NaN()){
      cerr<<"Error: NaN for log(pi)="<<pi[i]<<endl;
    }
    f+=xf+logp;
  }
  if(R2*tau+1<=0)return 0;
  f+=(1-R2)*log(1+R2*tau);
  return f;
}


/* Execute half search 
  Return 0 if success, -1 if failed
*/
int taohalfsearch(vector<double>& pi, vector<long>&SGCounter, 
  double a, double b, //left and right boundary
  double& taoval, double& taomaxv)
{
  //cout<<"## searching ["<<a<<","<<b<<"]..."<<endl;
  double stopcriteria=1e-7;
  double nstep=20;
  if(fabs(a-b)<stopcriteria){
    taoval=(a+b)/2;
    //cout<<"## stop criteria matched."<<endl;
    taomaxv=calculateftao2(taoval,pi,SGCounter);
    if(taomaxv==0){
      taoval=b;
      taomaxv=calculateftao2(taoval,pi,SGCounter);
      if(taomaxv==0) return -1;
    }
    return 0; 
  }
  double steplen=(b-a)/nstep;
  double maxvi=-1;
  double maxv=DOUBLE_MIN_VAL;
  for(int i=0;i<=nstep;i++){
     double currentf=a+i*steplen;
     double fval=calculateftao2(currentf,pi,SGCounter);
     //cout<<"#### tao="<<currentf<<",f="<<fval<<endl;
     if(fval!=0 && fval>maxv){maxv=fval;maxvi=currentf;}
  }
  if(maxv!=DOUBLE_MIN_VAL){
    //cout<<"## MAX value found: tao="<<maxvi<<", f="<<maxv<<endl;
    double al=maxvi-steplen, bl=maxvi+steplen;
    if(al<a)al=a; 
    if(bl>b)bl=b;
    return taohalfsearch(pi,SGCounter, al, bl, taoval, taomaxv);
  }else{
    //cout<<"## failed to find max value."<<endl;
    taoval=0;
    taomaxv=DOUBLE_MIN_VAL;
    return -1;
  }
}

/* Calculate the value of tao in another way 
   It's fast too!
*/
double calculatetao2(double tau, vector<double > &isopi, vector<vector<double> > &rdSupport, vector<long> &SGCounter, EMPara& para){
  //hill climbing search
  vector<double> pi;
  calculatemultipval(rdSupport,isopi,pi);
  
  //begin from the minimum value 
  
  double R=para.actualReadCnt;
  double tao0=-0.99/R;
  double maxtaol,maxvtaol;
  double maxtaor,maxvtaor;
  int lr=taohalfsearch(pi,SGCounter,tao0,0, maxtaol,maxvtaol);
  //cout<<"Left ret "<<lr<<", tao="<<maxtaol<<",f="<<maxvtaol<<endl;
  int rr=taohalfsearch(pi,SGCounter,  0,1, maxtaor,maxvtaor);
  //cout<<"Right ret "<<rr<<", tao="<<maxtaor<<",f="<<maxvtaor<<endl;

  double tao=0;
  if(lr!=-1) tao=maxtaol;
  if(rr!=-1){
    if(maxvtaor>maxvtaol) tao=maxtaor;
  }
  if(para.printdebuginfo) cout<<"# R="<<R<<",tao="<<tao<<endl;
  return tao;
}

//-------------------------------------------------------------------

/* Calculate the value of tao, using the derivative function */
double calculatetao(double tau, vector<vector<double> > &pairendgamma, vector<vector<double> > &rdSupport, vector<long> &SGCounter){
  long showdebug = 1;
  double prevtau = 1, taubreak = 100;
  long maxloop = 100;
  double stopcriteria = 1e-6, smallnum = 1e-10;
  long nloop = 0;
  double f;
  //Use Newton-Raphson method to calculate tao
  while ((fabs(tau-prevtau) > stopcriteria * fabs(prevtau)) && (fabs(tau) < taubreak) && (nloop < maxloop)){
    prevtau = tau;
    tau = newtonraphson(tau, pairendgamma, rdSupport, SGCounter);
    if (showdebug){
      printf("tau: %lf -> %lf , f: %lf -> %lf\n",prevtau, tau, 
        calculateftao(prevtau, pairendgamma, rdSupport, SGCounter), 
        calculateftao(tau, pairendgamma, rdSupport, SGCounter));
    }
    ++nloop;
  }

  //Newton-raphson gives a too large value; use binary search method (?)
  //search the zero point
  if (fabs(tau) >= taubreak){
    tau=calculatetaobybinsearch(tau,pairendgamma,rdSupport,SGCounter);
  }
  return tau;
}

/* Binary searching values of tao using the equation */
double calculatetaobybinsearch(double tau, vector<vector<double> > &pairendgamma, 
   vector<vector<double> > &rdSupport, vector<long> &SGCounter){

  long showdebug = 1;
  long maxloop = 100;
  double stopcriteria = 1e-6; 
  long nloop = 0;
  double f=0;

  //Newton-raphson gives a too large value; use binary search method (?)
  //search the zero point
  if (showdebug){
    printf("Divergent of tau: %lf, use binary search\n", tau);
  }
  vector<long> selectedrow;
  for (long l = 0; l < rdSupport.size(); ++l){
    double rowSum = 0;
    for (long k = 0; k < rdSupport[l].size(); ++k) rowSum += rdSupport[l][k];
    if (rowSum > 0) selectedrow.push_back(1);
    else selectedrow.push_back(0);
  }
  long sumselectedrow = 0;
  for (long l = 0; l < selectedrow.size(); ++l)
    sumselectedrow += selectedrow[l];
  if (sumselectedrow < 2){
    if (showdebug){
      printf("empty selection of rows, set to 0.\n");
    }
    return 0;
  }
  long R = 0;
  for (long k = 0; k < selectedrow.size(); ++k)
    if (selectedrow[k] == 1) R += SGCounter[k];
  vector<vector<double> > rds;
  if (rdSupport.size() > 0) zeros(rds, rdSupport.size(), rdSupport[0].size());        
  double tmin = DOUBLE_MIN_VAL;
  for (long k = 0; k < rdSupport.size(); ++k)
    for (long l = 0; l < rdSupport[0].size(); ++l){
      rds[k][l] = rdSupport[k][l] / SGCounter[k];
      if (rds[k][l] > 0 && (rds[k][l] < tmin || tmin == DOUBLE_MIN_VAL)) tmin = rds[k][l];
    }
  tmin = min(tmin, double(1/R));
  vector<vector<double> > pairendgamma_sel;
  vector<vector<double> > rdSupport_sel;
  vector<long> SGCounter_sel;
  double tmax = 1;
  for (long k = 0; k < selectedrow.size(); ++k)
    if (selectedrow[k] > 0){
      pairendgamma_sel.push_back(pairendgamma[k]);
      rdSupport_sel.push_back(rdSupport[k]);
        SGCounter_sel.push_back(SGCounter[k]);
    }
  double ftmin, ftmax;
  ftmin = calculateftao(tmin, pairendgamma_sel, rdSupport_sel, SGCounter_sel);
  ftmax = calculateftao(tmax, pairendgamma_sel, rdSupport_sel, SGCounter_sel);
  double prevdiff = fabs(tmax - tmin);
  if ((ftmin * ftmax <= 0) && !((ftmin == 0) && (ftmax == 0))){
    nloop = 0;
    while ((fabs(tmax - tmin) > max((double)stopcriteria, (double)(min(fabs(tmax), fabs(tmin))))) && (nloop < maxloop)){
      ++nloop;
      double mid = (tmin + tmax) / 2;
      if (showdebug){
        printf("tau: %lf , f: %lf\n", mid, calculateftao(mid, pairendgamma_sel, rdSupport_sel, SGCounter_sel));
      }
      f = calculateftao(mid, pairendgamma_sel, rdSupport_sel, SGCounter_sel);
      if ((f * ftmin > 0) || (f == 0 && ftmin == 0)){
        tmin = mid;
        ftmin = calculateftao(tmin, pairendgamma_sel, rdSupport_sel, SGCounter_sel);
      }
      else {
        tmax = mid;
        ftmax = calculateftao(tmax, pairendgamma_sel, rdSupport_sel, SGCounter_sel);
      }
      if ((ftmax * ftmin > 0) || (ftmax == 0 && ftmin == 0)){
        if (showdebug){
          printf("Warning: ftmin and ftmax have the same sign in the binary search\n");
        }
      }
    }
    tau = (tmin + tmax) / 2;
  }
  else {
    if (showdebug){
      printf("Warning: ftmin and ftmax have the same sign. Set to 0.\n");
    }
    tau = 0;
  }
  return tau;

}

/* The value of the newton-raphson equation */
double calculateftao(double tao, vector<vector<double> > &pairendgamma, vector<vector<double> > &rdSupport, vector<long> &SGCounter){
  double radj = 1;
  vector<double> selectedrow;
  zeros(selectedrow, rdSupport.size());
  for (long i = 0; i < rdSupport.size(); ++i){
    double rowsum = 0;
    for (long j = 0; j < rdSupport[i].size(); ++j) rowsum += rdSupport[i][j];
    if (rowsum > 0) selectedrow[i] = 1;
    else selectedrow[i] = 0;
  }
  long R;
  for (long i = 0; i < selectedrow.size(); ++i)
    if (selectedrow[i] > 0)
      R += SGCounter[i];
  long nr = pairendgamma.size(),  ni = 0;
  if (nr > 0) ni = pairendgamma[0].size();
  vector<vector<long> > xjj;
  zeros(xjj, SGCounter.size(), ni);
  vector<vector<double> > rs;
  if (rdSupport.size() > 0)
    zeros(rs, rdSupport.size(), rdSupport[0].size());
  for (long i = 0; i < SGCounter.size(); ++i)
    for (long j = 0; j < ni; ++j){
      xjj[i][j] = SGCounter[i] * (SGCounter[i] - 1) * selectedrow[i];
    }
  for (long i = 0; i < rdSupport.size(); ++i)
    for (long j = 0; j < rdSupport[i].size(); ++j){
      rs[i][j] = rdSupport[i][j] + SGCounter[i] * selectedrow[i] * tao;
      if (fabs(rs[i][j]) < 1e-20) rs[i][j] = 1;
    }
  double f = 0;
  for (long i = 0; i < xjj.size(); ++i)
    for (long j = 0; j < xjj[i].size(); ++j)
      f += xjj[i][j] * pairendgamma[i][j] / rs[i][j];
  f = f - R * (R - radj) / (1 + R * tao);
  return f;
}

double newtonraphson(double tao, vector<vector<double> > &pairendgamma, vector<vector<double> > &rdSupport, vector<long> &SGCounter){
  double radj = 1;
  vector<double> selectedrow;
  zeros(selectedrow, rdSupport.size());
  for (long i = 0; i < rdSupport.size(); ++i){
    double rowsum = 0;
    for (long j = 0; j < rdSupport[i].size(); ++j) rowsum += rdSupport[i][j];
    if (rowsum > 0) selectedrow[i] = 1;
    else selectedrow[i] = 0;
  }
  long R;
  for (long i = 0; i < selectedrow.size(); ++i)
    if (selectedrow[i] > 0)
      R += SGCounter[i];
  if (1 + R * tao < 0){
    return tao;
  }

  long nr = pairendgamma.size(),  ni = 0;
  if (nr > 0) ni = pairendgamma[0].size();
  vector<vector<long> > xjj, xjj2;
  zeros(xjj, SGCounter.size(), ni);
  vector<vector<double> > rs;
  if (rdSupport.size() > 0)
    zeros(rs, rdSupport.size(), rdSupport[0].size());
  for (long i = 0; i < SGCounter.size(); ++i)
    for (long j = 0; j < ni; ++j){
      xjj[i][j] = SGCounter[i] * (SGCounter[i] - 1) * selectedrow[i];
    }
  for (long i = 0; i < rdSupport.size(); ++i)
    for (long j = 0; j < rdSupport[i].size(); ++j){
      rs[i][j] = rdSupport[i][j] + SGCounter[i] * selectedrow[i] * tao;
      if (fabs(rs[i][j]) < 1e-20) rs[i][j] = 1;
    }
  double ftao = 0, ftao1  = 0;
  for (long i = 0; i < xjj.size(); ++i)
    for (long j = 0; j < xjj[i].size(); ++j)
      ftao += xjj[i][j] * pairendgamma[i][j] / rs[i][j];
  ftao = ftao - R * (R - radj) / (1 + R * tao);
  
  zeros(xjj2, SGCounter.size(), ni);
  for (long i = 0; i < SGCounter.size(); ++i)
    for (long j = 0; j < ni; ++j){
      xjj2[i][j] = SGCounter[i] *SGCounter[i] * (SGCounter[i] - 1) * selectedrow[i];
    }
  ftao1  = R*R*(R-radj)/(1+R*tao)/(1+R*tao);
  if (xjj2.size() > 0)
    for (long i = 0; i < xjj2.size(); ++i)
      for (long j = 0; j < xjj2[i].size(); ++j)
        ftao1 -= xjj2[i][j] * pairendgamma[i][j] / (rs[i][j] * rs[i][j]);
  double t;
  if ((fabs(ftao1) < smalldelta) || (ftao1 == numeric_limits<double>::infinity()))
    t = tao;
  else t = tao - ftao/ftao1;
  if (t == numeric_limits<double>::quiet_NaN())
    printf("Error: Nan result for tao!\n");
  return t;
}


