##################################
#                                #
# Last modified 2018/05/13       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import pysam
import string
import math

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s config regions chrFieldID chrom.sizes FRiP minFragments outprefix [-singleFieldRegions]' % sys.argv[0]
        print '\tconfig file format:'
        print '\t\tlabel <tab> BAMfilename'
        print '\tNote: the script wil exclude contigs named chrM by default; if you want other contigs excluded, remove them from the chrom.sizes file'
        sys.exit(1)

    config = sys.argv[1]
    regionsfile = sys.argv[2]
    chrFieldID = int(sys.argv[3])
    chrominfo = sys.argv[4]
    chromInfoList=[]
    linelist=open(chrominfo)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))
    minFRiP = float(sys.argv[5])
    minFragments = int(sys.argv[6])
    outprefix = sys.argv[7]

    print 'minFRiP:', minFRiP
    print 'minFragments:', minFragments

    doSFR = False
    if '-singleFieldRegions' in sys.argv:
        doSFR = True

    DataMatrix = {}

    regionsList = []
    if regionsfile.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + regionsfile
    elif regionsfile.endswith('.gz'):
        cmd = 'gunzip -c ' + regionsfile
    elif regionsfile.endswith('.zip'):
        cmd = 'unzip -p ' + regionsfile
    else:
        cmd = 'cat ' + regionsfile
    p = os.popen(cmd, "r")
    line = 'line'
    while line != '':
        line = p.readline().strip()
        if line == '':
            break
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        if doSFR:
            chr = fields[chrFieldID].split(':')[0]
            left = int(fields[chrFieldID].split(':')[1].split('-')[0])
            right = int(fields[chrFieldID].split(':')[1].split('-')[1])
        else:
            chr = fields[chrFieldID]
            left = int(fields[chrFieldID + 1])
            right = int(fields[chrFieldID + 2])
        if chr == 'chrM':
            continue
        DataMatrix[(chr,left,right)] = {}

    print 'finished inputting regions'

    outfile1 = open(outprefix + '.stats', 'w')
    outfile2 = open(outprefix + '.matrix', 'w')

    outline = '#label\tfilename\tTotalUniqueFragments\tFRiP\tIncluded_or_not'

    labels = []

    linelist = open(config)
    for line in linelist:
        fields = line.strip().split('\t')
        label = fields[0]
        file = fields[1]
        fragments = {}
        samfile = pysam.Samfile(file, "rb" )
        for (chr,start,end) in chromInfoList:
            if chr == 'chrM':
                continue
            try:
                jj=0
                for alignedread in samfile.fetch(chr, start, end):
                    jj+=1
                    if jj==1:
                        break
            except:
                continue
            for alignedread in samfile.fetch(chr, start, end):
                pos = alignedread.pos
                matepos = alignedread.pnext
                left = min(pos,matepos)
                right = max(pos,matepos)
                f = (chr,left,right)
                fragments[f] = 1
        fragments = fragments.keys()
        fragmentsInPeaks = {}
        for (chr,start,end) in DataMatrix.keys():
            if chr == 'chrM':
                continue
            for alignedread in samfile.fetch(chr, start, end):
                pos = alignedread.pos
                matepos = alignedread.pnext
                left = min(pos,matepos)
                right = max(pos,matepos)
                f = (chr,left,right)
                fragmentsInPeaks[f] = 1
        fragmentsInPeaks = fragmentsInPeaks.keys()
        if len(fragments) == 0:
            FRiP = 0
        else:
            FRiP = (len(fragmentsInPeaks) + 0.0)/(len(fragments) + 0.0)
        outline = label + '\t' + file + '\t' + str(len(fragments)) + '\t' + str(FRiP)
        if len(fragments) < minFragments or FRiP < minFRiP:
            outline = outline + '\t' + 'no'
            outfile1.write(outline + '\n')
            print outline 
            continue
        outline = outline + '\t' + 'yes'
        print outline 
        outfile1.write(outline + '\n')
        for (chr,start,end) in DataMatrix.keys():
            if chr == 'chrM':
                continue
            regionfragments = {}
            for alignedread in samfile.fetch(chr, start, end):
                pos = alignedread.pos
                matepos = alignedread.pnext
                left = min(pos,matepos)
                right = max(pos,matepos)
                f = (chr,left,right)
                regionfragments[f] = 1
            counts = len(regionfragments.keys())
            DataMatrix[(chr,start,end)][label] = counts
        labels.append(label)

    regions = DataMatrix.keys()
    regions.sort()

    labels.sort()

    outline = '#'
    for label in labels:
        outline = outline + '\t' + label
    outfile2.write(outline + '\n')

    for (chr,start,end) in regions:
        outline = chr + ':' + str(start) + '-' + str(end)
        for label in labels:
            outline = outline + '\t' + str(DataMatrix[(chr,start,end)][label])
        outfile2.write(outline + '\n')

    outfile1.close()
    outfile2.close()
            
run()
