##################################
#                                #
# Last modified 2018/09/25       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s TSS_screen.RPM.ratios.with_CFD.genes-only.table gene-logFC_fieldIDs TSS_screen.RPM.ratios.with_CFD.TSS-only.table TSS-logFC_fieldIDs gtf outfilename' % sys.argv[0]
        print '\tthe script accepts stdin as input'
        print '\tthe script will print to stdout'
        sys.exit(1)

    GGs = sys.argv[1]
    gFIDs = []
    for F in sys.argv[2].split(','):
        gFIDs.append(int(F))
    TGs = sys.argv[3]
    tFIDs = []
    for F in sys.argv[4].split(','):
        tFIDs.append(int(F))
    GTF = sys.argv[5]
    outfilename = sys.argv[6]

    TranscriptDict={}
    
    linelist = open(GTF)
    i=0
    for line in linelist:
        i+=1
        if i % 100000 == 0:
            print i, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2] != 'CDS' and fields[2] != 'exon':
            continue
        chr = fields[0]
        start = int(fields[3])
        stop = int(fields[4])
        strand = fields[6]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        if 'gene_name "' in fields[8]:
            geneName = fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName = geneID
        transcriptID = fields[8].split('transcript_id "')[1].split('";')[0]
        if TranscriptDict.has_key(geneName):
            pass
        else:
            TranscriptDict[geneID] = {}
        if TranscriptDict[geneID].has_key(transcriptID):
            pass
        else:
            TranscriptDict[geneID][transcriptID] = {}
            TranscriptDict[geneID][transcriptID]['CDS'] = []
            TranscriptDict[geneID][transcriptID]['exon'] = []
        TranscriptDict[geneName][transcriptID][fields[2]].append((chr,start,stop,strand))

    print 'finished reading GTF file'

    TSSDict={}

    for gene in TranscriptDict.keys():
        for transcriptID in TranscriptDict[gene].keys():
            TranscriptDict[gene][transcriptID]['exon'].sort()
            TranscriptDict[gene][transcriptID]['CDS'].sort()
            chr = TranscriptDict[gene][transcriptID]['exon'][0][0]
            strand = TranscriptDict[gene][transcriptID]['exon'][0][3]
            if strand == '+':
                TSS = (chr,TranscriptDict[gene][transcriptID]['exon'][0][1])
            if strand == '-':
                TSS = (chr,TranscriptDict[gene][transcriptID]['exon'][-1][2])
            if TSSDict.has_key(TSS):
                pass
            else:
                TSSDict[TSS] = {}
                TSSDict[TSS]['transcripts'] = []
                TSSDict[TSS]['genes'] = []
            TSSDict[TSS]['transcripts'].append(transcriptID)
            TSSDict[TSS]['genes'] = gene

    print 'finished parsing TSSs'

    GGDict = {}

    lineslist  = open(GGs)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        gene = fields[0].split('_')[1]
        guides = []
        for ID in gFIDs:
            guides.append(float(fields[ID]))
        mean = sum(guides)/len(guides)
        guides.append(mean)
        guides.append(math.fabs(mean))
        guides.reverse()
        if GGDict.has_key(gene):
            pass
        else:
            GGDict[gene] = []
        GGDict[gene].append(tuple(guides))
        GGDict[gene].sort()

    print 'finished inputing gene-level guide data'

    outfile = open(outfilename, 'w')

    lineslist  = open(TGs)
    for line in lineslist:
        if line.startswith('#'):
            outline = line.strip() + '\ttop_gene_guide-mean\trelative_effect_to_top_gene_guide\toverlaps_CDS\tgene\tDistance_to_TSS'
            outfile.write(outline + '\n')
            continue
        if line.strip() == '':
            continue
        fields = line.strip().split('\t')
        guide = fields[0]
        chr = guide.split('_')[0].split(':')[0]
        left = int(guide.split('_')[0].split(':')[1].split('-')[0])
        right = int(guide.split('_')[0].split(':')[1].split('-')[1])
        TSS = (chr,(right + left)/2)
        sgRNA = guide.split('_')[2]
        sgRNAleft = int(guide.split('_')[2].split(':')[1].split('-')[0])
        sgRNAright = int(guide.split('_')[2].split(':')[1].split('-')[1])
        sgRNAstrand = guide.split('_')[2].split(':')[2]
        if sgRNAstrand == '+':
            cutSite = sgRNAright - 3
            TSSdist = cutSite - TSS[1]
        if sgRNAstrand == '-':
            cutSite = sgRNAleft + 3
            TSSdist = TSS[1] - cutSite
        gene = TSSDict[TSS]['genes']
        guides = []
        for ID in tFIDs:
            guides.append(float(fields[ID]))
        mean = sum(guides)/len(guides)
        if GGDict.has_key(gene):
            GMean = GGDict[gene][-1][1]
            ratio = mean/GMean	
        else:
            GMean = 'nan'
            ratio = 'nan'
        InCDS = False
        for transcriptID in TranscriptDict[gene].keys():
            if InCDS:
                break
            for (chr,start,stop,strand) in TranscriptDict[gene][transcriptID]['CDS']:
                if start <= sgRNAright and stop >= sgRNAright or start <= sgRNAleft and stop >= sgRNAleft:
                    InCDS = True
                    break
        outline = line.strip() + '\t' + str(GMean) + '\t' + str(ratio) + '\t' + str(InCDS) + '\t' + gene + '\t' + str(TSSdist)
        outfile.write(outline + '\n')

    print 'finished inputing TSS-level guide data'

    outfile.close()

run()

