##################################
#                                #
# Last modified 07/18/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from cistematic.core.geneinfo import geneinfoDB
from cistematic.genomes import Genome

try:
	import psyco
	psyco.full()
except:
	pass

def getSequence(genome,chromosome,start,stop,sense):
    
    hg = Genome(genome)
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    chromosome = chromosome[3:len(chromosome)]
    if sense=='F' or sense=='+':
        sequence = hg.sequence(chromosome,start,stop-start)
    if sense=='R' or sense=='-':
        preliminarysequence = hg.sequence(chromosome,start,stop-start)
        sequence=''
        for i in range(len(preliminarysequence)):
            sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-i-1]]
    return sequence

def reverseComplement(sequence,DNA):
    
    reversesequence=''
    for i in range(len(sequence)):
        reversesequence=reversesequence+DNA[sequence[len(sequence)-i-1]]
    
    return reversesequence

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s genome gtf indels.vjf snps.vcf <ID1,ID2> outfile [-fasta <comma-separated list of fasta files>] [-GQ minGQ] [-refFlat refFlat_file]' % sys.argv[0]
        print '     NOTE: VCFv3.3 assumed for snps, VCFv4.0 for indels' 
        print '     NOTE: recommended to use RefSeq genes; the script will only output the longest isoform of any gene as measured in bp' 
        print '     Be careful about 1- and 0-based genomes and annotations when working with new genomes' 
        sys.exit(1)

    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n','-':'-'}

    doFasta=False
    if '-fasta' in sys.argv:
        doFasta=True
        FastaList=sys.argv[sys.argv.index('-fasta')+1].split(',')

    doGQ=False
    if '-GQ' in sys.argv:
        doGQ=True
        minGQ=int(sys.argv[sys.argv.index('-GQ')+1])

    doRefFlat=False
    if '-refFlat' in sys.argv:
        doRefFlat=True
        linelist=open(sys.argv[sys.argv.index('-refFlat')+1])
        RefSeqToNamesDict={}
        for line in linelist:
            fields=line.split('\t')
            RefSeqToNamesDict[fields[1]]=fields[0]

    genome=sys.argv[1]
    gtf = sys.argv[2]
    indels = sys.argv[3]
    snps = sys.argv[4]
    (ID1,ID2) = tuple(sys.argv[5].split(','))
    outfilename = sys.argv[6]

    if doFasta:
        SequenceDict={}
        for fa in FastaList:
            print 'processing', fa
            linelist=open(fa)
            l=0
            for line in linelist:
                l+=1
                if l % 100000 == 0:
                    print l, 'lines processed'
                if line.startswith('>'):
                    chr=line.strip().split('>')[1]
                    if SequenceDict.has_key(chr):
                        print 'duplicate chromosome entries detected', chr
                        sys.exit(1)
                    SequenceDict[chr]=''
                else:
                    SequenceDict[chr]=SequenceDict[chr]+line.strip()
        for chr in SequenceDict.keys():
            print chr, len(SequenceDict[chr]), 'bp'
    else:
        hg = Genome(genome)
        idb = geneinfoDB()

    GeneDict={}
    CoverageDict={}

    linelist=open(gtf)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        if doFasta:
            if SequenceDict.has_key(chr):
                pass
            else:
                continue
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if doRefFlat:
            if RefSeqToNamesDict.has_key(geneID):
                geneID = RefSeqToNamesDict[geneID]
        transcriptID=fields[8].split('gene_id "')[1].split('";')[0]
        if GeneDict.has_key(geneID):
            pass
        else:
            GeneDict[geneID]={}
        if GeneDict[geneID].has_key(transcriptID):
            pass
        else:
            GeneDict[geneID][transcriptID]=[]
        GeneDict[geneID][transcriptID].append((chr,left,right,strand))

    VariantDict1={}
    VariantDict2={}

    g=0
    bp=0
    for geneID in GeneDict.keys():
        g+=1
        longestIsoform=''
        longestIsoformLength=0
        for transcriptID in GeneDict[geneID].keys():
            length=0
            for (chr,left,right,strand) in GeneDict[geneID][transcriptID]:
                length+=(right-left)
            if length > longestIsoformLength:
                longestIsoformLength=length
                longestIsoform=transcriptID
        for transcriptID in GeneDict[geneID].keys():
            if transcriptID != longestIsoform:
                del GeneDict[geneID][transcriptID]
        for (chr,left,right,strand) in GeneDict[geneID][longestIsoform]:
            if CoverageDict.has_key(chr):
                pass
            else:
                CoverageDict[chr]={}
                VariantDict1[chr]={}
                VariantDict2[chr]={}
            for i in range(left,right):
                CoverageDict[chr][i]=''
                bp+=1

    print g, 'genes considered', bp, 'bp in total'
        
    linelist=open(indels)
    i=0
    v1=0
    v2=0
    for line in linelist:
        i+=1
        if i % 5000000 == 0:
            print i, 'lines processed in', indels
        if line.startswith('##'):
            continue
        fields=line.strip().split('\t')
        if line.startswith('#CHROM'):
            print fields
            fieldID1=fields.index(ID1)
            fieldID2=fields.index(ID2)
            continue
        chr='chr'+fields[0]
        if doFasta:
            if SequenceDict.has_key(chr):
                pass
            else:
                continue
        pos=int(fields[1])
        if CoverageDict.has_key(chr):
            pass
        else:
            CoverageDict[chr]={}
            VariantDict1[chr]={}
            VariantDict2[chr]={}
        if CoverageDict[chr].has_key(pos):
            Reference=fields[3]
            variants=fields[4].split(',')
            if fields[fieldID1] != '.':
                if doGQ:
                    try:
                        if int(fields[fieldID1].split(':')[1]) >= minGQ:
                            indel=variants[int(fields[fieldID1].split('/')[0])-1]
                            VariantDict1[chr][pos]=(Reference,indel)
                    except:
                         print 'skipping', ID1, fields[fieldID1]
                         continue
                else:
                    indel=variants[int(fields[fieldID1].split('/')[0])-1]
                    VariantDict1[chr][pos]=(Reference,indel)
                    v1+=1
            if fields[fieldID2] != '.':
                if doGQ:
                    try:
                        if int(fields[fieldID2].split(':')[1]) >= minGQ:
                            indel=variants[int(fields[fieldID2].split('/')[0])-1]
                            VariantDict2[chr][pos]=(Reference,indel)
                    except:
                         print 'skipping', ID2, fields[fieldID2]
                         continue
                else:
                    indel=variants[int(fields[fieldID2].split('/')[0])-1]
                    VariantDict2[chr][pos]=(Reference,indel)
                    v2+=1

    print 'found', v1, 'coding indels in', ID1
    print 'found', v2, 'coding indels in', ID2

    linelist=open(snps)
    i=0
    v1=0
    v2=0
    for line in linelist:
        i+=1
        if i % 5000000 == 0:
            print i, 'lines processed in', snps
        if line.startswith('##'):
            continue
        fields=line.strip().split('\t')
        if line.startswith('#CHROM'):
            print fields
            fieldID1=fields.index(ID1)
            fieldID2=fields.index(ID2)
            continue
        chr='chr'+fields[0]
        if doFasta:
            if SequenceDict.has_key(chr):
                pass
            else:
                continue
        pos=int(fields[1])
        if CoverageDict[chr].has_key(pos):
            Reference=fields[3]
            variants=fields[4].split(',')
            if fields[fieldID1][0:2] != '0/' and fields[fieldID1][0:2] != './':
                if doGQ:
                    if int(fields[fieldID1].split(':')[4]) >= minGQ:
                        snp=variants[int(fields[fieldID1].split('/')[0])-1]
                        VariantDict1[chr][pos]=(Reference,snp)
                else:
                    snp=variants[int(fields[fieldID1].split('/')[0])-1]
                    VariantDict1[chr][pos]=(Reference,snp)
                    v1+=1
            if fields[fieldID2][0:2] != '0/' and fields[fieldID2][0:2] != './':
                if doGQ:
                    if int(fields[fieldID2].split(':')[4]) >= minGQ:
                        snp=variants[int(fields[fieldID2].split('/')[0])-1]
                        VariantDict2[chr][pos]=(Reference,snp)
                else:
                    snp=variants[int(fields[fieldID2].split('/')[0])-1]
                    VariantDict2[chr][pos]=(Reference,snp)
                    v2+=1

    print 'found', v1, 'coding SNPs in', ID1
    print 'found', v2, 'coding SNPs in', ID2

