##################################
#                                #
# Last modified 2023/09/24       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import pysam

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s GTF outputfilename [-stranded + | -] [-strandSum] [-TPM]' % sys.argv[0]
        print '       Note: the script is designed to work with Cufflinks GTF output'
        sys.exit(1)
    
    GTF = sys.argv[1]
    outfilename = sys.argv[2]

    doStranded=False
    if '-stranded' in sys.argv:
         doStranded=True
         KeepStrand=sys.argv[sys.argv.index('-stranded')+1]
         print 'will only consider', strand, 'strand reads'

    doStrandSum = False
    if '-strandSum' in sys.argv:
         print 'will output an integrated +/- stranded scores'
         doStrandSum = True

    doTPM = False
    if '-TPM' in sys.argv:
         print 'will use TPM values instead of FPKM'
         doTPM = True

    DataDict={}

    i=0
    linelist=open(GTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if doTPM:
            if fields[2] != 'transcript':
                continue
        else:
            if fields[2] != 'exon':
                continue
        chr = fields[0]
        if DataDict.has_key(chr):
            pass
        else:
            DataDict[chr]=[]
        DataDict[chr].append(fields)

    print 'finished inputting GTF'

    outfile = open(outfilename, 'w')

    keys=DataDict.keys()
    keys.sort()
    for chr in keys:
        print chr
        coverageDict={}
        for fields in DataDict[chr]:
            left = int(fields[3])
            right = int(fields[4])
            strand = fields[6]
            if doStranded and strand != KeepStrand:
                continue
            if doTPM:
                score = float(fields[8].split('TPM "')[1].split('";')[0])
            else:
                score = float(fields[8].split('FPKM "')[1].split('";')[0])
#            print chr, left, right, strand, FPKM
            if doStrandSum:
                if strand == '-':
                    FPKM = -1*score
                else:
                    FPKM = score
#                print score, FPKM, strand
            else:
                FPKM = score
            if FPKM != 0.0:
                for j in range(left,right):
                    if coverageDict.has_key(j):
                        coverageDict[j] += FPKM
                    else:
                        coverageDict[j] = FPKM 
        posKeys=coverageDict.keys()
        posKeys.sort()
        if len(posKeys) == 0:
            continue
        initial=(posKeys[0],coverageDict[posKeys[0]])
        previous=(posKeys[0],coverageDict[posKeys[0]])
        written=['']
        for i in posKeys[1:len(posKeys)]:
            if previous[0]+1 == i and previous[1]==coverageDict[i]:
                 previous=(i,coverageDict[i])
            else:
                 if written[0]==initial[0]:
                     print written, initial, previous
                 if doStranded and strand == '-':
                     outline=chr+'\t'+str(initial[0])+'\t'+str(previous[0]+1)+'\t-'+str(initial[1])
                 else:
                     outline=chr+'\t'+str(initial[0])+'\t'+str(previous[0]+1)+'\t'+str(initial[1])
                 written=(initial[0],previous[0]+1)
                 outfile.write(outline+'\n')
                 initial=(i,coverageDict[i])
                 previous=(i,coverageDict[i])

    outfile.close()
            
run()
