##################################
#                                #
# Last modified 2019/08/27       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import gzip

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s genome_fasta vcf fieldID ID outfile [-DP minDP] [-GQ minGQ] [-addChr] [-noInDels]' % sys.argv[0]
        print '\toutput format for the heterozygous genome: chrN::ID' 
        print '\tBe careful about 1- and 0-based genomes and annotations when working with new genomes' 
        print '\tIMPORTANT: The script assumes an inbred strain and unphase calls (i.e. ./. sepatators) and will disregard all heterozygous positions in the VCF file; use the -DP and -GQ options to further filter variants' 
        sys.exit(1)

    doGQ=False
    if '-GQ' in sys.argv:
        doGQ=True
        minGQ=int(sys.argv[sys.argv.index('-GQ')+1])

    doDP=False
    if '-DP' in sys.argv:
        doDP=True
        minDP=int(sys.argv[sys.argv.index('-DP')+1])

    doAddChr = False
    if '-addChr' in sys.argv:
        doAddChr = True

    doNoInDels = False
    if '-noInDels' in sys.argv:
        print 'will omit indels'
        doNoInDels = True

    fasta=sys.argv[1]
    VCF = sys.argv[2]
    fieldID = int(sys.argv[3])
    ID = sys.argv[4]
    outfilename = sys.argv[5]

    inputdatafile = open(fasta)
    SequenceDict={}
    sequence = ''
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                sequence = ''.join(sequence)
                SequenceDict[chr]=sequence
            chr = line.strip().split('>')[1]
            sequence=[]
        else:
            sequence.append(line.strip())   
    sequence = ''.join(sequence)
    SequenceDict[chr]=sequence

    VariantDict={}

    if VCF.endswith(.gz):
        linelist = gzip.open(VCF)
    else:
        linelist = open(VCF)
    i=0
    j=0
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        i+=1
        if doDP:
            if fields[7].startswith('DP='):
                DP = int(fields[7].split('DP=')[1].split(';')[0])
            else:
                DP = int(fields[7].split(';DP=')[1].split(';')[0])
            if DP < minDP:
                continue
        if doAddChr:
            chr='chr'+fields[0]
        else:
            chr = fields[0]
        if VariantDict.has_key(chr):
            pass
        else:
            VariantDict[chr]={}
        pos=int(fields[1])-1
        ref=fields[3]
        variants=fields[4].split(',')
        GQfield = fields[8].split(':').index('GQ')
        GTfield = fields[8].split(':').index('GT')
        if fields[fieldID] == './.':
            continue
        GQ = fields[fieldID].split(':')[GQfield]
        GT = fields[fieldID].split(':')[GTfield]
        if doGQ:
            if int(GQ) < minGQ:
                continue
        alleles = GT.split('/')
        if alleles[0] != alleles[1]:
            continue
        if alleles[0] == '0':
            continue
        if doNoInDels:
            if len(ref) != len(variants[int(alleles[0])-1]):
                continue
        VariantDict[chr][pos] = (ref,variants[int(alleles[0])-1])
        j+=1

    print 'retained', j, 'variants out of', i

    outfile = open(outfilename, 'w')

    keys = SequenceDict.keys()
    keys.sort()

    for chr in keys:
        outline = '>' + chr + '::' + ID
        outfile.write(outline + '\n')
        if VariantDict.has_key(chr):
            if len(VariantDict[chr].keys())==0:
                sequence = SequenceDict[chr]
            else:
                sequence = []
                variants = VariantDict[chr].keys()
                variants.sort()
                current = 0
                for pos in variants:
                    (ref,var) = VariantDict[chr][pos]
                    sequence.append(SequenceDict[chr][current:pos])
                    sequence.append(var)
                    current = pos + len(ref)
                sequence.append(SequenceDict[chr][current:len(SequenceDict[chr])])
                sequence = ''.join(sequence)
        else:
            sequence = SequenceDict[chr]
        for b in range(0,len(sequence),50):
             outfile.write(sequence[b:min(b+50,len(sequence))] + '\n')
        
    outfile.close()

run()