##################################
#                                #
# Last modified 11/17/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import pysam
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s BAMfilename chrom.sizes GTF outputfilename' % sys.argv[0]
        print '       BAM file has to be indexed'
        sys.exit(1)

    SAM = sys.argv[1]
    GTF = sys.argv[3]
    outputfilename = sys.argv[4]

    chrominfo=sys.argv[2]
    chromInfoList=[]
    linelist=open(chrominfo)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))

    TranscriptDict={}

    GeneDict={}
    TranscriptDict={}
    linelist=open(GTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr=fields[0]
        strand=fields[6]
        left=int(fields[3])
        right=int(fields[4])
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName = geneID
        if 'transcript_name "' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName = transcriptID
        ID = (geneID,geneName,transcriptID,transcriptName)
        if TranscriptDict.has_key(chr):
            pass
        else:
            TranscriptDict[chr]={}
        if TranscriptDict[chr].has_key(ID):
            pass
        else:
            TranscriptDict[chr][ID]=[]
        TranscriptDict[chr][ID].append((chr,left,right,strand))

    outfile=open(outputfilename, 'w')
    outfile.write('#geneID(s)\tgeneName(s)\ttranscriptID(s)\ttranscriptName(s)\tchr\tPos\tReadID\n')

    samfile = pysam.Samfile(SAM, "rb" )
    i=0
    for (chr,start,end) in chromInfoList:
        if TranscriptDict.has_key(chr):
            pass
        else:
            continue
        print 'bam', chr,start,end
        PosDict={}
        try:
            for alignedread in samfile.fetch(chr, start, end):
                i+=1
                if i % 5000000 == 0:
                    print str(i/1000000) + 'M alignments processed'
                fields=str(alignedread).split('\t')
                readID=fields[0]
                pos=alignedread.pos
                if PosDict.has_key(pos):
                    pass
                else:
                    PosDict[pos]={}
                PosDict[pos][readID]=[]
        except:
            print 'problem with region:', chr, start, end, 'skipping'
            continue
        for ID in TranscriptDict[chr]:
            TranscriptDict[chr][ID] = list(Set(TranscriptDict[chr][ID]))
            for (chr,left,right,strand) in TranscriptDict[chr][ID]: 
                for pos in range(left,right):
                    if PosDict.has_key(pos):
                        for readID in PosDict[pos].keys():
                            PosDict[pos][readID].append(ID)
        for pos in PosDict.keys():
            geneIDs=[]
            geneNames=[]
            transcriptIDs=[]
            transcriptNames=[]
            readIDs=[]
            for readID in PosDict[pos].keys():
                readIDs.append(readIDs)
                for ID in PosDict[pos][readID]:
                    geneIDs.append(ID[0])
                    geneNames.append(ID[1])
                    transcriptIDs.append(ID[2])
                    transcriptNames.append(ID[3])
            transcriptNames = list(Set(transcriptNames))
            transcriptIDs = list(Set(transcriptIDs))
            geneNames = list(Set(geneNames))
            geneIDs = list(Set(geneIDs))
            outline = ''
            if len(geneIDs)== 0 :
                outline = outline + '-' + '\t'
            else:
                for geneID in geneIDs:
                    outline = outline + geneID + ','
                outline = outline[0:-1] + '\t'
            if len(geneNames)== 0 :
                outline = outline + '-' + '\t'
            else:
                for geneName in geneNames:
                    outline = outline + geneName + ','
                outline = outline[0:-1] + '\t'
            if len(transcriptIDs)== 0 :
                outline = outline + '-' + '\t'
            else:
                for transcriptID in transcriptIDs:
                    outline = outline + transcriptID + ','
                outline = outline[0:-1] + '\t'
            if len(transcriptNames)== 0 :
                outline = outline + '-' + '\t'
            else:
                for transcriptName in transcriptNames:
                    outline = outline + transcriptName + ','
                outline = outline[0:-1] + '\t'
            outline = outline + chr + '\t' + str(pos) + '\t' + readID.replace('|','\t')
            outfile.write(outline +'\n')
             
    outfile.close()

run()
