/*##########################################################################*/
/*#                                                                         #*/
/*# C O P Y R I G H T   N O T I C E                                         #*/
/*#  Copyright (c) 2003-10 by:                                              #*/
/*#    * California Institute of Technology                                 #*/
/*#                                                                         #*/
/*#    All Rights Reserved.                                                 #*/
/*#                                                                         #*/
/*# Permission is hereby granted, free of charge, to any person             #*/
/*# obtaining a copy of this software and associated documentation files    #*/
/*# (the "Software"), to deal in the Software without restriction,          #*/
/*# including without limitation the rights to use, copy, modify, merge,    #*/
/*# publish, distribute, sublicense, and/or sell copies of the Software,    #*/
/*# and to permit persons to whom the Software is furnished to do so,       #*/
/*# subject to the following conditions:                                    #*/
/*#                                                                         #*/
/*# The above copyright notice and this permission notice shall be          #*/
/*# included in all copies or substantial portions of the Software.         #*/
/*#                                                                         #*/
/*# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,         #*/
/*# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF      #*/
/*# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND                   #*/
/*# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS     #*/
/*# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN      #*/
/*# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN       #*/
/*# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE        #*/
/*# SOFTWARE.                                                               #*/
/*###########################################################################*/
/*# C extension for commonly used motif methods.                             */

#include <Python.h>
#include <stdio.h>
#include <stdlib.h>
#include <math.h>

static char module_doc[] = 
"This module only implements the locateMotif and correlateMotifs in C for now.";

double probMatchPWM_func(double *PWM, int index, char *aSeq, long seqPos) {
    double aval, cval, gval, tval;
    
    aval = PWM[4 * index];
    cval = PWM[4 * index + 1];
    gval = PWM[4 * index + 2];
    tval = PWM[4 * index + 3];
    
    switch (aSeq[seqPos]) {
        case 'A': case 'a':
            return aval;
            break;
        case 'C': case 'c':
            return cval;
            break;
        case 'G': case 'g':
            return gval;
            break;
        case 'T': case 't':
            return tval;
            break;
        case 'N': case 'n':
            return 1.0;
            break;
        case 'S': case 's':
            return cval + gval;
            break;
        case 'W': case 'w':
            return aval + tval;
            break;
        case 'R': case 'r':
            return aval + gval;
            break;
        case 'Y': case 'y':
            return cval + tval;
            break;
        case 'M': case 'm':
            return aval + cval;
            break;
        case 'K': case 'k':
            return gval + tval;
            break;
        case 'B': case 'b':
            return cval + gval + tval;
            break;
        case 'D': case 'd':
            return aval + gval + tval;
            break;
        case 'H': case 'h':
            return aval + cval + tval;
            break;
        case 'V': case 'v':
            return aval + cval + gval;
            break;
    }
    return 0.0;
}

double probMatchDPWM_func(double *PWM, int index, char *aSeq, long seqPos) {
    double aval, cval, gval, tval;
    int prevIndex;
    
    aval = 0.0;
    cval = 0.0;
    gval = 0.0;
    tval = 0.0;
    
    if (index == 0) {
        for (prevIndex = 0; prevIndex < 4; prevIndex++) {
            aval += PWM[prevIndex * 4];
            cval += PWM[prevIndex * 4 + 1];
            gval += PWM[prevIndex * 4 + 2];
            tval += PWM[prevIndex * 4 + 3];
        }
    } else {
        if (seqPos > 0) { 
            switch(aSeq[seqPos - 1]) {
                case 'A': case 'a':
                    aval = PWM[16 * index];
                    cval = PWM[16 * index + 1];
                    gval = PWM[16 * index + 2];
                    tval = PWM[16 * index + 3];
                    break;
                case 'C': case 'c':
                    aval = PWM[16 * index + 4];
                    cval = PWM[16 * index + 5];
                    gval = PWM[16 * index + 6];
                    tval = PWM[16 * index + 7];
                    break;
                case 'G': case 'g':
                    aval = PWM[16 * index + 8];
                    cval = PWM[16 * index + 9];
                    gval = PWM[16 * index + 10];
                    tval = PWM[16 * index + 11];
                    break;
                case 'T': case 't':
                    aval = PWM[16 * index + 12];
                    cval = PWM[16 * index + 13];
                    gval = PWM[16 * index + 14];
                    tval = PWM[16 * index + 15];
                    break;
            }
        }
    }
    
    switch (aSeq[seqPos]) {
        case 'A': case 'a':
            return aval;
            break;
        case 'C': case 'c':
            return cval;
            break;
        case 'G': case 'g':
            return gval;
            break;
        case 'T': case 't':
            return tval;
            break;
    }
    
    return 0.0001;
}

