##################################
#                                #
# Last modified 07/20/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s cuffdiff_file foldchange_threshold element_type outputfilename' % sys.argv[0]
        print '       element_type options: genes | isoforms | CDS | isoform_switch | splicing | promoter | TSS_group'
        sys.exit(1)

    inputfilename = sys.argv[1]
    threshold = float(sys.argv[2])
    type = sys.argv[3]
    outputfilename = sys.argv[4]

    outfile = open(outputfilename, 'w')

    ResultsDict={}

    lineslist = open(inputfilename)
    i=0
    if type == 'genes' or type == 'isoforms':
        for line in lineslist:
            i+=1
            if i % 1000000 == 0:
                print i, 'lines processed'
            fields=line.strip().split('\t')
            sample1=fields[4]
            sample2=fields[5]
            status=fields[6]
            if status != 'OK':
                continue
            significant=fields[13]
            if significant != 'yes':
                continue
            if ResultsDict.has_key(sample1):
                pass
            else:
                ResultsDict[sample1]={}
                ResultsDict[sample1][sample1]=0
            if ResultsDict.has_key(sample2):
                pass
            else:
                ResultsDict[sample2]={}
                ResultsDict[sample2][sample2]=0
            if ResultsDict[sample1].has_key(sample2):
                pass
            else:
                ResultsDict[sample1][sample2]=0
            if ResultsDict[sample2].has_key(sample1):
                pass
            else:
                ResultsDict[sample2][sample1]=0
            try:
                foldchange=math.exp(float(fields[9]))
            except:
                foldchange=float(fields[9])
            if foldchange > threshold:
                ResultsDict[sample2][sample1]+=1
            if foldchange < 1/threshold:
                ResultsDict[sample1][sample2]+=1
        outline='#Sample'
        keys=ResultsDict.keys()
        keys.sort()
        for sample in keys:
            outline=outline+'\t'+sample
        outfile.write(outline+'\n')
        for sample1 in keys:
            outline=sample1
            for sample2 in keys:
                outline=outline+'\t'+str(ResultsDict[sample1][sample2])
            outfile.write(outline+'\n')
    if type == 'isoform_switch':
        IsoformDict={}
        for line in lineslist:
            i+=1
            if i % 1000000 == 0:
                print i, 'lines processed'
            fields=line.strip().split('\t')
            sample1=fields[3]
            sample2=fields[4]
            status=fields[5]
            if status != 'OK':
                continue
            pvalue=fields[11]
            if pvalue != 'yes':
                continue
            if ResultsDict.has_key(sample1):
                pass
            else:
                ResultsDict[sample1]={}
                IsoformDict[sample1]={}
                ResultsDict[sample1][sample1]=0
            if ResultsDict.has_key(sample2):
                pass
            else:
                ResultsDict[sample2]={}
                IsoformDict[sample2]={}
                ResultsDict[sample2][sample2]=0
            if ResultsDict[sample1].has_key(sample2):
                pass
            else:
                ResultsDict[sample1][sample2]=0
                IsoformDict[sample1][sample2]={}
            if ResultsDict[sample2].has_key(sample1):
                pass
            else:
                ResultsDict[sample2][sample1]=0
                IsoformDict[sample2][sample1]={}
            try:
                foldchange=math.exp(float(fields[8]))
            except:
                foldchange=float(fields[8])
            gene=fields[1]
            transcript=fields[2]
            if foldchange > threshold:
                if IsoformDict[sample2][sample1].has_key(gene):
                    pass
                else:
                    IsoformDict[sample2][sample1][gene]={}
                IsoformDict[sample2][sample1][gene][transcript]=''
            if foldchange < 1/threshold:
                if IsoformDict[sample1][sample2].has_key(gene):
                    pass
                else:
                    IsoformDict[sample1][sample2][gene]={}
                IsoformDict[sample1][sample2][gene][transcript]=''
        keys=ResultsDict.keys()
        keys.sort()
        outfile.write('#Genes with isoform switches\n')
        outline='#Sample'
        for sample in keys:
            outline=outline+'\t'+sample
        outfile.write(outline+'\n')
        for sample1 in keys:
            outline=sample1
            for sample2 in keys:
                if sample2==sample1:
                    outline=outline+'\t0'
                    continue
                switched_genes=0
                for gene in IsoformDict[sample1][sample2].keys():
                    if IsoformDict[sample2][sample1].has_key(gene):
                        switched_genes+=1
                        ResultsDict[sample1][sample2]+=len(IsoformDict[sample1][sample2][gene].keys())
                outline=outline+'\t' + str(switched_genes)
            outfile.write(outline+'\n')
        outfile.write('#Isoform switched\n')
        outline='#Sample'
        for sample in keys:
            outline=outline+'\t'+sample
        outfile.write(outline+'\n')
        for sample1 in keys:
            outline=sample1
            for sample2 in keys:
                outline=outline+'\t'+str(ResultsDict[sample1][sample2])
            outfile.write(outline+'\n')

    outfile.close()

run()

