##################################
#                                #
# Last modified 2017/08/24       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s circ_rna_table gtf gene_expression_table geneIDfieldID geneNameFieldID valuefields outputfilename [-MatchesOnly] [-maxCircLen bp]' % sys.argv[0]
        print '\tcirc_rna_table format:'
        print '\tchr <tab> left <tab> right <tab> strand'
        sys.exit(1)
    
    circTable = sys.argv[1]

#    circvalueFields=[]
#    if ':' in sys.argv[2]:
#        fields = sys.argv[2].split(':')
#        start = int(fields[0])
#        end = int(fields[1])
#        for f in range(start,end+1):
#            circvalueFields.append(f)
#    else:
#        fields = sys.argv[2].split(',')
#        for f in fields:
#            circvalueFields.append(int(f))
#    circvalueFields.sort()
#
#    print 'circ count fields:', circvalueFields

    gtf = sys.argv[2]
    expression_table = sys.argv[3]
    geneIDfieldID = int(sys.argv[4])
    geneNamefieldID = int(sys.argv[5])

    expvalueFields=[]
    if ':' in sys.argv[6]:
        fields = sys.argv[6].split(':')
        start = int(fields[0])
        end = int(fields[1])
        for f in range(start,end+1):
            expvalueFields.append(f)
    else:
        fields = sys.argv[6].split(',')
        for f in fields:
            expvalueFields.append(int(f))
    expvalueFields.sort()

    print 'expression fields:', expvalueFields

    outfilename = sys.argv[7]

    doMaxCircLen = False
    if '-maxCircLen' in sys.argv:
        doMaxCircLen = True
        maxCircLen = int(sys.argv[sys.argv.index('-maxCircLen') + 1])
        print 'will filter out circRNAs longer than', maxCircLen

    doMO = False
    if '-MatchesOnly' in sys.argv:
        doMO = True
        print 'will only output circRNAs with matches'

    ExonBoundariesDict = {}
    GeneDict = {}

    lineslist = open(gtf)
    TranscriptDict={}
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        if 'gene_name "' in fields[8]:
            geneName = fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName = fields[8].split('gene_id "')[1].split('";')[0]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        if ExonBoundariesDict.has_key(chr):
            pass
        else:
            ExonBoundariesDict[chr] = {}
        if ExonBoundariesDict[chr].has_key((left,strand)):
            pass
        else:
            ExonBoundariesDict[chr][(left,strand)] = []
        if ExonBoundariesDict[chr].has_key((right,strand)):
            pass
        else:
            ExonBoundariesDict[chr][(right,strand)] = []
        ExonBoundariesDict[chr][(right,strand)].append((geneID,geneName))
        ExonBoundariesDict[chr][(left,strand)].append((geneID,geneName))
        if GeneDict.has_key(chr):
            pass
        else:
            GeneDict[chr] = {}
        if GeneDict[chr].has_key((geneID,geneName)):
            pass
        else:
            GeneDict[chr][(geneID,geneName)] = {}
            GeneDict[chr][(geneID,geneName)]['strand'] = strand
            GeneDict[chr][(geneID,geneName)]['coordinates'] = []
        GeneDict[chr][(geneID,geneName)]['coordinates'].append(left)
        GeneDict[chr][(geneID,geneName)]['coordinates'].append(right)

    print 'finished parsing gtf'

    NewGeneDict = {}
    for chr in GeneDict.keys():
        NewGeneDict[chr] = {}
        NewGeneDict[chr]['+'] = []
        NewGeneDict[chr]['-'] = []
        for (geneID,geneName) in GeneDict[chr].keys():
            left = min(GeneDict[chr][(geneID,geneName)]['coordinates'])
            right = max(GeneDict[chr][(geneID,geneName)]['coordinates'])
            strand = GeneDict[chr][(geneID,geneName)]['strand']
            NewGeneDict[chr][strand].append((left,right,geneID,geneName))
        NewGeneDict[chr]['+'].sort()
        NewGeneDict[chr]['-'].sort()

    print 'finished sorting genes'

    linelist = open(expression_table)
    ExpFields = []
    ExpDict = {}
    for line in linelist:
        fields = line.strip().split('\t')
        if line.startswith('#'):
            for ID in expvalueFields:
                ExpFields.append(fields[ID])
            continue
        if line.startswith('tracking_id'):
            continue
        geneID = fields[geneIDfieldID]
        geneName = fields[geneNamefieldID]
        values = []
        for ID in expvalueFields:
            values.append(float(fields[ID]))
        ExpDict[(geneID,geneName)] = values

    print 'finished parsing expression'

    outfile = open(outfilename,'w')

    linelist = open(circTable)
    for line in linelist:
        fields = line.strip().split('\t')
        if line.startswith('#'):
            outline = line.strip() + '\tgeneID\tgeneName'
            for sample in ExpFields:
                outline = outline + '\t' + sample
            outfile.write(outline + '\n')
            continue
        chr = fields[0]
        left = int(fields[1])
        right = int(fields[2])
        if doMaxCircLen:
            if right - left > maxCircLen:
                continue
        strand = fields[3]
        if ExonBoundariesDict.has_key(chr):
            if ExonBoundariesDict[chr].has_key((left,strand)):
                genesLeft = ExonBoundariesDict[chr][(left,strand)]
            else:
                genesLeft = []
            if ExonBoundariesDict[chr].has_key((right,strand)):
                genesRight = ExonBoundariesDict[chr][(right,strand)]
            else:
                genesRight = []
            genesRight = list(Set(genesRight))
            genesLeft = list(Set(genesLeft))
            if len(genesRight) == 0 and len(genesLeft) == 0:
                candidateGenes = []
                for (gleft,gright,geneID,geneName) in NewGeneDict[chr][strand]:
                    if gleft > right:
                        break
                    if (left >= gleft and left <= gright) or (right >= gleft and right <= gright) or (left <= gleft and right >= gright):
                        candidateGenes.append((gleft,gright,geneID,geneName))
                candidateGenes = list(Set(candidateGenes))
                if len(candidateGenes) > 0:
                    candidateGenesFull = []
                    candidateGenesPartial = []
                    for (gleft,gright,geneID,geneName) in candidateGenes:
                        if gleft <= left and gright >= right:
                            candidateGenesFull.append((gleft,gright,geneID,geneName))
                        else:
                            candidateGenesPartial.append((gleft,gright,geneID,geneName))
                    if len(candidateGenesFull) > 0:
                        for (gleft,gright,geneID,geneName) in candidateGenesFull:
                            if ExpDict.has_key((geneID,geneName)):
                                outline = line.strip() + '\t' + geneID + '\t' + geneName
                                for TPM in ExpDict[(geneID,geneName)]:
                                    outline = outline + '\t' + str(TPM)
