##################################
#                                #
# Last modified 10/31/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s  ExpressionTable fieldIDs FPKM_threshold outfile_prefix [-minDatasets number]' % sys.argv[0]
        print '	fieldIDs should be comma separated'
        sys.exit(1)

    datafilename = sys.argv[1]
    fields = sys.argv[2].split(',')
    IDfields=[]
    for ID in fields:
        IDfields.append(int(ID))
    minFPKM = float(sys.argv[3])
    outfileprefix = sys.argv[4]

    doMinDatasets=False
    if '-minDatasets' in sys.argv:
        doMinDatasets=True
        minDataSets=int(sys.argv[sys.argv.index('-minDatasets')+1])
  
    LabelDict={}
    GeneDict={}

    GeneralStatsDict={}

    lineslist  = open(datafilename)
    i=0
    for line in lineslist:
        if i % 10000 == 0:
            print i, 'lines processed'
        i=i+1
        if line[0]=='#':
            header=line
            fields=line.strip().split('\t')
            for ID in IDfields:
                LabelDict[ID]=fields[ID]
                GeneralStatsDict[fields[ID]]={}
            continue
        else:
            fields = line.strip().split('\t')
            gene=fields[0].split('-')[0]
            transcript=fields[0]
            if GeneDict.has_key(gene):
                pass
            else:
                GeneDict[gene]={}
                GeneDict[gene]['transcripts']={}
                GeneDict[gene]['lines']={}
            GeneDict[gene]['transcripts'][transcript]={}
            GeneDict[gene]['lines'][transcript]=line
            for ID in IDfields:
                GeneDict[gene]['transcripts'][transcript][LabelDict[ID]]=float(fields[ID])
            AboveThresholdIn=0
            for dataset in GeneDict[gene]['transcripts'][transcript].keys():
                if GeneDict[gene]['transcripts'][transcript][dataset] >= minFPKM:
                    AboveThresholdIn+=1
            if AboveThresholdIn >= minDataSets:
                pass
            else:
                for dataset in GeneDict[gene]['transcripts'][transcript].keys():
                    GeneDict[gene]['transcripts'][transcript][dataset] = 0
                 
                
    Labels=GeneralStatsDict.keys()
    Labels.sort()

    outfileGeneralStats=open(outfileprefix+'.GeneralStats','w')
    for label in Labels:
        for gene in GeneDict.keys():
            isoforms=0
            for transcript in GeneDict[gene]['transcripts'].keys():
                if GeneDict[gene]['transcripts'][transcript][label] >= minFPKM:
                    isoforms+=1
                if GeneralStatsDict[label].has_key(isoforms):
                    pass
                else:
                    GeneralStatsDict[label][isoforms]=0
            GeneralStatsDict[label][isoforms]+=1

    isoformCounsList=[]
    for label in Labels:
        for isoforms in GeneralStatsDict[label].keys():
            isoformCounsList.append(isoforms)
    isoformCounsList=list(Set(isoformCounsList))
    for isoform in isoformCounsList:
        for label in Labels:
            if GeneralStatsDict[label].has_key(isoform):
                pass
            else:
                GeneralStatsDict[label][isoform]=0
    outline='#Isoforms_Number'
    for label in Labels:
        outline=outline+'\t'+label
    outfileGeneralStats.write(outline+'\n')
    for isoform in isoformCounsList:
        outline=str(isoform)
        for label in Labels:
            outline=outline+'\t'+str(GeneralStatsDict[label][isoform])
        outfileGeneralStats.write(outline+'\n')
    
    outfileGeneralStats.close()

    print 'finished outputting general stats'

    outfileMultipleIsoformGenes=open(outfileprefix+'.MultipleIsoformGenes','w')
    outfileMultipleIsoformTranscripts=open(outfileprefix+'.MultipleIsoformTranscripts','w')

    outfileMultipleIsoformTranscripts.write(header)

    for gene in GeneDict.keys():
        expressedIsoforms=[]
        for transcript in GeneDict[gene]['transcripts'].keys():
            for label in Labels:
                if GeneDict[gene]['transcripts'][transcript][label]>=minFPKM:
                    expressedIsoforms.append(GeneDict[gene]['lines'][transcript])
        expressedIsoforms=list(Set(expressedIsoforms))
        if len(expressedIsoforms)>1:
            outline=gene
            outfileMultipleIsoformGenes.write(outline+'\n')
            for line in expressedIsoforms:
                outfileMultipleIsoformTranscripts.write(line)

    outfileMultipleIsoformGenes.close()
    outfileMultipleIsoformTranscripts.close()
                
    print 'finished outputting multiple isoforms genes'

    outfileIsoformSwitches=open(outfileprefix+'.IsoformSwitches','w')
    outfileIsoformSwitchesGenes=open(outfileprefix+'.IsoformSwitchesGenes','w')

    outfileIsoformSwitches.write(header)
    for gene in GeneDict.keys():
        if len(GeneDict[gene]['transcripts'])==1:
            continue
        topIsoforms=[]
        for label in Labels:
            max=0
            maxIsoform=''
            for transcript in GeneDict[gene]['transcripts'].keys():
                if GeneDict[gene]['transcripts'][transcript][label] >= minFPKM and GeneDict[gene]['transcripts'][transcript][label] >= max:
                    max=GeneDict[gene]['transcripts'][transcript][label]
                    maxIsoform=transcript
            if maxIsoform != '':
                topIsoforms.append(maxIsoform)
        topIsoforms=list(Set(topIsoforms))
        if len(topIsoforms) > 1:
            for transcript in topIsoforms:
                outfileIsoformSwitches.write(GeneDict[gene]['lines'][transcript])
            outfileIsoformSwitchesGenes.write(gene+'\n')
       
    print 'finished outputting instances of isoform switch'

    outfileIsoformSwitches.close()
    outfileIsoformSwitchesGenes.close()

        
run()

