##################################
#                                #
# Last modified 05/12/2015       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import os

DNADict = {'A':'T','T':'A','G':'C','C':'G','N':'N','K':'K','Y':'Y','S':'S','W':'W','M':'M','R':'R'}

def run():

    if len(sys.argv) < 13:
        print 'usage: python %s config RTT <prok|euk> GTF ORFs genomic_fasta RNA-seq-FPKM_file RNA-seq-geneIDFieldID RNA-seq-FPKMfieldID RF-FPKM_file RF-geneIDFieldID RF-FPKMfieldID outfile [-ATP]' % sys.argv[0]
        print '\tNote: the script works with bibitem lists only'
        print '\tNote: RTT parameter: R-- will only calculate Replication cost. RT- will only calulate replicaiton and transcription cost, etc.'
        sys.exit(1)

    doATP = False
    if '-ATP' in sys.argv:
        doATP = True

    config = sys.argv[1]
    RTT = sys.argv[2]
    type = sys.argv[3]
    GTF = sys.argv[4]
    ORF = sys.argv[5]
    fasta = sys.argv[6]
    RNAFPKM = sys.argv[7]
    RNAgeneFieldID = int(sys.argv[8])
    RNAFPKMFieldID = int(sys.argv[9])
    RFFPKM = sys.argv[10]
    RFgeneFieldID = int(sys.argv[11])
    RFFPKMFieldID = int(sys.argv[12])
    outfilename = sys.argv[13]

    ConfigDict = {}
    linelist = open(config)
    for line in linelist:
        if line.startswith('#') or line.strip() == '':
            continue
        fields = line.strip().split('\t')
        q = fields[0]
        print fields
        if q in ['H1','H2A','H2B','H3','H4']:
            if len(fields) == 1:
                ConfigDict[q] = ''
            else:
                ConfigDict[q] = fields[1]
        else:
            if len(fields) == 1 or fields[1] == '':
                ConfigDict[q] = 0
            else:
                ConfigDict[q] = float(fields[1])

    ORFDict = {}

    linelist = open(ORF)
    for line in linelist:
        if line.startswith('#') or line.strip() == '':
            continue
        fields = line.strip().split('\t')
        geneID = fields[0]
        if ORFDict.has_key(geneID):
            pass
        else:
            ORFDict[geneID] = {}
        if len(fields) < 12:
            ORF == ''
        else:
            ORF = fields[11]
        transcriptID = fields[2]
        ORFDict[geneID][transcriptID] = ORF

    RNAFPKMDict = {}
    TotalRNAFPKM = 0

    linelist = open(RNAFPKM)
    for line in linelist:
        if line.startswith('#') or line.strip() == '' or line.startswith('tracking_id'):
            continue
        fields = line.strip().split('\t')
        geneID = fields[RNAgeneFieldID]
        if ORFDict.has_key(geneID):
            pass
        else:
            continue
        FPKM = float(fields[RNAFPKMFieldID])
        RNAFPKMDict[geneID] = FPKM
        TotalRNAFPKM += FPKM

    AverageRNAFPKM = TotalRNAFPKM/len(RNAFPKMDict)

    RFFPKMDict = {}
    TotalRFFPKM = 0

    linelist = open(RFFPKM)
    for line in linelist:
        if line.startswith('#') or line.strip() == '' or line.startswith('tracking_id'):
            continue
        fields = line.strip().split('\t')
        geneID = fields[RFgeneFieldID]
        if ORFDict.has_key(geneID):
            pass
        else:
            continue
        FPKM = float(fields[RFFPKMFieldID])
        RFFPKMDict[geneID] = FPKM
        TotalRFFPKM += FPKM

    AverageRFFPKM = TotalRFFPKM/len(RFFPKMDict)

    GeneDict = {}

    linelist = open(GTF)
    for line in linelist:
        if line.startswith('#') or line.strip() == '':
            continue
        fields = line.strip().split('\t')
        if fields[2] != 'exon':
            continue
        geneID=fields[8].split('gene_id "')[1].split('"')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('"')[0]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('"')[0]
        else:
            geneName = geneID
        chr = fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        if GeneDict.has_key(geneID):
            pass
        else:
            GeneDict[geneID]={}
            GeneDict[geneID]['gene_name'] = geneName
        if GeneDict[geneID].has_key(transcriptID):
            pass
        else:
            GeneDict[geneID][transcriptID] = []
        GeneDict[geneID][transcriptID].append((chr,left,right,strand))

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence)
            chr = line.strip().split('>')[1]
            print chr
            sequence=[]
            continue
        else:
            sequence.append(line.strip().upper())
    GenomeDict[chr] = ''.join(sequence)

    genes = ORFDict.keys()
    genes.sort()

    ConfigDict['nuc:N'] = (ConfigDict['nuc:A'] + ConfigDict['nuc:G'] + ConfigDict['nuc:T'] + ConfigDict['nuc:C'])/4
    ConfigDict['nuc:K'] = (ConfigDict['nuc:A'] + ConfigDict['nuc:G'] + ConfigDict['nuc:T'] + ConfigDict['nuc:C'])/4
    ConfigDict['nuc:Y'] = (ConfigDict['nuc:A'] + ConfigDict['nuc:G'] + ConfigDict['nuc:T'] + ConfigDict['nuc:C'])/4
    ConfigDict['nuc:S'] = (ConfigDict['nuc:A'] + ConfigDict['nuc:G'] + ConfigDict['nuc:T'] + ConfigDict['nuc:C'])/4
    ConfigDict['nuc:W'] = (ConfigDict['nuc:A'] + ConfigDict['nuc:G'] + ConfigDict['nuc:T'] + ConfigDict['nuc:C'])/4
    ConfigDict['nuc:R'] = (ConfigDict['nuc:A'] + ConfigDict['nuc:G'] + ConfigDict['nuc:T'] + ConfigDict['nuc:C'])/4
    ConfigDict['nuc:M'] = (ConfigDict['nuc:A'] + ConfigDict['nuc:G'] + ConfigDict['nuc:T'] + ConfigDict['nuc:C'])/4

    if type == 'euk':
       SCH1 = 0
       SCH2B = 0
       SCH2A = 0
       SCH3 = 0
       SCH4 = 0
       for AA in ConfigDict['H1']:
           SCH1 += ConfigDict['aa:' + AA]
       for AA in ConfigDict['H2A']:
           SCH2A += ConfigDict['aa:' + AA]
       for AA in ConfigDict['H2B']:
           SCH2B += ConfigDict['aa:' + AA]
       for AA in ConfigDict['H3']:
           SCH3 += ConfigDict['aa:' + AA]
       for AA in ConfigDict['H4']:
           SCH4 += ConfigDict['aa:' + AA]
       SCH4 += ConfigDict['Translation_initiation']
       SCH4 += ConfigDict['Translation_termination']
