##################################
#                                #
# Last modified 04/29/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
import os

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s vcf1 vcf2 outfile_prefix [-noVCF1-only] [-noVCF2-only] [-ignoreIndels]' % sys.argv[0]
        print '\tNote: the script can read compressed files, as long as they end with .bz2 or .gz'
        print '\tNote: have the smaller file be the first one in order to save memory'
        sys.exit(1)

    vcf1 = sys.argv[1]
    vcf2 = sys.argv[2]
    outfile_prefix = sys.argv[3]

    doVCF1=True
    if '-noVCF1-only' in sys.argv:
        doVCF1=False

    doVCF2=True
    if '-noVCF2-only' in sys.argv:
        doVCF2=False

    IgnoreIndels = False
    if '-ignoreIndels' in sys.argv:
        IgnoreIndels = True	

    VCF1Dict = {}

    line = '#line'
    if vcf1.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + vcf1
    elif  vcf1.endswith('.gz'):
        cmd = 'gunzip c ' + vcf1
    else:
        cmd = 'cat ' + vcf1
    p = os.popen(cmd, "r")
    ii=0
    while line != '':
        ii+=1
        if ii % 1000000 == 0:
            print str(ii/1000000) + 'M lines processed'
        line = p.readline()
        if line.startswith('#') or line == '':
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        if chr.startswith('chr'):
            chr = chr[3:]
        pos = int(fields[1])
        REF = fields[3]
        ALT = Set(fields[4].split(','))
        if IgnoreIndels:
            indel = False
            if len(REF) > 1:
                indel = True
            for a in ALT:
                if len(a) > 1:
                    indel = True
            if indel:
                continue
        VCF1Dict[(chr,pos)] = (REF,ALT)

    CommonDict = {}

    if doVCF2:
        outfile2 = open(outfile_prefix + '.vcf2-only.vcf', 'w')

    line = '#line'
    if vcf2.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + vcf2
    elif  vcf2.endswith('.gz'):
        cmd = 'gunzip c ' + vcf2
    else:
        cmd = 'cat ' + vcf2
    p = os.popen(cmd, "r")
    while line != '':
        ii+=1
        if ii % 1000000 == 0:
            print str(ii/1000000) + 'M lines processed'
        line = p.readline()
        if line.startswith('#') or line == '':
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        if chr.startswith('chr'):
            chr = chr[3:]
        pos = int(fields[1])
        REF = fields[3]
        ALT = Set(fields[4].split(','))
        if IgnoreIndels:
            indel = False
            if len(REF) > 1:
                indel = True
            for a in ALT:
                if len(a) > 1:
                    indel = True
            if indel:
                continue
        if VCF1Dict.has_key((chr,pos)) and (REF,ALT) == VCF1Dict[(chr,pos)]:
            CommonDict[(chr,pos)] = (REF,ALT)
        else:
            if doVCF2:
                outfile2.write(line)

    if doVCF2:
        outfile2.close()

    outfile3 = open(outfile_prefix + '.common.vcf', 'w')

    SNPs = CommonDict.keys()
    SNPs.sort()
    for (chr,pos) in SNPs:
        (REF,ALT) = CommonDict[(chr,pos)]
        outline = chr + '\t' + str(pos) + '\t' + '.' + '\t' + REF + '\t'
        for V in list(ALT):
            outline = outline + V + ','
            outline = outline[0:-1]
            outfile3.write(outline + '\n')

    outfile3.close()

    if doVCF1:
        outfile1 = open(outfile_prefix + '.vcf1-only.vcf', 'w')
        line = '#line'
        if vcf1.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + vcf1
        elif  vcf1.endswith('.gz'):
            cmd = 'gunzip c ' + vcf1
        else:
            cmd = 'cat ' + vcf1
        p = os.popen(cmd, "r")
        while line != '':
            line = p.readline()
            if line == '':
                continue
            if line.startswith('#'):
                outfile1.write(line)
                continue
            fields = line.strip().split('\t')
            chr = fields[0]
            if chr.startswith('chr'):
                chr = chr[3:]
            pos = int(fields[1])
            REF = fields[3]
            ALT = Set(fields[4].split(','))
            VCF1Dict[(chr,pos)] = (REF,ALT)
            if CommonDict.has_key((chr,pos)):
                continue
            else:
                outfile1.write(line)
        outfile1.close()

run()