
import sys
import string
import pysam
import os

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s BAMfilename positions_file chrFieldID positionFieldID radius maxFragmentLlength outputfilename [-notdividemulti] [-stranded fieldID] [-nomulti] [-narrowPeak] [-chr chr1,...,chrN] ' % sys.argv[0]
        print '\tNote: the regions file can be zipped'
        print '\tNote: positionFieldID can be also middle'
        print '\tNote: this script is modified to divide multimapping reads by their multiplicity, unless -momulti is specified.' 
        print '\tNote: this script is also modified to normalize the reads from each position by the local accessibility.'
        sys.exit(1)

    BAM = sys.argv[1]
    regions = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    posFieldID = int(sys.argv[4])
    radius = int(sys.argv[5])
    mFL = int(sys.argv[6])
    outfilename = sys.argv[7]

    noMulti = False

    dividemulti = True
    if "-notdividemulti" in sys.argv:
        print 'treating multi-mapping reads as unique reads'
        dividemulti = False

    ILSumMatrix = {}
    for i in range(-radius,radius+1):
        ILSumMatrix[i] = {}
        for j in range(mFL):
            ILSumMatrix[i][j] = 0

    samfile = pysam.Samfile(BAM, "rb" )

    if regions.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + regions
    elif regions.endswith('.gz'):
        cmd = 'gunzip -c ' + regions
    elif regions.endswith('.zip'):
        cmd = 'unzip -p ' + regions
    else:
        cmd = 'cat ' + regions
    p = os.popen(cmd, "r")
    line = 'line'
    RP = 0
    while line != '':
        line = p.readline().strip()
        fields = line.split('\t')
        
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        peak = int(fields[posFieldID])
        strand = '+'
        
        InsertLengthMatrix = {}
        for i in range(-radius,radius+1):
            InsertLengthMatrix[i] = {}
            for j in range(mFL):
                InsertLengthMatrix[i][j] = 0

        TEST = 0
        try:
            for alignedread in samfile.fetch(chr, max(0,peak-radius-mFL), peak+radius+mFL):
                TEST += 1
                if TEST >= 1:
                    break
        except:
            print ('problem with region:')
            print (fields)
            print (chr,max(0,peak-radius-mFL), peak+radius+mFL)
            continue
        for alignedread in samfile.fetch(chr, max(0,peak-radius-mFL), peak+radius+mFL):
            try:
                multiplicity = alignedread.opt('NH')
            except:
                print ('no NH: tags in BAM file! continue by treating all as unique mapping reads.')
            
            if noMulti and multiplicity > 1:
                continue
            fields=str(alignedread).split('\t')
            FLAGfields = FLAG(int(fields[1]))  # process the flags of this read
            pos = alignedread.pos    # the position this read maps to and the position the mate read maps to are recorded in the annotation
            if 8 in FLAGfields:     # only count paired end reads
                continue
            matepos = alignedread.pnext   # pnext is recorded in the annotation for this read, has nothing to do with the order of the file.
            if matepos > pos:  # only count one read in a pair of mapped reads
                continue
            IL = pos - matepos + len(alignedread.query)
            FP = int(matepos + IL/2.)
            if FP < peak-radius or FP >= peak+radius or IL >= mFL:   # only process reads with focus point and length within the range
                continue
            if strand == '+':
                relativepos = FP - peak
            if strand == '-':
                relativepos = peak - FP
            if dividemulti:
                InsertLengthMatrix[relativepos][IL] += 1.0/multiplicity   # this should be divided by the multiplicity
            else:
                InsertLengthMatrix[relativepos][IL] += 1

        TotalFrags = 0.0
        for i in InsertLengthMatrix.keys():
            for j in InsertLengthMatrix[i].keys():
                TotalFrags += InsertLengthMatrix[i][j]

        NormFactor = TotalFrags/1000000
        for j in range(mFL):
            IL = mFL - j - 1
        
            for i in range(-radius,radius):
                ILSumMatrix[i][IL] = ILSumMatrix[i][IL] + InsertLengthMatrix[i][IL]/NormFactor
                
        RP += 1
        if RP % 1000 == 0:
            print (RP,'regions processed')
        

    outfile = open(outfilename, 'w')

    outline = '#'
    for i in range(-radius,radius):
        outline = outline + '\t' + str(i)
    outfile.write(outline + '\n')

    for j in range(mFL):
        IL = mFL - j - 1
        outline = str(IL)
        for i in range(-radius,radius):
            outline = outline + '\t' + str(ILSumMatrix[i][IL]/RP)  # average per region
        outfile.write(outline + '\n')

    outfile.close()
            
run()