#       SCH4 += len(ConfigDict['H4'])*ConfigDict['Translation_degradation_per_AA']
       SCH4 += len(ConfigDict['H4'])*ConfigDict['Translation_Elongation_per_AA']
       SCH3 += ConfigDict['Translation_initiation']
       SCH3 += ConfigDict['Translation_termination']
#       SCH3 += len(ConfigDict['H3'])*ConfigDict['Translation_degradation_per_AA']
       SCH3 += len(ConfigDict['H3'])*ConfigDict['Translation_Elongation_per_AA']
       SCH2A += ConfigDict['Translation_initiation']
       SCH2A += ConfigDict['Translation_termination']
#       SCH2A += len(ConfigDict['H2A'])*ConfigDict['Translation_degradation_per_AA']
       SCH2A += len(ConfigDict['H2A'])*ConfigDict['Translation_Elongation_per_AA']
       SCH2B += ConfigDict['Translation_initiation']
       SCH2B += ConfigDict['Translation_termination']
#       SCH2B += len(ConfigDict['H2B'])*ConfigDict['Translation_degradation_per_AA']
       SCH2B += len(ConfigDict['H2B'])*ConfigDict['Translation_Elongation_per_AA']
       SCH1 += ConfigDict['Translation_initiation']
       SCH1 += ConfigDict['Translation_termination']
