##################################
#                                #
# Last modified 11/11/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set


def run():

    if len(sys.argv) < 7:
        print 'usage: python %s <gtf #1> <transcritp.expr #1> <gtf #2> <transcritp.expr #2> <minimum overlap for single exon transcripts, bp> <minimum overlap for single exon transcripts, fraction> outputfilename [-onesidedoverlap]' % sys.argv[0]
        print '       Note: if the -onesidedoverlap option is used, a minimal overlap for single-exon transcripts will be required in only one of the replicates instead of in both' 

        sys.exit(1)
    
    cachePages = 2000000

    gtf1 = sys.argv[1]
    gtf2 = sys.argv[3]
    expr1 = sys.argv[2]
    expr2 = sys.argv[4]
    minOverlapBases = int(sys.argv[5])
    minOverlapFraction = float(sys.argv[6])
    outfilename = sys.argv[7]
    doOR=False
    if '-onesidedoverlap' in sys.argv:
        doOR=True

    TranscriptDict1={}
    TranscriptDict2={}

    InitialTranscriptDict1={}
    InitialTranscriptDict2={}

    ExprDict1={}
    ExprDict2={}

    print 'processing gtf file', gtf1
    linelist = open(gtf1)
    for line in linelist:
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        orientation=fields[6]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if InitialTranscriptDict1.has_key(chr):
            pass
        else:
            InitialTranscriptDict1[chr]={}
        if InitialTranscriptDict1[chr].has_key(transcriptID):
            pass
        else:
            InitialTranscriptDict1[chr][transcriptID]=[]
        InitialTranscriptDict1[chr][transcriptID].append((left,right,orientation))

    print 'processing gtf file', gtf2
    linelist = open(gtf2)
    for line in linelist:
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        left=int(fields[3])
        right=int(fields[4])
        orientation=fields[6]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        if InitialTranscriptDict2.has_key(chr):
            pass
        else:
            InitialTranscriptDict2[chr]={}
        if InitialTranscriptDict2[chr].has_key(transcriptID):
            pass
        else:
            InitialTranscriptDict2[chr][transcriptID]=[]
        InitialTranscriptDict2[chr][transcriptID].append((left,right,orientation))

    for chr in InitialTranscriptDict1.keys():
        TranscriptDict1[chr]={}
        TranscriptDict1[chr]['singleexon']={}
        TranscriptDict1[chr]['multiexon']={}
        for ID in InitialTranscriptDict1[chr].keys():
            if len(InitialTranscriptDict1[chr][ID])==1:
                TranscriptDict1[chr]['singleexon'][ID]=InitialTranscriptDict1[chr][ID]
            else:
                TranscriptDict1[chr]['multiexon'][ID]=[]
                orientation=InitialTranscriptDict1[chr][ID][0][2]
                TranscriptDict1[chr]['multiexon'][ID].append(orientation)
                firstSpliceSite=InitialTranscriptDict1[chr][ID][0][1]
                TranscriptDict1[chr]['multiexon'][ID].append(firstSpliceSite)
                for i in range(1,len(InitialTranscriptDict1[chr][ID])-1):
                    rightPos=InitialTranscriptDict1[chr][ID][i][1]
                    TranscriptDict1[chr]['multiexon'][ID].append(rightPos)

    InitialTranscriptDict1={}

    for chr in InitialTranscriptDict2.keys():
        TranscriptDict2[chr]={}
        TranscriptDict2[chr]['singleexon']={}
        TranscriptDict2[chr]['multiexon']={}
        for ID in InitialTranscriptDict2[chr].keys():
            if len(InitialTranscriptDict2[chr][ID])==1:
                TranscriptDict2[chr]['singleexon'][ID]=InitialTranscriptDict2[chr][ID]
            else:
                TranscriptDict2[chr]['multiexon'][ID]=[]
                orientation=InitialTranscriptDict2[chr][ID][0][2]
                TranscriptDict2[chr]['multiexon'][ID].append(orientation)
                firstSpliceSite=InitialTranscriptDict2[chr][ID][0][1]
                TranscriptDict2[chr]['multiexon'][ID].append(firstSpliceSite)
                for i in range(1,len(InitialTranscriptDict2[chr][ID])-1):
                    rightPos=InitialTranscriptDict2[chr][ID][i][1]
                    TranscriptDict2[chr]['multiexon'][ID].append(rightPos)
   
    InitialTranscriptDict2={}

    print 'processing expression file', expr1
    linelist = open(expr1)
    for line in linelist:
        fields=line.strip().split('\t')
        ID=fields[0]
        FPKM=fields[5]
        FPKM_lo=fields[8]
        FPKM_hi=fields[9]
        ExprDict1[ID]=(FPKM,FPKM_lo,FPKM_hi)

    print 'processing expression file', expr2
    linelist = open(expr2)
    for line in linelist:
        fields=line.strip().split('\t')
        ID=fields[0]
        FPKM=fields[5]
        FPKM_lo=fields[8]
        FPKM_hi=fields[9]
        ExprDict2[ID]=(FPKM,FPKM_lo,FPKM_hi)

    CorrespondanceDict1to2={}
    CorrespondanceDict2to1={}


    chromosomes=TranscriptDict1.keys()
    chromosomes.sort()

    for chr in chromosomes:
        print chr 
        try:
            print 'multiexon', len(TranscriptDict1[chr]['multiexon'].keys()), len(TranscriptDict2[chr]['multiexon'].keys())
            print 'singleexon', len(TranscriptDict1[chr]['singleexon'].keys()), len(TranscriptDict2[chr]['singleexon'].keys())
        except:
            pass
        for ID1 in TranscriptDict1[chr]['multiexon'].keys():
            Replicated=False
            try:
                for ID2 in TranscriptDict2[chr]['multiexon'].keys():
                    if TranscriptDict1[chr]['multiexon'][ID1]==TranscriptDict2[chr]['multiexon'][ID2]:
                        CorrespondanceDict1to2[ID1]=ID2
                        CorrespondanceDict2to1[ID2]=ID1
                        Replicated=True
                        del TranscriptDict2[chr]['multiexon'][ID2]
                        break
            except:
                print 'no', chr, 'in TranscriptDict2, multiexonic'
                continue  
            if Replicated==False:
                CorrespondanceDict1to2[ID1]='-'
        for ID1 in TranscriptDict1[chr]['singleexon'].keys():
            Replicated=False
            try:
                leftPos1=int(TranscriptDict1[chr]['singleexon'][ID1][0][0])
                rightPos1=int(TranscriptDict1[chr]['singleexon'][ID1][0][1])
            except:
                print 'no', chr, 'in TranscriptDict2, singleexonic'
                continue
            strand1=TranscriptDict1[chr]['singleexon'][ID1][0][2]
            try:
                for ID2 in TranscriptDict2[chr]['singleexon'].keys():
                    pass
            except:
                print 'no', chr, 'in TranscriptDict2, singleexonic'
                continue
            for ID2 in TranscriptDict2[chr]['singleexon'].keys():
                leftPos2=int(TranscriptDict2[chr]['singleexon'][ID2][0][0])
                rightPos2=int(TranscriptDict2[chr]['singleexon'][ID2][0][1])
                strand2=TranscriptDict2[chr]['singleexon'][ID2][0][2]
                if leftPos1 > rightPos1 or rightPos1 < leftPos2 or strand1 != strand2:
                    continue
                else:
                    if (leftPos2 >= leftPos1 and leftPos2 <= rightPos1) or (leftPos1 >= leftPos2 and leftPos1 <= rightPos2):
                        length1=0.+rightPos1-leftPos1
                        length2=0.+rightPos2-leftPos2
                        overlap=min(rightPos1,rightPos2)-max(leftPos1,leftPos2)
                        if overlap >= minOverlapBases:
                            if doOR:
                                if overlap/length1 >= minOverlapFraction or overlap/length2 >= minOverlapFraction:
                                    CorrespondanceDict1to2[ID1]=ID2
                                    CorrespondanceDict2to1[ID2]=ID1
                                    Replicated=True
                                    del TranscriptDict2[chr]['singleexon'][ID2]
                                    break
                            else:
                                if overlap/length1 >= minOverlapFraction and overlap/length2 >= minOverlapFraction:
                                    CorrespondanceDict1to2[ID1]=ID2
                                    CorrespondanceDict2to1[ID2]=ID1
                                    Replicated=True
                                    del TranscriptDict2[chr]['singleexon'][ID2]
                                    break
            if Replicated==False:
                CorrespondanceDict1to2[ID1]='-'
   
    for chr in TranscriptDict2.keys():
        for ID2 in TranscriptDict2[chr]['multiexon'].keys():
            CorrespondanceDict2to1[ID2]='-'
        for ID2 in TranscriptDict2[chr]['singleexon'].keys():
            CorrespondanceDict2to1[ID2]='-'

    outfile = open(outfilename, 'w')

    outline='ID1\tFPKM_#1\tFPKM_lo_#1\tFPKM_hi_#1\tID2\tFPKM_#2\tFPKM_lo_#2\tFPKM_hi_#2\t'
    outfile.write(outline+'\n')
    for ID1 in CorrespondanceDict1to2.keys():
        outline=ID1
        ID2=CorrespondanceDict1to2[ID1]
        FPKM1=ExprDict1[ID1]
        if ID2=='-':
            FPKM2=('-1','-1','-1')
        else:
            FPKM2=ExprDict2[ID2]
        outline=outline+'\t'+FPKM1[0]+'\t'+FPKM1[1]+'\t'+FPKM1[2]+'\t'+ID2+'\t'+FPKM2[0]+'\t'+FPKM2[1]+'\t'+FPKM2[2]
        outfile.write(outline+'\n')

    for ID2 in CorrespondanceDict2to1.keys():
        ID1=CorrespondanceDict2to1[ID2]
        if ID1!='-':
            continue
        outline='-\t-1\t-1\t-1\t'
        FPKM2=ExprDict2[ID2]
        outline=outline+ID2+'\t'+FPKM2[0]+'\t'+FPKM2[1]+'\t'+FPKM2[2]
        outfile.write(outline+'\n')

    
    outfile.close()
            
run()