#                                if (chr,left,right) == ('chr1',10037607,10048800):
#                                    print 'step1'
                                outfile.write(outline + '\n')
                            else:
                                print geneID, geneName, 'not found in expression file, skipping'
                    else:
                        for (gleft,gright,geneID,geneName) in candidateGenesPartial:
                            if ExpDict.has_key((geneID,geneName)):
                                outline = line.strip() + '\t' + geneID + '\t' + geneName
                                for TPM in ExpDict[(geneID,geneName)]:
                                    outline = outline + '\t' + str(TPM)
#                                if (chr,left,right) == ('chr1',10037607,10048800):
#                                    print 'step2'
                                outfile.write(outline + '\n')
                            else:
                                print geneID, geneName, 'not found in expression file, skipping'
                else:
                    if doMO:
                        continue
                    else:
                        outline = line.strip() + '\t-\t-'
                        for sample in ExpFields:
                            outline = outline + '\t' + '-'
                        outfile.write(outline + '\n')
            elif genesLeft == genesRight:
                if ExpDict.has_key((geneID,geneName)):
                    (geneID,geneName) = genesRight[0]
                    outline = line.strip() + '\t' + geneID + '\t' + geneName
                    for TPM in ExpDict[(geneID,geneName)]:
                        outline = outline + '\t' + str(TPM)
#                    if (chr,left,right) == ('chr1',10037607,10048800):
#                        print 'step4'
                    outfile.write(outline + '\n')
                else:
                    print geneID, geneName, 'not found in expression file, skipping'
            else:
                commongenes = list(Set(genesLeft).intersection(Set(genesRight)))
                if len(commongenes) > 0:
                    for (geneID,geneName) in commongenes:
                        if ExpDict.has_key((geneID,geneName)):
                            outline = line.strip() + '\t' + geneID + '\t' + geneName
                            for TPM in ExpDict[(geneID,geneName)]:
                                outline = outline + '\t' + str(TPM)
#                            if (chr,left,right) == ('chr1',10037607,10048800):
#                                print 'step5'
                            outfile.write(outline + '\n')
                        else:
                            print geneID, geneName, 'not found in expression file, skipping'
                else:
                    for (geneID,geneName) in genesRight:
                        if ExpDict.has_key((geneID,geneName)):
                            outline = line.strip() + '\t' + geneID + '\t' + geneName
                            for TPM in ExpDict[(geneID,geneName)]:
                                outline = outline + '\t' + str(TPM)
#                            if (chr,left,right) == ('chr1',10037607,10048800):
#                                print 'step6'
                            outfile.write(outline + '\n')
                        else:
                            print geneID, geneName, 'not found in expression file, skipping'
                    for (geneID,geneName) in genesLeft:
                        if ExpDict.has_key((geneID,geneName)):
                            outline = line.strip() + '\t' + geneID + '\t' + geneName
                            for TPM in ExpDict[(geneID,geneName)]:
                                outline = outline + '\t' + str(TPM)
#                            if (chr,left,right) == ('chr1',10037607,10048800):
#                                print 'step7'
                            outfile.write(outline + '\n')
                        else:
                            print geneID, geneName, 'not found in expression file, skipping'
        else:
            if doMO:
                continue
            else:
                outline = line.strip() + '\t-\t-'
                for sample in ExpFields:
                    outline = outline + '\t' + '-'
                outfile.write(outline + '\n')

    outfile.close()
   
run()