#       SCH1 += len(ConfigDict['H1'])*ConfigDict['Translation_degradation_per_AA']
       SCH1 += len(ConfigDict['H1'])*ConfigDict['Translation_Elongation_per_AA']

    outfile = open(outfilename,'w')
    outline = '#geneID\tgeneName\tgeneLength\tLongestORFLength\tLongestORFTranscriptLength\tLongestORFNumberIntrons\tRNA_level\tNumber_transcripts'
    outline =  outline + '\tRF_level\tNumber_proteins\tReplicationCost\tTranscriptionCost\tTranslationCost\tTotalCost\tReplicationCostFraction\tTranscriptionCostFraction'
    outline =  outline + '\tTranslationCostFraction\tReplication_Nucleotides\tReplication_Processive\tReplication_Okazaki\tReplication_Repair'
    outline =  outline + '\tReplication_Histones\tTranscription_Nucleotides\tTranscription_Activaiton\tCostTranscription_Abortive\tTranscription_Initiation\tTranscription_Processive'
    outline =  outline + '\tTranscription_Cap\tTranscription_CTD\tTranscription_Export\tTranscription_Splicing\tTranscription_Histone\tTranscription_PolyA\tCostTranscription_Readthrough'
    outline =  outline + '\tTranslation_AA\tTranslation_Processive\tTranslation_Initiation\tTranslation_Termination\tTranslation_Degradation'
    outfile.write(outline + '\n')

    for geneID in genes:
        if RFFPKMDict.has_key(geneID):
            pass
        else:
            continue
        try:
            RNAFPKM = RNAFPKMDict[geneID]
        except:
            print 'problem with', geneID, 'skipping'
            continue

        CostReplication_Nucleotides = 0
        CostReplication_Processive = 0
        CostReplication_Okazaki = 0
        CostReplication_Repair = 0
        CostReplication_Histones = 0

        CostTranscription_Nucleotides = 0
        CostTranscription_Activaiton = 0
        CostTranscription_Abortive = 0
        CostTranscription_Initiation = 0
        CostTranscription_Processive = 0
        CostTranscription_Cap = 0
        CostTranscription_CTD = 0
        CostTranscription_Export = 0
        CostTranscription_Splicing = 0
        CostTranscription_Histone = 0
        CostTranscription_PolyA = 0
        CostTranscription_Readthrough = 0

        CostTranslation_AA = 0
        CostTranslation_Processive = 0
        CostTranslation_Initiation = 0
        CostTranslation_Termination = 0
        CostTranslation_Degradation = 0

        NumberRNAs = (RNAFPKM/TotalRNAFPKM)*ConfigDict['Total_Transcripts_per_cell']
        RFFPKM = RFFPKMDict[geneID]
        NumberProteins = (RFFPKM/TotalRFFPKM)*ConfigDict['Total_Proteins_per_cell']