void scoreMot_func(char *theSeq, double *mPWM, double *rPWM, long pos, long motLen, float *bestCons, float *maxDiff, float *score, char *sense) 
{
    float forScore, revScore;
    long index, currentPos;
    
    forScore = 0.0;
    revScore = 0.0;
    *sense = 'F';
    
    for (index = 0; index < motLen; index++) {
        currentPos = pos + index;
        forScore += probMatchPWM_func(mPWM, index, theSeq, currentPos);
        revScore += probMatchPWM_func(rPWM, index, theSeq, currentPos);
    }
    
    if (((forScore + *maxDiff) < *bestCons) && ((revScore + *maxDiff) < *bestCons)) {
        *score = -1.0;
        return;
    }
    
    if (forScore > revScore) {
        *score = forScore;
    } else {
        *score = revScore;
        *sense = 'R';
    }
}

void scoreDPWM_func(char *theSeq, double *mDPWM, double *rDPWM, long pos, long motLen, float *bestCons, float *score, char *sense) 
{
    float forScore, revScore;
    long index, currentPos;
    
    forScore = 0.0;
    revScore = 0.0;
    *sense = 'F';
    
    for (index = 0; index < motLen; index++) {
        currentPos = pos + index;
        forScore -= log(probMatchDPWM_func(mDPWM, index, theSeq, currentPos)) / log(2.0);
        revScore -= log(probMatchDPWM_func(rDPWM, index, theSeq, currentPos)) / log(2.0);
    }
    
    if ((forScore > *bestCons) && (revScore > *bestCons)) {
        *score = -1.0;
        return;
    }
    
    if (forScore < revScore) {
        *score = forScore;
    } else {
        *score = revScore;
        *sense = 'R';
    }
}

static PyObject*
the_func(PyObject *self, PyObject *args)
{
    PyObject *pySeq, *motifPWM, *revPWM, *mScore, *mDiff, *results;
    double *mPWM, *rPWM;
    float maxScore, maxDiff, seqScore;
    long pos, maxPos;
    int ok, index, indexMax, ntIndex, motLen, seqLen, skipping;
    char *seq, sense;
    
    results = Py_BuildValue("[]");
    
    ok = PyArg_UnpackTuple(args, "ref", 5, 5, &pySeq, &motifPWM, &revPWM, &mScore, &mDiff);
    
    seq = PyString_AsString(pySeq);
    motLen = PyList_Size(motifPWM);
    seqLen = PyString_Size(pySeq);
    maxScore = PyFloat_AsDouble(mScore);
    maxDiff = PyFloat_AsDouble(mDiff);
    
    indexMax = 4 * motLen;
    mPWM = malloc(indexMax * sizeof(double));
    rPWM = malloc(indexMax * sizeof(double));
    
    for (index = 0; index < motLen; index++) {
        for (ntIndex = 0; ntIndex < 4; ntIndex++) {
            mPWM[4 * index + ntIndex] = PyFloat_AsDouble(PyList_GetItem(PyList_GetItem(motifPWM, index), ntIndex));
            rPWM[4 * index + ntIndex] = PyFloat_AsDouble(PyList_GetItem(PyList_GetItem(revPWM, index), ntIndex));
        }
    }
    
    pos = 0;
    maxPos = seqLen - motLen;
    
    while (pos <= maxPos) {
        skipping = 0;
        for (index = 0; index < motLen; index++) {
            if ((seq[pos + index] == 'N') || (seq[pos + index] == 'n')) {
                skipping = index + 1;
            } 
        }
        
        if (skipping) {
            pos += skipping;
            continue;
        }
        
        scoreMot_func(seq, mPWM, rPWM, pos, motLen, &maxScore, &maxDiff, &seqScore, &sense);
        
        if (seqScore > 1.0) {
            PyList_Append(results, Py_BuildValue("(i, c)", pos, sense));
        }
        pos++;
    }
    
    free(mPWM);
    free(rPWM);
    return results;
}

