##################################
#                                #
# Last modified 06/02/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s - -' % sys.argv[0]
        print '\tthis script will take SAM output from an aligner run with multiplicity up to 2 allowed (note: do not sort it before that) and filter out reads that map to more than two locations but not on the same chromosome in the two haplotypes' 
        print '\tit runs on standard input and prints to stadard output; run it with two dashes as parameters' 
        print '\tit is assumed chromosomes are specified as follows: chr::haplotype/strain' 
        sys.exit(1)

    linelist=sys.stdin
    currentRead = ''
    alignments = []
    for line in linelist:
        if line.startswith('#'):
            continue
        if line.startswith('@'):
            print line.strip()
            continue
        fields=line.strip().split('\t')
        ID = fields[0]
        if ID != currentRead:
            if alignments == []:
                pass
            else:
                if len(alignments) > 2:
                    print 'more than two alignments found for read', currentRead, 'exiting'
                    print alignments
                    sys.exit(1)
                if len(alignments) == 2:
                    chrStrain1 = alignments[0].split('\t')[2]
                    chrStrain2 = alignments[1].split('\t')[2]
                    chr1 = chrStrain1.split('::')[0]
                    Strain1 = chrStrain1.split('::')[1]
                    chr2 = chrStrain2.split('::')[0]
                    Strain2 = chrStrain2.split('::')[1]
                    if chr1 == chr2 and Strain1 != Strain2:
                        print alignments[0]
                        print alignments[1]
                if len(alignments) == 1:
                    if alignments[0].split('\t')[2] != '*':
                        print alignments[0]
            alignments = []
            alignments.append(line.strip())
            currentRead = ID
        else:
            alignments.append(line.strip())

run()