#        print geneID, NumberProteins
        longestTranscript = ('',0)
        longestORF = ('',0)
        coordinates = []
        for transcriptID in GeneDict[geneID]:
            TL = 0
            if transcriptID == 'gene_name':
                continue
            for (chr,left,right,strand) in GeneDict[geneID][transcriptID]:
                coordinates.append(left)
                coordinates.append(right)
                TL += (right-left)
            if TL >= longestTranscript[1]:
                longestTranscript = (transcriptID,TL)
            if ORFDict[geneID].has_key(transcriptID):
                pass
            else:
                continue
            if len(ORFDict[geneID][transcriptID]) > longestORF[1]:
                longestORF = (transcriptID,len(ORFDict[geneID][transcriptID]))
        geneLeft = min(coordinates)
        geneRight = max(coordinates)

        if RTT[0] == '-':
            ReplicationCost = '-'
        if RTT[0] == 'R':
            ReplicationCost = 0
            for i in range(geneLeft,geneRight):
                ReplicationCost += (ConfigDict['nuc:' + GenomeDict[chr][i]] + ConfigDict['nuc:' + DNADict[GenomeDict[chr][i]]] + 2*1.5)
                CostReplication_Nucleotides += (ConfigDict['nuc:' + GenomeDict[chr][i]] + ConfigDict['nuc:' + DNADict[GenomeDict[chr][i]]] + 2*1.5)
            if doATP:
                ReplicationCost += 3*(geneRight - geneLeft)
                CostReplication_Processive += 3*(geneRight - geneLeft)
                ReplicationCost += 1*((geneRight - geneLeft)/ConfigDict['Okazaki_fragment_length'])*ConfigDict['Replication_primer_length']
                CostReplication_Okazaki += 1*((geneRight - geneLeft)/ConfigDict['Okazaki_fragment_length'])*ConfigDict['Replication_primer_length']
            else:
                ReplicationCost += 6*(geneRight - geneLeft)
                CostReplication_Processive += 6*(geneRight - geneLeft)
                ReplicationCost += 2*((geneRight - geneLeft)/ConfigDict['Okazaki_fragment_length'])*ConfigDict['Replication_primer_length']
                CostReplication_Okazaki += 2*((geneRight - geneLeft)/ConfigDict['Okazaki_fragment_length'])*ConfigDict['Replication_primer_length']
            ReplicationCost += (geneRight - geneLeft)*ConfigDict['DNA_repair_per_bp']
            CostReplication_Repair += (geneRight - geneLeft)*ConfigDict['DNA_repair_per_bp']
            if type == 'euk':
                ReplicationCost += ((geneRight - geneLeft)/(ConfigDict['Nucleosome_length'] + ConfigDict['Linker_length']))*(SCH1 + 2*(SCH2A + SCH2B + SCH3 + SCH4))
                CostReplication_Histones += ((geneRight - geneLeft)/(ConfigDict['Nucleosome_length'] + ConfigDict['Linker_length']))*(SCH1 + 2*(SCH2A + SCH2B + SCH3 + SCH4))

        if RTT[1] == '-':
            TranscriptionCost = '-'
        else:
            TranscriptionCost = 0
            transcriptID = longestORF[0]
            GeneDict[geneID][transcriptID].sort()
            LongestORFTranscriptLength = 0
            for (chr,left,right,strand) in GeneDict[geneID][transcriptID]:
                LongestORFTranscriptLength += (right - left)
                for i in range(left,right):
                    if strand == '+':
                        TranscriptionCost += NumberRNAs*ConfigDict['nuc:' + GenomeDict[chr][i]]
                        CostTranscription_Nucleotides += NumberRNAs*ConfigDict['nuc:' + GenomeDict[chr][i]]
                    if strand == '-':
                        TranscriptionCost += NumberRNAs*ConfigDict['nuc:' + DNADict[GenomeDict[chr][i]]]
                        CostTranscription_Nucleotides += NumberRNAs*ConfigDict['nuc:' + DNADict[GenomeDict[chr][i]]]
            TranscriptionRate = (RNAFPKM/AverageRNAFPKM)*ConfigDict['Average_Transcription_rate']
            TranscriptionCycles = ConfigDict['time']*TranscriptionRate
            TranscriptionCost += TranscriptionCycles*ConfigDict['Transcription_activation']/ConfigDict['initiations_per_transcription_cycle']
            CostTranscription_Activaiton += TranscriptionCycles*ConfigDict['Transcription_activation']/ConfigDict['initiations_per_transcription_cycle']
            TranscriptionCost += TranscriptionCycles*ConfigDict['Transcription_initiation']
            CostTranscription_Initiation += TranscriptionCycles*ConfigDict['Transcription_initiation']
            TranscriptionCost += TranscriptionCycles*ConfigDict['Abortive_transcripts']
            CostTranscription_Abortive += TranscriptionCycles*ConfigDict['Abortive_transcripts']
            left = GeneDict[geneID][transcriptID][0][1]
            right = GeneDict[geneID][transcriptID][-1][2]
            TranscriptionCost += TranscriptionCycles*(right - left + ConfigDict['readthrough_length'])*ConfigDict['Transcription_elongation']
            CostTranscription_Processive += TranscriptionCycles*(right - left)*ConfigDict['Transcription_elongation']
            CostTranscription_Readthrough += TranscriptionCycles*(ConfigDict['readthrough_length'])*ConfigDict['Transcription_elongation']
            if type == 'euk':
                TranscriptionCost += NumberRNAs*ConfigDict['nuc:G']
                CostTranscription_Cap += NumberRNAs*ConfigDict['nuc:G']
                TranscriptionCost += NumberRNAs*ConfigDict['PolyA_length']*ConfigDict['nuc:A']
                CostTranscription_PolyA += NumberRNAs*ConfigDict['PolyA_length']*ConfigDict['nuc:A']
                TranscriptionCost += TranscriptionCycles*ConfigDict['mRNA_capping']
                CostTranscription_Cap += TranscriptionCycles*ConfigDict['mRNA_capping']
                TranscriptionCost += TranscriptionCycles*ConfigDict['CTD_cost']
                CostTranscription_CTD += TranscriptionCycles*ConfigDict['CTD_cost']
                TranscriptionCost += TranscriptionCycles*ConfigDict['PolyA_length']*ConfigDict['Transcription_elongation']
                CostTranscription_PolyA += TranscriptionCycles*ConfigDict['PolyA_length']*ConfigDict['Transcription_elongation']
                TranscriptionCost += TranscriptionCycles*ConfigDict['mRNA_export']
                CostTranscription_Export += TranscriptionCycles*ConfigDict['mRNA_export']
                TranscriptionCost += TranscriptionCycles*(len(GeneDict[geneID][transcriptID]) - 1)*ConfigDict['Splicing_cost']
                CostTranscription_Splicing += TranscriptionCycles*(len(GeneDict[geneID][transcriptID]) - 1)*ConfigDict['Splicing_cost']
                TranscriptionCost += TranscriptionCycles*((right - left + ConfigDict['readthrough_length'])/(ConfigDict['Nucleosome_length'] + ConfigDict['Linker_length']))*ConfigDict['Elongation_histone_modificiations']
                CostTranscription_Histone += TranscriptionCycles*((right - left + ConfigDict['readthrough_length'])/(ConfigDict['Nucleosome_length'] + ConfigDict['Linker_length']))*ConfigDict['Elongation_histone_modificiations']

        if RTT[2] == '-':
            TranslationCost = '-'
        else:
            TranslationCost = 0
            transcriptID = longestORF[0]
            TranslationRate = (RFFPKM/AverageRFFPKM)*ConfigDict['Average_Translation_rate']
            TranslationCycles = ConfigDict['time']*TranslationRate
            for AA in ORFDict[geneID][transcriptID]:
                TranslationCost += NumberProteins*ConfigDict['aa:' + AA]
                CostTranslation_AA += NumberProteins*ConfigDict['aa:' + AA]
            if type == 'euk':
                TranslationCost += TranslationCycles*ConfigDict['Translation_mRNA_remodelling']/ConfigDict['Translation_rounds_per_mRNA_remodelling']
                CostTranslation_Initiation += TranslationCycles*ConfigDict['Translation_mRNA_remodelling']/ConfigDict['Translation_rounds_per_mRNA_remodelling']
            TranslationCost += TranslationCycles*ConfigDict['Translation_initiation']
            CostTranslation_Initiation += TranslationCycles*ConfigDict['Translation_initiation']
            TranslationCost += TranslationCycles*ConfigDict['Translation_termination']
            CostTranslation_Termination += TranslationCycles*ConfigDict['Translation_termination']
            TranslationCost += (TranslationCycles - NumberProteins)*len(ORFDict[geneID][transcriptID])*ConfigDict['Translation_degradation_per_AA']
            if TranslationCycles < NumberProteins:
                print 'translation cycles fewere than the number of proteins, exiting'
                sys.exit(1)
            CostTranslation_Degradation += (TranslationCycles - NumberProteins)*len(ORFDict[geneID][transcriptID])*ConfigDict['Translation_degradation_per_AA']
            TranslationCost += TranslationCycles*len(ORFDict[geneID][transcriptID])*ConfigDict['Translation_Elongation_per_AA']
            CostTranslation_Processive += TranslationCycles*len(ORFDict[geneID][transcriptID])*ConfigDict['Translation_Elongation_per_AA']

        geneName = GeneDict[geneID]['gene_name']
        outline = geneID + '\t' + geneName + '\t' + str(geneRight - geneLeft) + '\t' + str(longestORF[1]) + '\t' + str(LongestORFTranscriptLength)
        outline = outline + '\t' + str((len(GeneDict[geneID][transcriptID]) - 1)) + '\t' + str(RNAFPKM) + '\t' + str(NumberRNAs) + '\t' + str(RFFPKM) + '\t' + str(NumberProteins) + '\t' + str(ReplicationCost) + '\t' + str(TranscriptionCost) + '\t' + str(TranslationCost)
        TotalCost = ReplicationCost + TranscriptionCost + TranslationCost
        outline = outline + '\t' + str(TotalCost) + '\t' + str(ReplicationCost/TotalCost) + '\t' + str(TranscriptionCost/TotalCost) + '\t' + str(TranslationCost/TotalCost)
        outline = outline + '\t' + str(CostReplication_Nucleotides)
        outline = outline + '\t' + str(CostReplication_Processive)
        outline = outline + '\t' + str(CostReplication_Okazaki)
        outline = outline + '\t' + str(CostReplication_Repair)
        outline = outline + '\t' + str(CostReplication_Histones)
        outline = outline + '\t' + str(CostTranscription_Nucleotides)
        outline = outline + '\t' + str(CostTranscription_Activaiton)
        outline = outline + '\t' + str(CostTranscription_Abortive)
        outline = outline + '\t' + str(CostTranscription_Initiation)
        outline = outline + '\t' + str(CostTranscription_Processive)
        outline = outline + '\t' + str(CostTranscription_Cap)
        outline = outline + '\t' + str(CostTranscription_CTD)
        outline = outline + '\t' + str(CostTranscription_Export)
        outline = outline + '\t' + str(CostTranscription_Splicing)
        outline = outline + '\t' + str(CostTranscription_Histone)
        outline = outline + '\t' + str(CostTranscription_PolyA)
        outline = outline + '\t' + str(CostTranscription_Readthrough)
        outline = outline + '\t' + str(CostTranslation_AA)
        outline = outline + '\t' + str(CostTranslation_Processive)
        outline = outline + '\t' + str(CostTranslation_Initiation)
        outline = outline + '\t' + str(CostTranslation_Termination)
        outline = outline + '\t' + str(CostTranslation_Degradation)
        outfile.write(outline + '\n')

    outfile.close()


run()