static PyObject*
dpwm_func(PyObject *self, PyObject *args)
{
    PyObject *pySeq, *motifPWM, *revPWM, *mScore, *results;
    double *mDPWM, *rDPWM;
    float maxScore, seqScore;
    long pos, maxPos;
    int ok, index, indexMax, prevIndex, currentIndex, motLen, seqLen, skipping;
    char *seq, sense;
    
    results = Py_BuildValue("[]");
    
    ok = PyArg_UnpackTuple(args, "ref", 4, 4, &pySeq, &motifPWM, &revPWM, &mScore);
    
    seq = PyString_AsString(pySeq);
    motLen = PyList_Size(motifPWM);
    seqLen = PyString_Size(pySeq);
    maxScore = PyFloat_AsDouble(mScore);
    
    indexMax = 16 * motLen;
    mDPWM = malloc(indexMax * sizeof(double));
    rDPWM = malloc(indexMax * sizeof(double));
    
    for (index = 0; index < motLen; index++) {
        for (prevIndex = 0; prevIndex < 4; prevIndex++) {
            for (currentIndex = 0; currentIndex < 4; currentIndex++) {
                mDPWM[16 * index + 4 * prevIndex + currentIndex] = PyFloat_AsDouble(PyList_GetItem(PyList_GetItem(PyList_GetItem(motifPWM, index), prevIndex), currentIndex));
                rDPWM[16 * index + 4 * prevIndex + currentIndex] = PyFloat_AsDouble(PyList_GetItem(PyList_GetItem(PyList_GetItem(revPWM, index), prevIndex), currentIndex));
            }
        }
    }
    
    pos = 0;
    maxPos = seqLen - motLen;
    
    while (pos <= maxPos) {
        skipping = 0;
        for (index = 0; index < motLen; index++) {
            if ((seq[pos + index] == 'N') || (seq[pos + index] == 'n')) {
                skipping = index + 1;
            } 
        }
        
        if (skipping) {
            pos += skipping;
            continue;
        }
        
        scoreDPWM_func(seq, mDPWM, rDPWM, pos, motLen, &maxScore, &seqScore, &sense);
        
        if (seqScore > 1.0) {
            PyList_Append(results, Py_BuildValue("(i, c)", pos, sense));
        }
        pos++;
    }
    
    free(mDPWM);
    free(rDPWM);
    return results;
}

static PyObject*
mer_func(PyObject *self, PyObject *args)
{
    PyObject *pySeq, *forwardMer, *revcompMer, *maxMismatches, *results;
    long pos, maxPos;
    int ok, index, motLen, seqLen, skipping, mismatches, fmis, rmis;
    char *fMer, *rMer;
    char *seq, sense;
    
    results = Py_BuildValue("[]");
    
    ok = PyArg_UnpackTuple(args, "ref", 4, 4, &pySeq, &forwardMer, &revcompMer, &maxMismatches);
    
    seq = PyString_AsString(pySeq);
    fMer = PyString_AsString(forwardMer);
    rMer = PyString_AsString(revcompMer);
    motLen = PyString_Size(forwardMer);
    seqLen = PyString_Size(pySeq);
    mismatches = PyInt_AsLong(maxMismatches);
    
    pos = 0;
    maxPos = seqLen - motLen;
    
    while (pos <= maxPos) {
        skipping = 0;
        for (index = 0; index < motLen; index++) {
            if ((seq[pos + index] == 'N') || (seq[pos + index] == 'n')) {
                skipping = index + 1;
            } 
        }
        
        if (skipping) {
            pos += skipping;
            continue;
        }
        
        fmis = 0;
        rmis = 0;
        for (index = 0; index < motLen; index++) {
            if (seq[pos + index] != fMer[index]) {
                fmis++;
            }
            if (seq[pos + index] != rMer[index]) {
                rmis++;
            }
            if ((fmis > mismatches) && (rmis > mismatches)) {
                break;
            }
        }
        
        if ((fmis <= mismatches) || (rmis <= mismatches)) {
            sense = 'F';
            if (rmis < fmis) {
                sense = 'R';
            }             
            PyList_Append(results, Py_BuildValue("(i, c)", pos, sense));
        }
        pos++;
    }
    
    return results;
}

double pearson_func(double *colA, double *colB, int pos)
{
    double c, numerator;
    double meanA, denominatorA;
    double meanB, denominatorB;
    long index;
    
    meanA = 0.0;
    meanB = 0.0;
    
    for (index = 0; index < 4; index++) {
        meanA += colA[pos + index];
        meanB += colB[pos + index];
    }
    
    meanA /= 4;
    meanB /= 4;
    
    denominatorA = 0.0;
    denominatorB = 0.0;
    numerator = 0.0;
    
    for (index = 0; index < 4; index++) {
        numerator += (colA[pos + index] - meanA) * (colB[pos + index] - meanB);
        denominatorA += (colA[pos + index] - meanA) * (colA[pos + index] - meanA);
        denominatorB += (colB[pos + index] - meanB) * (colB[pos + index] - meanB);
    }
    
    if (denominatorA == 0.0 || denominatorB == 0.0) {
        c = 0.0;
    } else {
        c = numerator / sqrt(denominatorA * denominatorB);
    }
    
    return c;
}

