##################################
#                                #
# Last modified 10/06/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

import sys
import pysam
from sets import Set

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s TSS-table chrFieldID poschrFieldID strandFieldID input_files_list outfilename' % sys.argv[0]
        print '       input_files_list format: label <tab> file <tab> bed | bam | motif,pos1,pos2 | CpG,radius'
        print '       the script will look for coverage by the bed regions of the TSS, for exact read match on the same strand in the bam file,'
        print '       for presence of motif in the specified region relative to the TSS (needs getallsites outpit), or for CpG island within the specified radius of the TSS'
        sys.exit(1)

    TSStable = sys.argv[1]
    chrFieldID = int(sys.argv[2])
    TSSFieldID = int(sys.argv[3])
    strandFieldID = int(sys.argv[4])
    input_files = sys.argv[5]
    outfilename = sys.argv[6]

    DataList=[]

    linelist = open(input_files)
    for line in linelist:
        fields = line.strip().split('\t')
        label = fields[0]
        file = fields[1]
        type = fields[2]
        DataList.append((label,file,type))

    outfile=open(outfilename, 'w')

    TSSDict={}

    linelist = open(TSStable)
    for line in linelist:
        if line.startswith('#'):
             outline = line.strip()
             for (label,file,type) in DataList:
                 outline = outline + '\t' + label
             outfile.write(outline + '\n')
             continue
        fields = line.strip().split('\t')
        TSS = int(fields[TSSFieldID])
        chr = fields[chrFieldID]
        strand = fields[strandFieldID]
        TSSDict[(chr,TSS,strand)] = {}
        TSSDict[(chr,TSS,strand)]['line'] = line.strip()
        TSSDict[(chr,TSS,strand)]['coverage'] = {}
        for (label,file,type) in DataList:
            TSSDict[(chr,TSS,strand)]['coverage'][label]=0

    keys = TSSDict.keys()
    keys.sort()

    for (label,file,type) in DataList:
        print label, type
        if type == 'bed':
            linelist = open(file)
            for line in linelist:
                fields = line.strip().split('\t')
                left = int(fields[1])
                right = int(fields[2])
                chr = fields[0]
                for i in range(left,right):
                    if TSSDict.has_key((chr,i,'+')):
                        TSSDict[(chr,i,'+')]['coverage'][label]=1
                    if TSSDict.has_key((chr,i,'-')):
                        TSSDict[(chr,i,'-')]['coverage'][label]=1
        if type == 'bam':
            samfile = pysam.Samfile(file, "rb" )
            for (chr,TSS,strand) in keys:
                i=0
                try:
                    for alignedread in samfile.fetch(chr, TSS-1, TSS+1):
                        fields=str(alignedread).split('\t')
                        FLAGfields = FLAG(int(fields[1]))
                        if 16 in FLAGfields:
                            readstrand = '-'
                        else:
                            readstrand = '+'
                        if readstrand == strand:
                            i+=1
                except:
                    continue
                TSSDict[(chr,TSS,strand)]['coverage'][label]=i
        if type.startswith('motif'):
            TSSleft=int(type.split(',')[1])
            TSSright=int(type.split(',')[2])
            linelist = open(file)
            MotifPresenceDict={}
            for line in linelist:
                fields = line.strip().split('\t')
                left = int(fields[0].split(':')[1].split('-')[0])
                right = int(fields[0].split(':')[1].split('-')[1])
                chr = fields[0].split(':')[0]
                if MotifPresenceDict.has_key(chr):
                    pass
                else:
                    MotifPresenceDict[chr]={}
                pos = int((left+right)/2.0)
                MotifPresenceDict[chr][pos]=0
            for (chr,TSS,strand) in keys:
                for i in range(TSS + TSSleft, TSS + TSSright):
                    if MotifPresenceDict.has_key(chr) and MotifPresenceDict[chr].has_key(i):
                        TSSDict[(chr,TSS,strand)]['coverage'][label]=1
                        break
            MotifPresenceDict={}
        if type.startswith('CpG'):
            TSSradius=int(type.split(',')[1])
            linelist = open(file)
            CpGPresenceDict={}
            for line in linelist:
                fields = line.strip().split('\t')
                left = int(fields[2])
                right = int(fields[3])
                chr = fields[1]
                if CpGPresenceDict.has_key(chr):
                    pass
                else:
                    CpGPresenceDict[chr]={}
                for i in range(left,right):
                    CpGPresenceDict[chr][i]=0
            for (chr,TSS,strand) in keys:
                for i in range(TSS - TSSradius, TSS + TSSradius):
                    if CpGPresenceDict.has_key(chr) and CpGPresenceDict[chr].has_key(i):
                        TSSDict[(chr,TSS,strand)]['coverage'][label]=1
                        break
            CpGPresenceDict={}

    for (chr,TSS,strand) in keys:
        outline = TSSDict[(chr,TSS,strand)]['line']
        for (label,file,type) in DataList:
            outline = outline + '\t' + str(TSSDict[(chr,TSS,strand)]['coverage'][label])
        outfile.write(outline+'\n')
             
    outfile.close()

run()