#    for chr in VariantDict1.keys():
#        for pos in VariantDict1[chr].keys():
#            outline = chr + '\t' + VariantDict1[chr][pos][0] + '\t' + VariantDict1[chr][pos][1]

    outfile = open(outfilename, 'w')
    outfilelog = open(outfilename + '.log', 'w')

    outline='#GeneID \tchr\tmRNAPos\tGenomicPos\tReference\t' + ID1 + '\t' + ID2
    outfilelog.write(outline +'\n')

    for geneID in GeneDict.keys():
        if len(GeneDict[geneID].keys()) > 1:
            print 'more than transcritps found', GeneDict[geneID].keys()
        for transcriptID in GeneDict[geneID].keys():
            pos=0
            mRNAVariantsDict={}
            sequence=''
            GeneDict[geneID][transcriptID].sort()
            coordinates=[]
            for (chr,left,right,strand) in GeneDict[geneID][transcriptID]:
                coordinates.append(left)
                coordinates.append(right)
                if doFasta:
                    sequence=sequence+SequenceDict[chr][left:right]
                else:
                    try:
                        sequence=sequence+getSequence(genome,chr,left,right,'+')
                    except:
                        print 'skipping', chr,left,right
                        continue
                for i in range(left,right):
                    if VariantDict1[chr].has_key(i) or VariantDict2[chr].has_key(i):
                        mRNAVariantsDict[pos]={}
                        if VariantDict1[chr].has_key(i) and VariantDict2[chr].has_key(i):
                            mRNAVariantsDict[pos][ID1]=VariantDict1[chr][i]
                            mRNAVariantsDict[pos][ID2]=VariantDict2[chr][i]
                            outline=geneID + '\t' + chr  + '\t' + str(pos) + '\t' + str(i) + '\t' + mRNAVariantsDict[pos][ID1][0] + '\t' + mRNAVariantsDict[pos][ID1][1] + '\t' + mRNAVariantsDict[pos][ID2][1]
                            outfilelog.write(outline + '\n')
                        elif VariantDict1[chr].has_key(i):
                            mRNAVariantsDict[pos][ID1]=VariantDict1[chr][i]
                            mRNAVariantsDict[pos][ID2]='NONE'
                            outline=geneID + '\t' + chr  + '\t' + str(pos) + '\t' + str(i) + '\t' + mRNAVariantsDict[pos][ID1][0] + '\t' + mRNAVariantsDict[pos][ID1][1] + '\t' + mRNAVariantsDict[pos][ID2]
                            outfilelog.write(outline + '\n')
                        elif VariantDict2[chr].has_key(i):
                            mRNAVariantsDict[pos][ID2]=VariantDict2[chr][i]
                            mRNAVariantsDict[pos][ID1]='NONE'
                            outline=geneID + '\t' + chr  + '\t' + str(pos) + '\t' + str(i) + '\t' + mRNAVariantsDict[pos][ID2][0] + '\t' + mRNAVariantsDict[pos][ID1] + '\t' + mRNAVariantsDict[pos][ID2][1]
                            outfilelog.write(outline + '\n')
                    pos+=1
            sequence=string.upper(sequence)
            if sequence=='':
                continue
            if len(mRNAVariantsDict.keys()) == 0:
                outfile.write('>' + geneID + '::' + chr +':' + str(left) + '-' + str(right) + strand + '\n')
                if strand == '-':
                    sequence=reverseComplement(sequence,DNA)
                outfile.write(sequence + '\n')
            else:
                positions=mRNAVariantsDict.keys()
                positions.sort()
                positions.reverse()
