##################################
#                                #
# Last modified 07/19/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string

def run():

    if len(sys.argv) < 2:
        print 'usage: python % snps.vcf outfile [-coding gtf]' % sys.argv[0]
        sys.exit(1)

    snps = sys.argv[1]
    outfilename = sys.argv[2]

    doCodingOnly=False
    if '-coding' in sys.argv:
        doCodingOnly=True
        gtf=sys.argv[sys.argv.index('-coding')+1]
        CoverageDict={}
        linelist=open(gtf)
        for line in linelist:
            if line.startswith('#'):
               continue
            fields=line.strip().split('\t')
            if fields[2]!='exon':
               continue
            chr=fields[0]
            left=int(fields[3])
            right=int(fields[4])
            if CoverageDict.has_key(chr):
                pass
            else:
                CoverageDict[chr]={}
            for i in range(left,right):
                CoverageDict[chr][i]=''

    DifferencesDict={}
    FieldIDDict={}

    linelist=open(snps)
    i=0
    v1=0
    v2=0
    for line in linelist:
        i+=1
        if i % 1000000 == 0:
            print i, 'lines processed in', snps
        if line.startswith('##'):
            continue
        fields=line.strip().split('\t')
        if line.startswith('#CHROM'):
            for j in range(9,len(fields)):
                FieldIDDict[j]=fields[j]
                DifferencesDict[fields[j]]={}
                for k in range(9,len(fields)):
                    DifferencesDict[fields[j]][fields[k]]=0
            print DifferencesDict
            continue
        if doCodingOnly:
            chr='chr'+fields[0]
            pos=int(fields[1])
            if CoverageDict.has_key(chr):
                pass
            else:
                continue
            if CoverageDict[chr].has_key(pos):
                pass
            else:
                continue
        for j in range(9,len(fields)): 
            for k in range(j,len(fields)):
                if fields[j].split('/')[0] != fields[k].split('/')[0]:
                    DifferencesDict[FieldIDDict[j]][FieldIDDict[k]]+=1
                    DifferencesDict[FieldIDDict[k]][FieldIDDict[j]]+=1

    print DifferencesDict

    outfile = open(outfilename, 'w')

    outline='#Genotype'
    keys=DifferencesDict.keys()
    keys.sort()

    for GT in keys:
        outline=outline+'\t'+GT
    outfile.write(outline +'\n')

    for GT1 in keys:
        outline=GT1
        for GT2 in keys:
            outline=outline+'\t'+str(DifferencesDict[GT1][GT2])
        outfile.write(outline +'\n')

    outfile.close()

run()