##################################
#                                #
# Last modified 03/06/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s list_of_input_files threshold cuffdiff_file ' % sys.argv[0]
        print '       input format: TSS-Distance-Filename <tab> Outfile-Prefix'
        print '       TSS-Distance files derived from ERANGE calls and refFlat files are assumed, where the gene names are in the form of Name::refSeqID'
        sys.exit(1)

    inputfilename = sys.argv[1]
    threshold = float(sys.argv[2])
    cuffdiff = sys.argv[3]

    SampleComparisonDict={}
    GeneralStatsDict={}

    ResultsDict={}
    lineslist = open(cuffdiff)
    i=0
    for line in lineslist:
        i+=1
        if i % 1000000 == 0:
            print i, 'lines processed'
        if line.startswith('test_id'):
            continue
        fields=line.strip().split('\t')
        ID=fields[0]
        sample1=fields[3]
        sample2=fields[4]
        SampleComparisonDict[(sample1,sample2)]=''
        if GeneralStatsDict.has_key((sample1,sample2)):
            pass
        else:
            GeneralStatsDict[(sample1,sample2)]={}
            GeneralStatsDict[(sample1,sample2)]['up']=0
            GeneralStatsDict[(sample1,sample2)]['down']=0
        if ResultsDict.has_key(ID):
            pass
        else:
            ResultsDict[ID]={}
        status=fields[5]
        if status != 'OK':
            continue
        pvalue=fields[11]
        if pvalue != 'yes':
            continue
        try:
            foldchange=math.exp(float(fields[8]))
        except:
            foldchange=float(fields[8])
        if foldchange > threshold:
            ResultsDict[ID][(sample1,sample2)]=foldchange
            GeneralStatsDict[(sample1,sample2)]['up']+=1
        if foldchange < 1/threshold:
            ResultsDict[ID][(sample1,sample2)]=foldchange
            GeneralStatsDict[(sample1,sample2)]['down']+=1
        else:
            continue

    SampleComparisonList=SampleComparisonDict.keys()
    SampleComparisonList.sort()

    lineslist = open(inputfilename)
    for line1 in lineslist:
        fields=line1.strip().split('\t')
        print fields
        StatsDict={}
        for (sample1,sample2) in SampleComparisonList:
            StatsDict[(sample1,sample2)]={}
            StatsDict[(sample1,sample2)]['up']=0
            StatsDict[(sample1,sample2)]['down']=0
            StatsDict[(sample1,sample2)]['flat']=0
        outfile1=open(fields[1]+'.individual','w')
        outfile2=open(fields[1]+'.generalStats','w')
        lines=open(fields[0])
        outline='#TSSDistance\tGeneName\tGeneID\tregionID\tchrom\tstart\tstop\tRPM\tfold\tmulti%\tplus%\tleftPlus%\tpeakPos\tpeakHeight\tpValue'
        for (sample1,sample2) in SampleComparisonList:
            outline=outline+'\t'+sample1+'::'+sample2
        outfile1.write(outline+'\n')
        outline='#Sampl1\tSampl2\tTotalUp\tTotalDown\tBoundUp\tBoundDown\tBoundNoChange'
        outfile2.write(outline+'\n')
        for line in lines:
            fields=line.strip().split('\t')
            TSSDistance=fields[0]
            name=fields[1].split('::')[0]
            ID=fields[1].split('::')[1]
            regionID=fields[2]
            chrom=fields[3]
            start=fields[4]
            stop=fields[5]
            RPM=fields[6]
            fold=fields[7]
            multi=fields[8]
            plus=fields[9]
            leftPlus=fields[10]
            peakPos=fields[11]
            peakHeight=fields[12]
            pValue=fields[13]
            outline=TSSDistance+'\t'+name+'\t'+ID+'\t'+regionID+'\t'+chrom+'\t'+start+'\t'+stop+'\t'+RPM+'\t'+fold+'\t'+multi+'\t'+plus+'\t'+leftPlus+'\t'+peakPos+'\t'+peakHeight+'\t'+pValue
            for (sample1,sample2) in SampleComparisonList:
                if ResultsDict.has_key(ID) and ResultsDict[ID].has_key((sample1,sample2)):
                    outline=outline+'\t'+str(ResultsDict[ID][(sample1,sample2)])
                    if ResultsDict[ID][(sample1,sample2)] > 1:
                        StatsDict[(sample1,sample2)]['up']+=1
                    if ResultsDict[ID][(sample1,sample2)] < 1:
                        StatsDict[(sample1,sample2)]['down']+=1
                else:
                    outline=outline+'\t'+'-'
                    StatsDict[(sample1,sample2)]['flat']+=1
            outfile1.write(outline+'\n')
        for (sample1,sample2) in SampleComparisonList:
            outline = sample1 + '\t' + sample2 +'\t' + str(GeneralStatsDict[(sample1,sample2)]['up']) + '\t' + str(GeneralStatsDict[(sample1,sample2)]['down']) + '\t' + str(StatsDict[(sample1,sample2)]['up']) + '\t' + str(StatsDict[(sample1,sample2)]['down']) + '\t' + str(StatsDict[(sample1,sample2)]['flat'])
            outfile2.write(outline+'\n')
        outfile1.close()
        outfile2.close()

run()

