##################################
#                                #
# Last modified 03/22/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s VCF GTF genome.fa extension_radius outfile_prefix' % sys.argv[0]
        print '\tThe script can read compressed VCF files as long as they have the correct suffix - .bz2 or .gz'
        sys.exit(1)

    VCF = sys.argv[1]
    GTF = sys.argv[2]
    fasta = sys.argv[3]
    radius = int(sys.argv[4])
    outfile_prefix = sys.argv[5]

    VCFDict={}
    VCFDict['pat']={}
    VCFDict['mat']={}

    lineslist = open(VCF)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        chr = 'chr' + fields[0]
        pos = int(fields[1])-1
        REF = fields[3]
        ALT = fields[4].split(',')
        GT_field = fields[8].split(':').index('GT')
        GT = fields[9].split(':')[GT_field]
        pat = int(GT.split('/')[0])
        mat = int(GT.split('/')[1])
        if pat != 0:
            if VCFDict['pat'].has_key(chr):
                pass
            else:
                VCFDict['pat'][chr]={}
            VCFDict['pat'][chr][pos] = REF + '::' + ALT[pat-1]
        if mat != 0:
            if VCFDict['mat'].has_key(chr):
                pass
            else:
                VCFDict['mat'][chr]={}
            VCFDict['mat'][chr][pos] = REF + '::' + ALT[mat-1]

    GenomeDict={}

    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence)

    ExonDict = {}

    lineslist = open(GTF)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        if ExonDict.has_key(chr):
            pass
        else:
            ExonDict[chr]={}
        left=int(fields[3])
        right=int(fields[4])
        ExonDict[chr][(left,right)]=''

    NewExonDict={}

    for chr in ExonDict.keys():
        NewExonDict[chr]={}
        exons = ExonDict[chr].keys()
        exons.sort()
        current = exons[0]
        for i in range(1,len(exons)):
            if current[1] < exons[i][0]:
                NewExonDict[chr][current]=''
                current = exons[i]
            else:
                current = (min(current[0],exons[i][0]),max(current[1],exons[i][1]))
        NewExonDict[chr][current]=''

    chromosomes = NewExonDict.keys()
    chromosomes.sort()

    outfile = open(outfile_prefix + '.bed','w')

    for chr in chromosomes:
        exons = NewExonDict[chr].keys()
        exons.sort()
        for (left,right) in exons:
            outline = chr + '\t' + str(left - radius) + '\t' + str(right + radius)
            outfile.write(outline + '\n')

    outfile.close()

    outfile = open(outfile_prefix + '.mat_pat.fa','w')

    for chr in chromosomes:
        exons = NewExonDict[chr].keys()
        exons.sort()
        for (left,right) in exons:
            newleft = left - radius
            newright = right + radius
            mat_sequence = '' 
            pos = newright
            for i in xrange(newright,newleft,-1):
                if VCFDict['mat'].has_key(chr) and VCFDict['mat'][chr].has_key(i):
                    REF = VCFDict['mat'][chr][i].split('::')[0]
                    ALT = VCFDict['mat'][chr][i].split('::')[1]
                    mat_sequence = ALT + GenomeDict[chr][i+len(REF):pos] + mat_sequence
                    pos = i
            mat_sequence = GenomeDict[chr][newleft:pos] + mat_sequence
            pat_sequence = ''
            pos = newright
            for i in xrange(newright,newleft,-1):
                if VCFDict['pat'].has_key(chr) and VCFDict['pat'][chr].has_key(i):
                    REF = VCFDict['pat'][chr][i].split('::')[0]
                    ALT = VCFDict['pat'][chr][i].split('::')[1]
                    pat_sequence = ALT + GenomeDict[chr][i+len(REF):pos] + pat_sequence
                    pos = i
            pat_sequence = GenomeDict[chr][newleft:pos] + pat_sequence
            outline = '>' + chr + ':' + str(newleft) + '-' + str(newright) + '-maternal'
            outfile.write(outline + '\n')
            for i in range(0,len(mat_sequence),50):
                outfile.write(mat_sequence[i:min(i+50, len(mat_sequence))] + '\n')
            outline = '>' + chr + ':' + str(newleft) + '-' + str(newright) + '-paternal'
            outfile.write(outline + '\n')
            for i in range(0,len(pat_sequence),50):
                outfile.write(pat_sequence[i:min(i+50, len(pat_sequence))] + '\n')
            
run()