##################################
#                                #
# Last modified 03/23/2014       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import random
import string

def run():

    if len(sys.argv) < 8:
        print 'usage: python %s GTFcomparison assemblyFieldID referenceFieldID FPKMtable transcriptFieldID FPKMFieldID minFPKM outfilename' % sys.argv[0]
        print '\tThe following GTFcomparison format ia assumed:'
        print '\t#TranscriptID1\tNumExons\tTranscriptID2\tNumExons\t5end_distance\t3end_distance\tPartialOverlapTranscript1\tPartialOverlapExonsTranscript1\tPartialOverlapTranscript2\tPartialOverlapExonsTranscript2'
        print '\tThe FPKM table refers to the table with simulated transcript FPKMs'
        sys.exit(1)

    GTFcomparison = sys.argv[1]
    assemblyFieldID = int(sys.argv[2])
    referenceFieldID = int(sys.argv[3])
    FPKMtable = sys.argv[4]
    transcriptFieldID = int(sys.argv[5])
    FPKMFieldID = int(sys.argv[6])
    minFPKM = float(sys.argv[7])
    outputfilename = sys.argv[8]

    AssemblyDict = {}
    ReferenceDict = {}

    linelist = open(GTFcomparison)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        assemblyID = fields[assemblyFieldID]
        referenceID = fields[referenceFieldID]
        if assemblyID != '-':
            if AssemblyDict.has_key(assemblyID):
                pass
            else:
                AssemblyDict[assemblyID] = {}
                AssemblyDict[assemblyID]['matches'] = {}
                AssemblyDict[assemblyID]['partials'] = {}
        if referenceID != '-':
            if ReferenceDict.has_key(referenceID):
                pass
            else:
                ReferenceDict[referenceID] = {}
                ReferenceDict[referenceID]['FPKM'] = 0
                ReferenceDict[referenceID]['match'] = 0
        if referenceID != '-' and assemblyID != '-':
            AssemblyDict[assemblyID]['matches'][referenceID] = 1
            ReferenceDict[referenceID]['match'] = 1
        if referenceID != '-' and assemblyID == '-':
#            print referenceFieldID + 6, fields[referenceFieldID + 6], line.strip()
            if fields[referenceFieldID + 6] != '-':
                partialAssemblyID = fields[referenceFieldID + 6]
                if AssemblyDict.has_key(partialAssemblyID):
                    pass
                else:
                    AssemblyDict[partialAssemblyID] = {}
                    AssemblyDict[partialAssemblyID]['matches'] = {}
                    AssemblyDict[partialAssemblyID]['partials'] = {}
                AssemblyDict[partialAssemblyID]['partials'][referenceID] = 1

    print 'finished inputting GTF comparison'

    linelist = open(FPKMtable)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        transcript = fields[transcriptFieldID]
        FPKM = float(fields[FPKMFieldID])
        if ReferenceDict.has_key(transcript):
            ReferenceDict[transcript]['FPKM'] = FPKM

    print 'finished inputting FPKMs'

    FalsePositives = 0
    Partials = 0
    FalseNegatives = 0
    ExpressedAndAssembledAboveMinFPKM = 0
    ExpressedAndAssembled = 0
    NotExpressedAndAssembled = 0

    for transcript in AssemblyDict:
        if len(AssemblyDict[transcript]['matches']) == 0:
            if len(AssemblyDict[transcript]['partials'].keys()) == 0:
                FalsePositives += 1
            else:
                Partials += 1
        
    for transcript in ReferenceDict:
        if ReferenceDict[transcript]['match'] == 1:
            if ReferenceDict[transcript]['FPKM'] == 0:
                NotExpressedAndAssembled += 1
            else:
                ExpressedAndAssembled += 1
            if ReferenceDict[transcript]['FPKM'] > minFPKM:
                ExpressedAndAssembledAboveMinFPKM += 1
        else:
            if ReferenceDict[transcript]['FPKM'] > minFPKM:
                FalseNegatives += 1

    outfile = open(outputfilename,'w')
    outfile.write('#Class\tnumber\n')

    outfile.write('False Positives' + '\t' + str(FalsePositives) + '\n')
    outfile.write('Partials' + '\t' + str(Partials) + '\n')
    outfile.write('False Negatives' + '\t' + str(FalseNegatives) + '\n')
    outfile.write('Expressed and Assembled Above MinFPKM' + '\t' + str(ExpressedAndAssembledAboveMinFPKM) + '\n')
    outfile.write('Expressed and Assembled' + '\t' + str(ExpressedAndAssembled) + '\n')
    outfile.write('Not Expressed and Assembled' + '\t' + str(NotExpressedAndAssembled) + '\n')

    outfile.close()

run()

