##################################
#                                #
# Last modified 03/26/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import random
import string
from sets import Set

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for i in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-i-1]]
    return sequence

def getCodon(GenomDict,chr,positions,i,strand,V):

    codon = ''
    if i % 3 == 1:
        codon = GenomDict[chr][positions[i-1]] + GenomDict[chr][positions[i]] + GenomDict[chr][positions[i+1]]
        newcodon = V + GenomDict[chr][positions[i]] + GenomDict[chr][positions[i+1]]
    if i % 3 == 2:
        codon = GenomDict[chr][positions[i-2]] + GenomDict[chr][positions[i-1]] + GenomDict[chr][positions[i]]
        newcodon = GenomDict[chr][positions[i-2]] + V + GenomDict[chr][positions[i]]
    if i % 3 == 0:
        codon = GenomDict[chr][positions[i-3]] + GenomDict[chr][positions[i-2]] + GenomDict[chr][positions[i-1]]
        newcodon = GenomDict[chr][positions[i-3]] + GenomDict[chr][positions[i-2]] + V
    if strand == '-':
        codon = getReverseComplement(codon)
        newcodon = getReverseComplement(newcodon)
    codon = codon.upper().replace('T','U')
    newcodon = newcodon.upper().replace('T','U')

    return (codon,newcodon)

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s fasta VCF gtf outfilename' % sys.argv[0]
        print '\tNote: this script assumes there is CDS annotatoin in the GTF file and ignores indels'
        sys.exit(1)

    fasta = sys.argv[1]
    vcf=sys.argv[2]
    gtf=sys.argv[3]
    outputfilename = sys.argv[4]

    CodonDict={'GCU':'A', 'GCC':'A', 'GCA':'A', 'GCG':'A',
               'UUA':'L', 'UUG':'L', 'CUU':'L', 'CUC':'L', 'CUA':'L', 'CUG':'L',
               'CGU':'R', 'CGC':'R', 'CGA':'R', 'CGG':'R', 'AGA':'R', 'AGG':'R',
               'AAA':'K', 'AAG':'K',
               'AAU':'N', 'AAC':'N',
               'AUG':'M',
               'GAU':'D', 'GAC':'D',
               'UUU':'F', 'UUC':'F',
               'UGU':'C', 'UGC':'C',
               'CCU':'P', 'CCC':'P', 'CCA':'P', 'CCG':'P',
               'CAA':'Q', 'CAG':'Q',
               'UCU':'S', 'UCC':'S', 'UCA':'S', 'UCG':'S', 'AGU':'S', 'AGC':'S',
               'GAA':'E', 'GAG':'E',
               'ACU':'T', 'ACC':'T', 'ACA':'T', 'ACG':'T',
               'GGU':'G', 'GGC':'G', 'GGA':'G', 'GGG':'G',
               'UGG':'W',
               'CAU':'H', 'CAC':'H',
               'UAU':'Y', 'UAC':'Y',
               'AUU':'I', 'AUC':'I', 'AUA':'I',
               'GUU':'V', 'GUC':'V', 'GUA':'V', 'GUG':'V',
               'START':'AUG',
               'UAA':'STOP',
               'UGA':'STOP',
               'UAG':'STOP'}

    SNPDict={}

    j=0
    retained = 0
    lineslist = open(vcf)
    for line in lineslist:
        j+=1
        if j % 100000 == 0:
            print j, 'lines processed'
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        if chr.startswith('chr'):
            pass
        else:
            chr = 'chr' + chr
        pos = int(fields[1])-1
        REF = fields[3]
        ALT = fields[4].split(',')
        if len(REF) > 1 or len(ALT[0]) > 1:
            continue
        if SNPDict.has_key(chr):
            pass
        else:
            SNPDict[chr]={}
        retained += 1
        GTID = fields[8].split(':').index('GT')
        GT = fields[9].split(':')[GTID]
        SNPDict[chr][pos] = (ALT,GT)

    print 'retained', retained, 'SNVs'

    GenomeDict={}

    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    lineslist = open(gtf)
    TranscriptDict={}
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon' and fields[2]!='CDS':
            continue
        chr=fields[0]
        if GenomeDict.has_key(chr) and SNPDict.has_key(chr):
            pass
        else:
            continue
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=geneID
        if 'transcript_name "' in fields[8]:
            transcriptName=fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName=TranscriptID
        transcript = (geneName,geneID,transcriptName,transcriptID)
        if TranscriptDict.has_key(transcript):
            pass
        else:
            TranscriptDict[transcript]={}
            TranscriptDict[transcript]['exons']=[]
            TranscriptDict[transcript]['CDS']=[]
        left=int(fields[3])
        right=int(fields[4])
        strand=fields[6]
        if fields[2] == 'exon':
             TranscriptDict[transcript]['exons'].append((chr,left,right,strand))
        if fields[2] == 'CDS':
             TranscriptDict[transcript]['CDS'].append((chr,left,right,strand))

    outfile = open(outputfilename, 'w')
    outfile.write('#CHROM\tPOS\tREF\tALT\tGT\tGeneName\tGeneID\tTranscriptName\tTranscriptID\tEffect\n')

    tr=0
    print 'Found', len(TranscriptDict.keys()), 'transcripts'
    transcripts = TranscriptDict.keys()
    transcripts.sort()
    for transcript in transcripts:
        tr+=1
        if tr % 1000 == 0:
            print tr, 'transcripts sequences processed'
        (geneName,geneID,transcriptName,transcriptID) = transcript
        HasSNPs = False
        TranscriptDict[transcript]['exons'].sort()
        TranscriptDict[transcript]['CDS'].sort()
        for (chr,left,right,strand) in TranscriptDict[transcript]['exons']:
            for i in range(left-2,right+2):
                if SNPDict.has_key(chr) and SNPDict[chr].has_key(i):
                    HasSNPs = True
                    break
            if HasSNPs:
                break
        if not HasSNPs:
            continue
        Coding = False
        if len(TranscriptDict[transcript]['CDS']) > 0:
            Coding = True
        e=0
        if len(TranscriptDict[transcript]['exons']) > 1:
            e+=1
            for (chr,left,right,strand) in TranscriptDict[transcript]['exons']:
                if e == 1:
                    for i in range(right,right+2):
                        if SNPDict[chr].has_key(i):
                            outline = chr + '\t' + str(i) + '\t' + GenomeDict[chr][i] + '\t'
                            ALT = SNPDict[chr][pos][0]
                            GT = SNPDict[chr][pos][1]
                            for V in ALT:
                                outline = outline + V + ','
                            outline = outline[0:-1] + '\t' + GT + '\t' + geneName + '\t' + geneID + '\t' + transcriptName + '\t' + transcriptID + '\t' + 'SpliceSite'
                            outfile.write(outline + '\n')
                elif e == len(TranscriptDict[transcript]['exons']):
                    for i in range(left-3,left-1):
                        if SNPDict[chr].has_key(i):
                            outline = chr + '\t' + str(i) + '\t' + GenomeDict[chr][i] + '\t'
                            ALT = SNPDict[chr][pos][0]
                            GT = SNPDict[chr][pos][1]
                            for V in ALT:
                                outline = outline + V + ','
                            outline = outline[0:-1] + '\t' + GT + '\t' + geneName + '\t' + geneID + '\t' + transcriptName + '\t' + transcriptID + '\t' + 'SpliceSite'
                            outfile.write(outline + '\n')
                else:
                    for i in range(left-3,left-1):
                        if SNPDict[chr].has_key(i):
                            outline = chr + '\t' + str(i) + '\t' + GenomeDict[chr][i] + '\t'
                            ALT = SNPDict[chr][pos][0]
                            GT = SNPDict[chr][pos][1]
                            for V in ALT:
                                outline = outline + V + ','
                            outline = outline[0:-1] + '\t' + GT + '\t' + geneName + '\t' + geneID + '\t' + transcriptName + '\t' + transcriptID + '\t' + 'SpliceSite'
                            outfile.write(outline + '\n')
                    for i in range(right,right+2):
                        if SNPDict[chr].has_key(i):
                            outline = chr + '\t' + str(i) + '\t' + GenomeDict[chr][i] + '\t'
                            ALT = SNPDict[chr][pos][0]
                            GT = SNPDict[chr][pos][1]
                            for V in ALT:
                                outline = outline + V + ','
                            outline = outline[0:-1] + '\t' + GT + '\t' + geneName + '\t' + geneID + '\t' + transcriptName + '\t' + transcriptID + '\t' + 'SpliceSite'
                            outfile.write(outline + '\n')
        CDSpositions = []
        for (chr,left,right,strand) in TranscriptDict[transcript]['CDS']:
            for i in range(left-1,right):
                CDSpositions.append(i)
        Exonpositions = []
        for (chr,left,right,strand) in TranscriptDict[transcript]['exons']:
            for i in range(left-1,right):
                Exonpositions.append(i)
        CDSpositions.sort()
        Exonpositions.sort()
        if strand == '-':
            CDSpositions.reverse()
            Exonpositions.reverse()
        CDSpositionsDict={}
        i=0
        for pos in CDSpositions:
            CDSpositionsDict[pos]=0
            i+=1
            if SNPDict[chr].has_key(pos):
                ALT = SNPDict[chr][pos][0]
                GT = SNPDict[chr][pos][1].split('/')
                ALTtoConsider = []
                for g in GT:
                    g = int(g)
                    if g != 0:
                        ALTtoConsider.append(ALT[g-1])
                ALTtoConsider = list(Set(ALTtoConsider))
                outlineM = chr + '\t' + str(pos) + '\t' + GenomeDict[chr][pos] + '\t'
                for V in ALT:
                    outlineM = outlineM + V + ','
                outlineM = outlineM[0:-1] + '\t' + SNPDict[chr][pos][1] + '\t' + geneName + '\t' + geneID + '\t' + transcriptName + '\t' + transcriptID + '\t'
                for V in ALTtoConsider:                    
                    (codon,newcodon) = getCodon(GenomeDict,chr,CDSpositions,i,strand,V)
                    if CodonDict[codon] == CodonDict[newcodon]:
                        outline = outlineM + 'synonymous,'
                    if CodonDict[codon] != CodonDict[newcodon] and CodonDict[newcodon] != 'STOP':
                        outline = outlineM + 'missense,'
                    if CodonDict[codon] != CodonDict[newcodon] and CodonDict[newcodon] == 'STOP':
                        outline = outlineM + 'nonsense,'
                    outfile.write(outline[0:-1] + '\n') 
        current = "5'UTR"
        for pos in Exonpositions:
            if CDSpositionsDict.has_key(pos):
                current = 'CDS'
                continue
            else:
                if current == 'CDS':
                    current = "3'UTR"
                if SNPDict[chr].has_key(pos):
                    outline = chr + '\t' + str(pos) + '\t' + GenomeDict[chr][pos] + '\t'
                    ALT = SNPDict[chr][pos][0]
                    GT = SNPDict[chr][pos][1]
                    for V in ALT:
                        outline = outline + V + ','
                    outline = outline[0:-1] + '\t' + GT + '\t' + geneName + '\t' + geneID + '\t' + transcriptName + '\t' + transcriptID + '\t' + current
                    outfile.write(outline + '\n')

    outfile.close()

run()