#                print positions
                ID1sequence=sequence
                ID2sequence=sequence
                for pos in positions:
                    if mRNAVariantsDict[pos][ID1]!='NONE':
                        refLength=len(mRNAVariantsDict[pos][ID1][0])
#                        print ID1sequence[pos-1:pos+refLength], 
                        ID1sequence=ID1sequence[0:pos-1] + mRNAVariantsDict[pos][ID1][1] + ID1sequence[pos+refLength-1:len(ID1sequence)]
#                        print mRNAVariantsDict[pos][ID1][0], mRNAVariantsDict[pos][ID1][1], ID1sequence[pos:pos+len(mRNAVariantsDict[pos][ID1][1])]
                    if mRNAVariantsDict[pos][ID2]!='NONE':
                        refLength=len(mRNAVariantsDict[pos][ID2][0])
#                        print ID2sequence[pos-1:pos+refLength], 
                        ID2sequence=ID2sequence[0:pos-1] + mRNAVariantsDict[pos][ID2][1] + ID2sequence[pos+refLength-1:len(ID2sequence)]
#                        print mRNAVariantsDict[pos][ID2][0], mRNAVariantsDict[pos][ID2][1], ID1sequence[pos:pos+len(mRNAVariantsDict[pos][ID2][1])]
                outfile.write('>' + geneID + '::' + chr +':' + str(min(coordinates)) + '-' + str(max(coordinates)) + strand + '::' + ID1 + '\n')
                if strand == '-':
                    ID1sequence=reverseComplement(ID1sequence,DNA)
                outfile.write(ID1sequence + '\n')
                outfile.write('>' + geneID + '::' + chr +':' + str(min(coordinates)) + '-' + str(max(coordinates)) + strand + '::' + ID2 + '\n')
                if strand == '-':
                    ID2sequence=reverseComplement(ID2sequence,DNA)
                outfile.write(ID2sequence + '\n')

    outfile.close()
    outfilelog.close()

run()