static PyObject*
corr_func(PyObject *self, PyObject *args)
{
    PyObject *PyaPWM, *PybPWM, *PycPWM, *PyMaxSlide;
    double *aPWM, *bPWM, *cPWM, *tempA, *tempB, *tempC;
    float fscore, rscore, bestScore;
    long maxSlide;
    int ok, index, indexMax, bIndexMax, ntIndex, motLen, padLen, slide, adjustedPadLen, adjustedSlide, tempMax, tempSize;
    
    bestScore = 0.0;
    
    ok = PyArg_UnpackTuple(args, "corr", 4, 4, &PyaPWM, &PybPWM, &PycPWM, &PyMaxSlide);
    
    motLen = PyList_Size(PyaPWM);
    padLen = motLen - PyList_Size(PybPWM);
    maxSlide = PyInt_AsLong(PyMaxSlide);
    
    if (maxSlide > motLen) {
        maxSlide = motLen - 1;
    }
    
    indexMax = 4 * motLen;
    bIndexMax = 4 * (motLen - padLen);
    aPWM = malloc(indexMax * sizeof(double));
    bPWM = malloc(bIndexMax * sizeof(double));
    cPWM = malloc(bIndexMax * sizeof(double));
    
    for (index = 0; index < (motLen - padLen); index++) {
        for (ntIndex = 0; ntIndex < 4; ntIndex++) {
            aPWM[4 * index + ntIndex] = PyFloat_AsDouble(PyList_GetItem(PyList_GetItem(PyaPWM, index), ntIndex));
            bPWM[4 * index + ntIndex] = PyFloat_AsDouble(PyList_GetItem(PyList_GetItem(PybPWM, index), ntIndex));
            cPWM[4 * index + ntIndex] = PyFloat_AsDouble(PyList_GetItem(PyList_GetItem(PycPWM, index), ntIndex));
        }
    }
    
    for (; index < motLen; index++) {
        for (ntIndex = 0; ntIndex < 4; ntIndex++) {
            aPWM[4 * index + ntIndex] = PyFloat_AsDouble(PyList_GetItem(PyList_GetItem(PyaPWM, index), ntIndex));
        }
    }

    for (slide = -1 * maxSlide; slide < (maxSlide + padLen + 1); slide++ ) {
        tempA = malloc(3 * indexMax * sizeof(double));
        tempB = malloc(3 * indexMax * sizeof(double));
        tempC = malloc(3 * indexMax * sizeof(double));
        if (slide < 0) {
            tempSize = abs(slide) + motLen;
            tempMax = 4 * tempSize;
            for (index = 0; index < 4 * abs(slide); index++) {
                tempA[index] = 0.25;
            }
            for (; index < tempMax; index++) {
                tempA[index] = aPWM[index - 4 * slide];
            }
            for (index = 0; index < bIndexMax; index++) {  
                tempB[index] = bPWM[index];
                tempC[index] = cPWM[index];
            }
            for (; index < tempMax; index++) {
                tempB[index] = 0.25;
                tempC[index] = 0.25;
            }
        } else if ((slide > 0) && (slide <= maxSlide)) {
            if (padLen > 0) {
                if (padLen >= slide) {
                    adjustedPadLen = padLen - slide;
                    adjustedSlide = 0;
                } else {
                    adjustedPadLen = 0;
                    adjustedSlide = slide - padLen;
                }
                tempSize = motLen + adjustedSlide;
                tempMax = indexMax + 4 * adjustedSlide;
                for (index = 0; index < indexMax; index++) {  
                    tempA[index] = aPWM[index];   
                }
                for (; index < tempMax; index++) {
                    tempA[index] = 0.25;
                }
                tempMax = 4 * slide;
                for (index = 0; index < tempMax; index++) {
                    tempB[index] = 0.25;
                    tempC[index] = 0.25;
                }
                tempMax += bIndexMax;
                for (; index < tempMax; index++) {
                    tempB[index] = bPWM[index - 4 * slide];
                    tempC[index] = cPWM[index - 4 * slide];
                }
                tempMax = indexMax + 4 * adjustedSlide;
                for (; index < tempMax; index++) {
                    tempB[index] = 0.25;
                    tempC[index] = 0.25;
                }                
            } else {
                tempSize = motLen + slide;
                tempMax = indexMax + 4 * slide;
                for (index = 0; index < indexMax; index++) {  
                    tempA[index] = aPWM[index];   
                }
                for (; index < tempMax; index++) {
                    tempA[index] = 0.25;
                }
                tempMax = 4 * slide;
                for (index = 0; index < tempMax; index++) {
                    tempB[index] = 0.25;
                    tempC[index] = 0.25;
                }
                tempMax += bIndexMax;
                for (; index < tempMax; index++) {
                    tempB[index] = bPWM[index - 4 * slide];
                    tempC[index] = cPWM[index - 4 * slide];
                }
            }
        } else if (slide > maxSlide) {
            tempSize = motLen + maxSlide;
            tempMax = indexMax + 4 * maxSlide;
            for (index = 0; index < indexMax; index++) {  
                tempA[index] = aPWM[index];   
            }
            for (; index < tempMax; index++) {
                tempA[index] = 0.25;
            }
            tempMax = 4 * slide;
            for (index = 0; index < tempMax; index++) {
                tempB[index] = 0.25;
                tempC[index] = 0.25;
            }
            tempMax += bIndexMax;
            for (; index < tempMax; index++) {
                tempB[index] = bPWM[index - 4 * slide];
                tempC[index] = cPWM[index - 4 * slide];
            }
            tempMax = indexMax + 4 * maxSlide;
            for (; index < tempMax; index++) {
                tempB[index] = 0.25;
                tempC[index] = 0.25;
            }            
        } else {
            tempSize = motLen;
            
            for (index = 0; index < indexMax; index++) {  
                tempA[index] = aPWM[index];   
            }
            for (index = 0; index < bIndexMax; index++) {  
                tempB[index] = bPWM[index];
                tempC[index] = cPWM[index];
            }
            for (; index < indexMax; index++) {
                tempB[index] = 0.25;
                tempC[index] = 0.25;
            }
        }
        fscore = 0.0;
        rscore = 0.0;
        
        for (index = 0; index <tempSize; index++) {
            fscore += pearson_func(tempA, tempB, 4 * index);
            rscore += pearson_func(tempA, tempC, 4 * index);
            /* fprintf(stdout,"A: %f %f %f %f\n", tempA[4 * index], tempA[4 * index + 1], tempA[4 * index + 2], tempA[4 * index + 3]); */
            /* fprintf(stdout,"B: %f %f %f %f\n", tempB[4 * index], tempB[4 * index + 1], tempB[4 * index + 2], tempB[4 * index + 3]); */
        }
        fscore = fscore / tempSize;
        rscore = rscore / tempSize;
        /* fprintf(stdout,"slide %d %f %f\n", slide, fscore, rscore); */
        if ((fscore < rscore) && (rscore > bestScore)) {
            bestScore = rscore;
        }  else if (fscore > bestScore) {
            bestScore = fscore;
        }

        free(tempA);
        free(tempB);
        free(tempC);
    }    
    
    free(aPWM);
    free(bPWM);
    free(cPWM);
    
    return PyFloat_FromDouble(bestScore);
}

static char the_func_doc[] =
"returns a list of positions on aSeq that match the PWM within a Threshold, given as a percentage of the optimal consensus score.";

static char dpwm_func_doc[] =
"returns a list of positions on aSeq that match the DPWM within a given Fold of the optimal consensus score.";

static char mer_func_doc[] =
"returns a list of positions on aSeq that match an N-mer within M mismatches. Assumes Mers and Seq are in the same case.";

static char corr_func_doc[] = 
"returns a pearson-correlation coefficient based similarity value between -1 (anti-correlated) and +1 (identical) for two motifs.";

static PyMethodDef module_methods[] = {
    {"locateMotif", the_func, METH_VARARGS, the_func_doc},
    {"locateMarkov1", dpwm_func, METH_VARARGS, dpwm_func_doc},
    {"locateMer", mer_func, METH_VARARGS, mer_func_doc},
    {"correlateMotifs", corr_func, METH_VARARGS, corr_func_doc},
    {NULL, NULL}
};

PyMODINIT_FUNC
init_motif(void)
{
    Py_InitModule3("_motif", module_methods, module_doc);
}
