##################################
#                                #
# Last modified 03/24/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import random
import os
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s VCF number_heterozygous_variants number_homozygous_variants outfile' % sys.argv[0]
        print 'Note: Heterozygous variants will be split 50/50 by chance between the two parents'
        sys.exit(1)

    VCF = sys.argv[1]
    N_het = int(sys.argv[2])
    N_homo = int(sys.argv[3])
    outfile = sys.argv[4]

    NumbersList=[]

    i=0
    lineslist = open(VCF)
    for line in lineslist:
        if line.startswith('#'):
            continue
        i+=1
        NumbersList.append(i)

    HetVariants = random.sample(NumbersList,N_het)
    NonHetNumbersList = list(Set(NumbersList) - Set(HetVariants))
    HomoVariants = random.sample(NonHetNumbersList,N_homo)

    HetDict={}
    HomoDict={}

    for i in HetVariants:
        HetDict[i]=0

    for i in HomoVariants:
        HomoDict[i]=0

    print len(HomoDict), len(HetDict)

    outfile = open(outfile,'w')

    outfile.write('##FORMAT=GT,1,String,"Genotype"' + '\n')
    outfile.write('#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE' + '\n')

    lineslist = open(VCF)
    i=0
    for line in lineslist:
        if line.startswith('#'):
            continue
        i+=1
        if HomoDict.has_key(i) or HetDict.has_key(i):
            pass
        else:
            continue
        fields = line.strip().split('\t')
        outline = fields[0] + '\t' + fields[1] + '\t' + fields[2] + '\t' + fields[3] + '\t' + fields[4] + '\t' + fields[5] + '\t' + fields[6] + '\t' + '.' + '\t'
        outline = outline + 'GT' + '\t'
        if HomoDict.has_key(i):
             outline = outline + '1/1'
             outfile.write(outline + '\n')
        if HetDict.has_key(i):
             r = random.randint(0,1000)
             if r >= 500:
                 outline = outline + '0/1'
             if r < 500:
                 outline = outline + '1/0'
             outfile.write(outline + '\n')

    outfile.close()

run()