##################################
#                                #
# Last modified 05/15/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s gtf map_files_table outfile_prefix' % sys.argv[0]
        print '       map_files_table format:   map_file <tab> chr' 
        print '       map _file format:   #REF	PAT	MAT' 
        sys.exit(1)

    GTF = sys.argv[1]
    map_file_list = sys.argv[2]
    outfile_prefix = sys.argv[3]

    CoordinateDict={}

    linelist=open(GTF)
    for line in linelist:
        fields = line.strip().split('\t')
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        if CoordinateDict.has_key(chr):
            pass
        else:
            CoordinateDict[chr]={}
        CoordinateDict[chr][left]={}
        CoordinateDict[chr][right]={}

    print 'finished inputting GTF'

    PresentDict = {}

    CorrespondenceDictPat={}
    CorrespondenceDictMat={}

    linelist=open(map_file_list)
    for line1 in linelist:
        fields = line1.strip().split('\t')
        chr = fields[1]
        PresentDict[chr]=''
        CorrespondenceDictPat[chr] = {}
        file = fields[0]
        inputstream = open(file)
        print chr, file
        ref = 0
        pat = 0
        deletion = False
        insertion = False
        for line in inputstream:
            fields = line.strip().split('\t')
            if line.startswith('#'):
                patFieldID = fields.index('PAT')
                continue
            refCurrent = int(fields[0])
            patCurrent = int(fields[patFieldID])
            if refCurrent != 0 and patCurrent != 0:
                if insertion:
                    CorrespondenceDictPat[chr][refCurrent] = []
                    for i in range(pat,patCurrent):
                        CorrespondenceDictPat[chr][refCurrent].append(i)
                    for i in range(ref,refCurrent):
                        if CoordinateDict[chr].has_key(i):
                            CorrespondenceDictPat[chr][i] = [pat - (refCurrent - i)]
                elif deletion:
                    for i in range(ref,refCurrent):
                        CorrespondenceDictPat[chr][i] = [patCurrent]
                else:
                    for i in range(ref,refCurrent):
                        if CoordinateDict[chr].has_key(i):
                            CorrespondenceDictPat[chr][i] = [patCurrent - (refCurrent - i)]
                    CorrespondenceDictPat[chr][refCurrent] = [patCurrent]
                ref = refCurrent
                pat = patCurrent
                deletion = False
                insertion = False
                continue
            if refCurrent == 0 and patCurrent != 0:
                insertion = True
                pat = patCurrent
                continue
            if refCurrent != 0 and patCurrent == 0:
                deletion = True
                for i in range(ref,refCurrent):
                    if CoordinateDict[chr].has_key(i):
                         CorrespondenceDictPat[chr][i] = [pat + (i - ref)]
                ref = refCurrent
                continue
            if refCurrent == 0 and patCurrent == 0:
                continue
        if max(CoordinateDict[chr].keys()) > refCurrent:
            for i in range(refCurrent, max(CoordinateDict[chr].keys()) + 10):
                if CoordinateDict[chr].has_key(i):
                    CorrespondenceDictPat[chr][i] = [patCurrent + (i - refCurrent)]

    linelist=open(map_file_list)
    for line1 in linelist:
        fields = line1.strip().split('\t')
        chr = fields[1]
        CorrespondenceDictMat[chr] = {}
        file = fields[0]
        inputstream = open(file)
        print chr, file
        ref = 0
        mat = 0
        deletion = False
        insertion = False
        for line in inputstream:
            fields = line.strip().split('\t')
            if line.startswith('#'):
                matFieldID = fields.index('MAT')
                continue
            refCurrent = int(fields[0])
            matCurrent = int(fields[matFieldID])
            if refCurrent != 0 and matCurrent != 0:
                if insertion:
                    CorrespondenceDictMat[chr][refCurrent] = []
                    for i in range(mat,matCurrent):
                        CorrespondenceDictMat[chr][refCurrent].append(i)
                    for i in range(ref,refCurrent):
                        if CoordinateDict[chr].has_key(i):
                            CorrespondenceDictMat[chr][i] = [mat - (refCurrent - i)]
                elif deletion:
                    for i in range(ref,refCurrent):
                        CorrespondenceDictMat[chr][i] = [matCurrent]
                else:
                    for i in range(ref,refCurrent):
                        if CoordinateDict[chr].has_key(i):
                            CorrespondenceDictMat[chr][i] = [matCurrent - (refCurrent - i)]
                    CorrespondenceDictMat[chr][refCurrent] = [matCurrent]
                ref = refCurrent
                mat = matCurrent
                deletion = False
                insertion = False
                continue
            if refCurrent == 0 and matCurrent != 0:
                insertion = True
                mat = matCurrent
                continue
            if refCurrent != 0 and matCurrent == 0:
                deletion = True
                for i in range(ref,refCurrent):
                    if CoordinateDict[chr].has_key(i):
                         CorrespondenceDictMat[chr][i] = [mat + (i - ref)]
                ref = refCurrent
                continue
            if refCurrent == 0 and matCurrent == 0:
                continue
        if max(CoordinateDict[chr].keys()) > refCurrent:
            for i in range(refCurrent, max(CoordinateDict[chr].keys()) + 10):
                if CoordinateDict[chr].has_key(i):
                    CorrespondenceDictMat[chr][i] = [matCurrent + (i - refCurrent)]


    outfile_mat = open(outfile_prefix + '_maternal.GTF', 'w')
    outfile_pat = open(outfile_prefix + '_paternal.GTF', 'w')

    Seen = {}

    linelist=open(GTF)
    for line in linelist:
        fields = line.strip().split('\t')
        chr = fields[0]
        if Seen.has_key(chr):
            pass
        else:
            Seen[chr]=''
            print chr
        if PresentDict.has_key(chr):
            pass
        else:
            continue
        left = int(fields[3])
        right = int(fields[4])
#        print left, right, CorrespondenceDictMat[chr][left], CorrespondenceDictMat[chr][right],  CorrespondenceDictPat[chr][left], CorrespondenceDictPat[chr][right], CorrespondenceDictMat[chr][left][-1], CorrespondenceDictMat[chr][right][-1],  CorrespondenceDictPat[chr][left][-1], CorrespondenceDictPat[chr][right][-1]
        matleft = CorrespondenceDictMat[chr][left][-1]
        matright = CorrespondenceDictMat[chr][right][-1]
        patleft = CorrespondenceDictPat[chr][left][-1]
        patright = CorrespondenceDictPat[chr][right][-1]
        outline = fields[0] + '\t' + fields[1] + '\t' + fields[2] + '\t' + str(matleft) + '\t' + str(matright) + '\t' + fields[5] + '\t' + fields[6] + '\t' + fields[7] + '\t' + fields[8]
        outfile_mat.write(outline + '\n')
        outline = fields[0] + '\t' + fields[1] + '\t' + fields[2] + '\t' + str(patleft) + '\t' + str(patright) + '\t' + fields[5] + '\t' + fields[6] + '\t' + fields[7] + '\t' + fields[8]
        outfile_pat.write(outline + '\n')

    outfile_mat.close()
    outfile_pat.close()

run()