##################################
#                                #
# Last modified 07/01/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s genome_fasta gtf indels.vcf snps.vcf <ID1,ID2> outfile [-GQ minGQ]' % sys.argv[0]
        print '     NOTE: VCFv3.3 assumed for snps, VCFv4.0 for indels' 
        print '     output format for heterozygous transcripts: GeneName::GeneID::TranscriptName::TranscriptID::ID1 or GeneName::GeneID::TranscriptName::TranscriptID::ID2' 
        print '     Be careful about 1- and 0-based genomes and annotations when working with new genomes' 
        sys.exit(1)

    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n','-':'-'}

    doGQ=False
    if '-GQ' in sys.argv:
        doGQ=True
        minGQ=int(sys.argv[sys.argv.index('-GQ')+1])

    fasta=sys.argv[1]
    gtf = sys.argv[2]
    indels = sys.argv[3]
    snps = sys.argv[4]
    (ID1,ID2) = tuple(sys.argv[5].split(','))
    outfilename = sys.argv[6]

    inputdatafile = open(fasta)
    SequenceDict={}
    sequence = ''
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                sequence = ''.join(sequence)
                SequenceDict[chr]=sequence
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
        else:
            sequence.append(line.strip())   
    sequence = ''.join(sequence)
    SequenceDict[chr]=sequence

    TranscriptDict={}
    CoverageDict={}

    linelist=open(gtf)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        if SequenceDict.has_key(chr):
            pass
        else:
            continue
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if 'gene_name' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=geneID
        if 'transcript_name' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName=transcriptID
        if TranscriptDict.has_key((geneName,geneID,transcriptName,transcriptID)):
            pass
        else:
            TranscriptDict[(geneName,geneID,transcriptName,transcriptID)]=[]
        TranscriptDict[(geneName,geneID,transcriptName,transcriptID)].append((chr,left,right,strand))

    VariantDict1={}
    VariantDict2={}

    g=0
    bp=0
    for (geneName,geneID,transcriptName,transcriptID) in TranscriptDict.keys():
        g+=1
        for (chr,left,right,strand) in TranscriptDict[(geneName,geneID,transcriptName,transcriptID)]:
            if CoverageDict.has_key(chr):
                pass
            else:
                CoverageDict[chr]={}
                VariantDict1[chr]={}
                VariantDict2[chr]={}
            for i in range(left,right):
                CoverageDict[chr][i]=''
                bp+=1

    print g, 'genes considered', bp, 'bp in total'
        
    linelist=open(indels)
    i=0
    v1=0
    v2=0
    for line in linelist:
        i+=1
        if i % 5000000 == 0:
            print i, 'lines processed in', indels
        if line.startswith('##'):
            continue
        fields=line.strip().split('\t')
        if line.startswith('#CHROM'):
            print fields
            fieldID1=fields.index(ID1)
            fieldID2=fields.index(ID2)
            continue
        chr='chr'+fields[0]
        if SequenceDict.has_key(chr):
            pass
        else:
            continue
        pos=int(fields[1])
        if CoverageDict.has_key(chr):
            pass
        else:
            CoverageDict[chr]={}
            VariantDict1[chr]={}
            VariantDict2[chr]={}
        if CoverageDict[chr].has_key(pos):
            Reference=fields[3]
            variants=fields[4].split(',')
            if fields[fieldID1] != '.':
                if doGQ:
                    try:
                        if int(fields[fieldID1].split(':')[1]) >= minGQ:
                            indel=variants[int(fields[fieldID1].split('/')[0])-1]
                            VariantDict1[chr][pos]=(Reference,indel)
                    except:
                         print 'skipping', ID1, fields[fieldID1]
                         continue
                else:
                    indel=variants[int(fields[fieldID1].split('/')[0])-1]
                    VariantDict1[chr][pos]=(Reference,indel)
                    v1+=1
            if fields[fieldID2] != '.':
                if doGQ:
                    try:
                        if int(fields[fieldID2].split(':')[1]) >= minGQ:
                            indel=variants[int(fields[fieldID2].split('/')[0])-1]
                            VariantDict2[chr][pos]=(Reference,indel)
                    except:
                         print 'skipping', ID2, fields[fieldID2]
                         continue
                else:
                    indel=variants[int(fields[fieldID2].split('/')[0])-1]
                    VariantDict2[chr][pos]=(Reference,indel)
                    v2+=1

    print 'found', v1, 'coding indels in', ID1
    print 'found', v2, 'coding indels in', ID2

    linelist=open(snps)
    i=0
    v1=0
    v2=0
    for line in linelist:
        i+=1
        if i % 5000000 == 0:
            print i, 'lines processed in', snps
        if line.startswith('##'):
            continue
        fields=line.strip().split('\t')
        if line.startswith('#CHROM'):
            print fields
            fieldID1=fields.index(ID1)
            fieldID2=fields.index(ID2)
            continue
        chr='chr'+fields[0]
        if SequenceDict.has_key(chr):
            pass
        else:
            continue
        pos=int(fields[1])
        if CoverageDict[chr].has_key(pos):
            Reference=fields[3]
            variants=fields[4].split(',')
            if fields[fieldID1][0:2] != '0/' and fields[fieldID1][0:2] != './':
                if doGQ:
                    if int(fields[fieldID1].split(':')[4]) >= minGQ:
                        snp=variants[int(fields[fieldID1].split('/')[0])-1]
                        VariantDict1[chr][pos]=(Reference,snp)
                else:
                    snp=variants[int(fields[fieldID1].split('/')[0])-1]
                    VariantDict1[chr][pos]=(Reference,snp)
                    v1+=1
            if fields[fieldID2][0:2] != '0/' and fields[fieldID2][0:2] != './':
                if doGQ:
                    if int(fields[fieldID2].split(':')[4]) >= minGQ:
                        snp=variants[int(fields[fieldID2].split('/')[0])-1]
                        VariantDict2[chr][pos]=(Reference,snp)
                else:
                    snp=variants[int(fields[fieldID2].split('/')[0])-1]
                    VariantDict2[chr][pos]=(Reference,snp)
                    v2+=1

    print 'found', v1, 'coding SNPs in', ID1
    print 'found', v2, 'coding SNPs in', ID2

    outfile = open(outfilename, 'w')
    outfilelog = open(outfilename + '.log', 'w')

    outline='#GeneID\tGeneName\tTranscriptID\tTranscriptName\tchr\tmRNAPos\tGenomicPos\tReference\t' + ID1 + '\t' + ID2
    outfilelog.write(outline +'\n')

    keys = TranscriptDict.keys()
    keys.sort()

    for (geneName,geneID,transcriptName,transcriptID)  in keys:
        pos=0
        mRNAVariantsDict={}
        sequence=''
        TranscriptDict[(geneName,geneID,transcriptName,transcriptID)].sort()
        coordinates=[]
        for (chr,left,right,strand) in TranscriptDict[(geneName,geneID,transcriptName,transcriptID)]:
            coordinates.append(left)
            coordinates.append(right)
            sequence=sequence+SequenceDict[chr][left:right]
            for i in range(left,right):
                if VariantDict1[chr].has_key(i) or VariantDict2[chr].has_key(i):
                    mRNAVariantsDict[pos]={}
                    if VariantDict1[chr].has_key(i) and VariantDict2[chr].has_key(i):
                        mRNAVariantsDict[pos][ID1]=VariantDict1[chr][i]
                        mRNAVariantsDict[pos][ID2]=VariantDict2[chr][i]
                        outline=geneID + '\t' + geneName + '\t' + transcriptName + '\t' + transcriptID + '\t' + chr  + '\t' + str(pos) + '\t' + str(i) + '\t' + mRNAVariantsDict[pos][ID1][0] + '\t' + mRNAVariantsDict[pos][ID1][1] + '\t' + mRNAVariantsDict[pos][ID2][1]
                        outfilelog.write(outline + '\n')
                    elif VariantDict1[chr].has_key(i):
                        mRNAVariantsDict[pos][ID1]=VariantDict1[chr][i]
                        mRNAVariantsDict[pos][ID2]='NONE'
                        outline=geneID + '\t' + geneName + '\t' + transcriptName + '\t' + transcriptID + '\t' + chr  + '\t' + str(pos) + '\t' + str(i) + '\t' + mRNAVariantsDict[pos][ID1][0] + '\t' + mRNAVariantsDict[pos][ID1][1] + '\t' + mRNAVariantsDict[pos][ID2]
                        outfilelog.write(outline + '\n')
                    elif VariantDict2[chr].has_key(i):
                        mRNAVariantsDict[pos][ID2]=VariantDict2[chr][i]
                        mRNAVariantsDict[pos][ID1]='NONE'
                        outline=geneID + '\t' + geneName + '\t' + transcriptName + '\t' + transcriptID + '\t' + chr  + '\t' + str(pos) + '\t' + str(i) + '\t' + mRNAVariantsDict[pos][ID2][0] + '\t' + mRNAVariantsDict[pos][ID1] + '\t' + mRNAVariantsDict[pos][ID2][1]
                        outfilelog.write(outline + '\n')
                pos+=1
        sequence=string.upper(sequence)
        if len(mRNAVariantsDict.keys()) == 0:
            if strand == '-':
                sequence=getReverseComplement(sequence)
            outfile.write('>' + geneName + ':' + geneID +':' + transcriptName + ':' + transcriptID + ':' + ID1 + '\n')
            for b in range(0,len(sequence),50):
                outfile.write(sequence[b:min(b+50, len(sequence))] + '\n')
            outfile.write('>' + geneName + ':' + geneID +':' + transcriptName + ':' + transcriptID + ':' + ID2 + '\n')
            for b in range(0,len(sequence),50):
                outfile.write(sequence[b:min(b+50, len(sequence))] + '\n')
        else:
            positions=mRNAVariantsDict.keys()
            positions.sort()
            positions.reverse()
            ID1sequence=sequence
            ID2sequence=sequence
            for pos in positions:
                if mRNAVariantsDict[pos][ID1]!='NONE':
                    refLength=len(mRNAVariantsDict[pos][ID1][0])
                    ID1sequence=ID1sequence[0:pos-1] + mRNAVariantsDict[pos][ID1][1] + ID1sequence[pos+refLength-1:len(ID1sequence)]
                if mRNAVariantsDict[pos][ID2]!='NONE':
                    refLength=len(mRNAVariantsDict[pos][ID2][0])
                    ID2sequence=ID2sequence[0:pos-1] + mRNAVariantsDict[pos][ID2][1] + ID2sequence[pos+refLength-1:len(ID2sequence)]
            outfile.write('>' + geneName + ':' + geneID +':' + transcriptName + ':' + transcriptID + ':' + ID1 + '\n')
            if strand == '-':
                ID1sequence=getReverseComplement(ID1sequence)
            for b in range(0,len(ID1sequence ),50):
                outfile.write(ID1sequence[b:min(b+50, len(ID1sequence ))] + '\n')
            outfile.write('>' + geneName + ':' + geneID +':' + transcriptName + ':' + transcriptID + ':' + ID2 + '\n')
            if strand == '-':
                ID2sequence=getReverseComplement(ID2sequence)
            for b in range(0,len(ID2sequence ),50):
                outfile.write(ID2sequence[b:min(b+50, len(ID2sequence ))] + '\n')

    outfile.close()
    outfilelog.close()

run()