##################################
#                                #
# Last modified 06/17/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import pysam
import string
from sets import Set

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for i in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-i-1]]
    return sequence


def run():

    if len(sys.argv) < 6:
        print 'usage: python %s BAM chrom.sizes repeatMaskFile chrFieldID repeatIDFields outfilename [-nomulti] [-chr chr1,...,chrN] [-readLength min max]' % sys.argv[0]
        print '       Note: the script assumes no duplicate readIDs'
        print '       repeatIDFields comma-separated'
        sys.exit(1)

    BAM = sys.argv[1]
    chrominfo=sys.argv[2]
    chromInfoList=[]
    linelist=open(chrominfo)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))
    repeatMasker = sys.argv[3]
    chrFieldID = int(sys.argv[4])
    repeatfields = sys.argv[5].split(',')
    repeatFieldIDs=[]
    for ID in repeatfields:
        repeatFieldIDs.append(int(ID))
    repeatFieldIDs.sort()
    outputfilename = sys.argv[6]

    noMulti=False
    if '-nomulti' in sys.argv:
        noMulti=True
        print 'will only consider uniquely mappable reads'

    doChr = False
    if '-chr' in sys.argv:
        doChr = True
        chromosomes = sys.argv[sys.argv.index('-chr')+1].split(',')
        WantedChrDict = {}
        for chr in chromosomes:
            WantedChrDict[chr] = ''

    doReadLength=False
    if '-readLength' in sys.argv:
        doReadLength=True
        minRL = int(sys.argv[sys.argv.index('-readLength')+1])
        maxRL = int(sys.argv[sys.argv.index('-readLength')+2])
        print 'will only consider reads between', minRL, 'and', maxRL, 'bp length'

    ReadDict={}

    SequenceLengthDict={}

    i=0
    samfile = pysam.Samfile(BAM, "rb" )
    for (chr,start,end) in chromInfoList:
        print chr,start,end
        if doChr:
            if WantedChrDict.has_key(chr):
                pass
            else:
                continue
        try:
            for alignedread in samfile.fetch(chr, 0, 100):
                a='b'
        except:
            continue
        for alignedread in samfile.fetch(chr, start, end):
            i+=1
            if i % 5000000 == 0:
                print 'examining read multiplicity and inputting reads', str(i/1000000) + 'M alignments processed processed', chr,start,alignedread.pos,end
            ID = alignedread.qname
            if alignedread.is_read1:
                ID = ID + '/1'
            if alignedread.is_read2:
                ID = ID + '/2'
            sequence = alignedread.seq
            if doReadLength:
                if len(sequence) < minRL or len(sequence) > maxRL:
                    continue
            SequenceLengthDict[len(sequence)]=0
            if alignedread.is_reverse:
                sequence = getReverseComplement(sequence)
            if ReadDict.has_key((ID,sequence)):
                pass
            else:
                ReadDict[(ID,sequence)] = 0
            ReadDict[(ID,sequence)]+=1

    outfile = open(outputfilename, "w")
    RepeatDict={}

    linelist = open(repeatMasker)
    s=0
    for line in linelist:
        if line.startswith('#'):
            continue
        s+=1
        if s % 100000 == 0:
            print s, 'lines processed'
        fields = line.strip().split('\t')
        chr = fields[chrFieldID]
        if doChr:
            if WantedChrDict.has_key(chr):
                pass
            else:
                continue
        left = int(fields[chrFieldID+1])
        right = int(fields[chrFieldID+2])
        repeat=[]
        for ID in repeatFieldIDs:
            repeat.append(fields[ID])
        repeat=tuple(repeat)
        if RepeatDict.has_key(repeat):
            pass
        else:
            RepeatDict[repeat] = []
        RepeatDict[repeat].append((chr, left, right))

    outline = '#'
    for ID in repeatFieldIDs:
        outline = outline + 'repeat\t'
    sequenceLengths = SequenceLengthDict.keys()
    sequenceLengths.sort()
    for length in sequenceLengths:
        outline = outline + str(length) + '\t'
    for i in range(min(sequenceLengths)):
        outline = outline + str(i) + '\t'
    outfile.write(outline.strip() + '\n')

    repeats = RepeatDict.keys()
    repeats.sort()
    s=0
    for repeat in repeats:
        s+=1
        if s % 100 == 0:
            print s, 'repeats processed'
        repeatName = list(repeat)
        outline = ''
        for name in repeatName:
            outline = outline + name + '\t'
        RepeatSequenceList=[]
        for (chr, left, right) in RepeatDict[repeat]:
            if doChr:
                if WantedChrDict.has_key(chr):
                    pass
                else:
                    continue
            for alignedread in samfile.fetch(chr, left, right):
                ID = alignedread.qname
                if alignedread.is_read1:
                    ID = ID + '/1'
                if alignedread.is_read2:
                    ID = ID + '/2'
                sequence = alignedread.seq
                if doReadLength:
                    if len(sequence) < minRL or len(sequence) > maxRL:
                        continue
                if alignedread.is_reverse:
                    sequence = getReverseComplement(sequence)
                if noMulti and ReadDict[(ID,sequence)] > 1:
                    continue
                RepeatSequenceList.append((ID,sequence))
        RepeatSequenceList = list(Set(RepeatSequenceList))
        RepeatSequenceDist={}
        SequenceCompositionDict={}
        for length in sequenceLengths:
            RepeatSequenceDist[length]=0
        for i in range(min(sequenceLengths)):
            SequenceCompositionDict[i]={}
            SequenceCompositionDict[i]['A']=0
            SequenceCompositionDict[i]['C']=0
            SequenceCompositionDict[i]['G']=0
            SequenceCompositionDict[i]['T']=0
        for (ID,sequence) in RepeatSequenceList:
            RepeatSequenceDist[len(sequence)]+=1
            for i in range(min(sequenceLengths)):
                SequenceCompositionDict[i][sequence[i]]+=1
        numberreads = len(RepeatSequenceList) + 0.0
        for length in sequenceLengths:
            if numberreads == 0:
                outline = outline + str(0) + '\t'
            else:
                outline = outline + str(RepeatSequenceDist[length]/numberreads) + '\t'
        for i in range(min(sequenceLengths)):
            if numberreads == 0:
                outline = outline + 'A:' + str(0) + 'T:' + str(0) + 'C:' + str(0) + 'G:' + str(0) + '\t'
            else:
                outline = outline + 'A:' + str(SequenceCompositionDict[i]['A']/numberreads) + ',T:' + str(SequenceCompositionDict[i]['T']/numberreads) + ',C:' + str(SequenceCompositionDict[i]['C']/numberreads) + ',G:' + str(SequenceCompositionDict[i]['G']/numberreads) + '\t'
        outfile.write(outline.strip() + '\n')
        
    outfile.close()

run()

