##################################
#                                #
# Last modified 03/19/2009       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s inputfilename <class1,class2,...,classN> <ID1:gtf1,ID2:gtf2,....,IDN:gtfN> minPresent outfilename [-minLoConfRPKM num]' % sys.argv[0]
        sys.exit(1)

    inputfilename = sys.argv[1]
    CCclass = sys.argv[2]
    ID = sys.argv[3]
    minPresent = int(sys.argv[4])
    outputfilename = sys.argv[5]
    CCclasses=CCclass.split(',')
    gtfDict={}
    IDfields=ID.split(',')
    for field in IDfields:
        print field
        ID=field.split(':')[0]
        gtf=field.split(':')[1]
        gtfDict[ID]=gtf
    doMinConfLo=False
    if '-minLoConfRPKM' in sys.argv:
        doMinConfLo=True
        minConfLo=float(sys.argv[sys.argv.index('-minLoConfRPKM')+1])

    found={}
    listoflines = open(inputfilename)
    for line in listoflines:
        fields=line.strip().split('\t')
        if fields[3] not in CCclasses:
            continue
        present=0
        for i in range(4,4+len(IDfields)):
            if fields[i]!='-':
                foundID=fields[i].split('|')[1]
                if doMinConfLo:
                    if float(fields[i].split('|')[4]) >=  minConfLo:
                        present+=1
                else:
                    present+=1
        if present < minPresent:
            continue
        ID=foundID.split('.')[0]
        if found.has_key(ID):
            found[ID][foundID]=''
        else:
            found[ID]={}
            found[ID][foundID]=''
        if fields[3]=='j':
            geneID=fields[2].split('|')[1]
            geneName=fields[2].split('|')[0]
            found[ID][foundID]=(geneID,geneName)

    outfile = open(outputfilename, 'w')

    for ID in found.keys():
        print 'processing', gtfDict[ID]
        linelist=open(gtfDict[ID])
        for line in linelist:
            fields=line.strip().split('\t')
            transcriptID=fields[8].split('transcript_id "')[1].split('"')[0]
            if found[ID].has_key(transcriptID):
                if len(found[ID][transcriptID])==2:
                    geneID=found[ID][transcriptID][0]
                    geneName=found[ID][transcriptID][1]
                    transcriptName=geneName+'-'+transcriptID
                    outline=line.split('gene_id "')[0]
                    newfields=fields[8].split('"')
                    newfields[newfields.index('gene_id ')+1]=geneID
                    newfields[newfields.index('gene_id ')+1]=geneID
                    for field in newfields:
                        outline=outline+field+'"'
                    outline=outline[0:-1]+' gene_name "'+geneName+'"; transcript_name "'+transcriptName+'";\n'
                else:
                    outline=line
                outfile.write(outline)
            
    print len(found)

    outfile.close()

run()

