#include "signalFromBAM.h"
#include <iostream>
#include <regex>

void signalFromBAM(const string bamFileName, const string sigFileName, const bool flagStranded, const int signalType, const string outWigReferencesPrefix) {

    bam1_t *bamA;
    bamA=bam_init1();
    
    double nMult=0, nUniq=0;

    {//count reads in the BAM file

        BGZF *bamIn=bgzf_open(bamFileName.c_str(),"r");
        bam_hdr_t *bamHeader=bam_hdr_read(bamIn);
        while ( true ) {//until the end of file
            int bamBytes1=bam_read1(bamIn, bamA);
            if (bamBytes1<0) break; //end of file
            if (bamA->core.tid<0) continue; //unmapped read
//             if ( !std::regex_match(chrName.at(bamA->core.tid),std::regex(outWigReferencesPrefix))) continue; //reference does not mathc required references
            if ( outWigReferencesPrefix!="-" && (outWigReferencesPrefix.compare(0,outWigReferencesPrefix.size(),bamHeader->target_name[bamA->core.tid],outWigReferencesPrefix.size())!=0) ) continue; //reference does not match required references
            
            uint8_t* aNHp=bam_aux_get(bamA,"NH");
            if (aNHp!=NULL) {
                uint32_t aNH=bam_aux2i(aNHp);
                if (aNH==1) {//unique mappers
                    ++nUniq;
                } else if (aNH>1) {
                    nMult+=1.0/aNH;
                };
            };
        };
        bgzf_close(bamIn);
    };    
    
    BGZF *bamIn=bgzf_open(bamFileName.c_str(),"r");
    bam_hdr_t *bamHeader=bam_hdr_read(bamIn);

    int sigN=flagStranded ? 4 : 2;
    double *normFactor=new double[sigN];
    
    ofstream **sigOutAll=new ofstream* [sigN];
    
    sigOutAll[0]=new ofstream ( (sigFileName+".Unique.str1.out.bg").c_str() );
    normFactor[0]=1.0e6 / nUniq;
    sigOutAll[1]=new ofstream ( (sigFileName+".UniqueMultiple.str1.out.bg").c_str() );
    normFactor[1]=1.0e6 / (nUniq+nMult);    
    if (flagStranded) {
        sigOutAll[2]=new ofstream( (sigFileName+".Unique.str2.out.bg").c_str() );
        normFactor[2]=normFactor[0];
        sigOutAll[3]=new ofstream( (sigFileName+".UniqueMultiple.str2.out.bg").c_str() );
        normFactor[3]=normFactor[1];
    };

    for (uint32_t is=0;is<sigN;is++) {//formatting double output
        *sigOutAll[is]<<setiosflags(ios::fixed) << setprecision(5);
    };
    
    int iChr=-999;
    double *sigAll=NULL;
    while ( true ) {//until the end of file
        int bamBytes1=bam_read1(bamIn, bamA);
        uint32_t chrLen;
        if (bamA->core.tid!=iChr || bamBytes1<0) {
            //output to file
            if (iChr!=-999) {//iChr=-999 marks chromosomes that are not output, including unmapped reads
                for (uint32_t is=0;is<sigN;is++) {
                    double prevSig=0;
                    for (uint32_t ig=0;ig<chrLen;ig++) { 
                        double newSig=sigAll[sigN*ig+is];
                        if (newSig!=prevSig) {
                            if (prevSig!=0) {//finish previous record
                                *sigOutAll[is] <<ig<<"\t"<<prevSig*normFactor[is] <<"\n"; //1-based end
                            };
                            if (newSig!=0) {
                                *sigOutAll[is] << bamHeader->target_name[iChr] <<"\t"<< ig <<"\t"; //0-based beginning
                            };
                            prevSig=newSig;
                        };
                    };
                };
            };
            if (bamBytes1<0) {//no more reads
                break;
            };
            
            iChr=bamA->core.tid;
            if ( iChr==-1 || (outWigReferencesPrefix!="-" && (outWigReferencesPrefix.compare(0,outWigReferencesPrefix.size(),bamHeader->target_name[bamA->core.tid],outWigReferencesPrefix.size())!=0) ) ) {
                iChr=-999;
                continue; //reference does not match required references
            };
            
            chrLen=bamHeader->target_len[iChr]+1;//one extra base at the end which sohuld always be 0            
            delete [] sigAll;        
            sigAll= new double[sigN*chrLen];
            memset(sigAll, 0, sizeof(*sigAll)*sigN*chrLen);
        };
        
//         uint32_t nCigar =(bamA->core.flag<<16)>>16;
//         uint32_t mapFlag=bamA->core.flag>>16;
//         uint32_t mapQ=(bamA->core.flag<<16)>>24;
        
        #define BAM_CIGAR_OperationShift 4
        #define BAM_CIGAR_LengthBits 28
        #define BAM_CIGAR_M 0
        #define BAM_CIGAR_I 1
        #define BAM_CIGAR_D 2
        #define BAM_CIGAR_N 3
        #define BAM_CIGAR_S 4
        #define BAM_CIGAR_H 5
        #define BAM_CIGAR_P 6
        #define BAM_CIGAR_EQ 7
        #define BAM_CIGAR_X 8
        
        //NH attribute
        uint8_t* aNHp=bam_aux_get(bamA,"NH");
        if (aNHp==NULL) continue; //do not process lines without NH field
        uint32_t aNH=bam_aux2i(bam_aux_get(bamA,"NH")); //write a safer function allowing for lacking NH tag
        if (aNH==0) continue; //do not process lines without NH=0
        uint32_t aG=bamA->core.pos;
        uint32_t iStrand=0;
        if (flagStranded) {//strand for stranded data from SAM flag
            iStrand= ( (bamA->core.flag & 0x10) > 0 ) == ( (bamA->core.flag & 0x80) == 0 );//0/1 for +/-
        };          
        if (signalType==1) {//5' of the1st read signal only, RAMPAGE/CAGE
            if ( (bamA->core.flag & 0x80)>0) continue; //skip if this the second mate
            if (iStrand==0) {
                if (aNH==1) {//unique mappers
                    sigAll[aG*sigN+0+2*iStrand]++;
                };
                sigAll[aG*sigN+1+2*iStrand]+=1.0/aNH;//U+M, normalized by the number of multi-mapping loci
                continue; //record only the first position
            };
        };
        
        uint32_t* cigar=(uint32_t*) (bamA->data+bamA->core.l_qname);
       
        for (uint32_t ic=0; ic<bamA->core.n_cigar; ic++) {
            uint32_t cigOp=(cigar[ic]<<BAM_CIGAR_LengthBits)>>BAM_CIGAR_LengthBits;
            uint32_t cigL=cigar[ic]>>BAM_CIGAR_OperationShift;
            switch (cigOp) {
                case(BAM_CIGAR_D):
                case(BAM_CIGAR_N):
                    aG+=cigL;
                    break;
                case(BAM_CIGAR_M):
                    if (signalType==0) {//full signal
                        for (uint32_t ig=0;ig<cigL;ig++) {
                            if (aG>=chrLen) {
                                cerr << "BUG: alignment extends past chromosome in signalFromBAM.cpp\n";
                                exit(-1);
                            };
                            if (aNH==1) {//unique mappers
                                sigAll[aG*sigN+0+2*iStrand]++;
                            };
                            sigAll[aG*sigN+1+2*iStrand]+=1.0/aNH;//U+M, normalized by the number of multi-mapping loci
                            aG++;
                        };
                    } else {
                        aG+=cigL;
                    };
            };
        };
        if (signalType==1) {//full signal
            --aG;
            if (aNH==1) {//unique mappers
                sigAll[aG*sigN+0+2*iStrand]++;
            };
            sigAll[aG*sigN+1+2*iStrand]+=1.0/aNH;//U+M, normalized by the number of multi-mapping loci
        };
    };
    delete [] sigAll;        

    for (int is=0; is<sigN; is++) {// flush/close all signal files
        sigOutAll[is]->flush();
        sigOutAll[is]->close();
    };
};
