##################################
#                                #
# Last modified 04/02/2013       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
import random
import string
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s input_file.conf outfile_prefix' % sys.argv[0]
        print '\tNote: the input file should be formatted as follows:'
        print '\t\tLabel\tCircRNA.genes\tCircRNA.backsplices'
        print '\t\tCircRNA.genes format:'
        print '\t\t\t#GeneName\tTotal_Unique_Read_Pairs\tTotal_Unique_Circular_Pairs\tFraction_Unique\tTotal_Multi_Read_Pairs\tTotal_Multi_Circular_Pairs\tFraction_Multi'
        print '\t\tCircRNA.backsplices format:'
        print '\t\t\t#GeneName\tchr\tSplice\tbackSplice\tstrand\tReads\tShortest_Splce_exon_sequence\tShortest_backSplce_exon_sequence'
        print '\tNote: the script will use unique reads only for the denominator in the RPM calculation and it will use unique reads only for the denominator; if you want to run DESeq downstream; the suggested path for expression analysis would be to run eXpress on the data then combine with the output for the backsplices'
        sys.exit(1)

    input = sys.argv[1]
    outfile_prefix = sys.argv[2]

    labels = []
    DataSetDict = {}
    lineslist = open(input)
    for line in lineslist:
        fields=line.strip().split('\t')
        label = fields[0]
        labels.append(label)
        genes = fields[1]
        backsplices = fields[2]
        DataSetDict[label]={}
        DataSetDict[label]['files'] = (genes,backsplices)
        DataSetDict[label]['total_reads'] = 0

    ReadCountsDict = {}
    BackSpliceDict = {}

    for label in DataSetDict.keys():
        (genes,backsplices) = DataSetDict[label]['files']
        linelist = open(genes)
        TotalReads = 0
        for line in linelist:
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            gene = fields[0]
            if gene == '*':
                continue
            uniquereads = int(fields[1])
            multireads = int(fields[4])
            TotalReads += uniquereads
            if ReadCountsDict.has_key(gene):
                pass
            else:
                ReadCountsDict[gene] = {}
            ReadCountsDict[gene][label] = uniquereads
        DataSetDict[label]['total_reads'] = TotalReads
        linelist = open(backsplices)
        TotalReads = 0
        for line in linelist:
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            gene = fields[0]
            if BackSpliceDict.has_key(gene):
                pass
            else:
                BackSpliceDict[gene] = {}
            chr = fields[1]
            left = fields[2]
            right = fields[3]
            strand = fields[4]
            reads = int(fields[5])
            splice = fields[6]
            try:
                bsplice = fields[7]
            except:
                print 'skipping', line.strip()
                continue
            backsplice = (chr,left,right,strand,splice,bsplice)
            if BackSpliceDict[gene].has_key(backsplice):
                pass
            else:
                BackSpliceDict[gene][backsplice]={}
            BackSpliceDict[gene][backsplice][label] = reads

    outfile = open(outfile_prefix + '.RPM.genes', 'w') 
    outline = '#gene'

    labels.sort()
    for label in labels:
        outline = outline + '\t' + label + '_unique_reads' + '\t' + label + '_unique_reads_RPM'
    outfile.write(outline + '\n')

    genes = ReadCountsDict.keys()
    genes.sort()

    for gene in genes:
        outline = gene
        for label in labels:
            if ReadCountsDict[gene].has_key(label):
                reads = ReadCountsDict[gene][label]
                if BackSpliceDict.has_key(gene):
                    for backsplice in BackSpliceDict[gene].keys():
                        if BackSpliceDict[gene][backsplice].has_key(reads):
                            reads = reads - BackSpliceDict[gene][backsplice]
                RPM = (reads + 0.0) / (DataSetDict[label]['total_reads']/1000000.)
            else:
                reads = 0
                RPM = 0
            outline = outline + '\t' + str(reads) + '\t' + str(RPM)
        outfile.write(outline + '\n')

    outfile.close()

    outfile = open(outfile_prefix + '.RPM.backsplices', 'w') 
    outline = '#gene\tchr\tsplice\tbacksplice\tstrand'

    labels.sort()
    for label in labels:
        outline = outline + '\t' + label + '_unique_reads' + '\t' + label + '_unique_reads_RPM' + '\t' + label + '_fraction_of_total_unique_gene_reads'
    outfile.write(outline + '\tShortest_Exon_splice\tShortest_Exon_backsplice\n')

    genes = BackSpliceDict.keys()
    genes.sort()

    for gene in genes:
        for backsplice in BackSpliceDict[gene].keys():
            outline = gene + '\t' + backsplice[0] + '\t' + backsplice[1] + '\t' + backsplice[2] + '\t' + backsplice[3]
            for label in labels:
                if BackSpliceDict[gene][backsplice].has_key(label):
                    reads = BackSpliceDict[gene][backsplice][label]
                    geneReads = ReadCountsDict[gene][label]
                    RPM = (reads + 0.0) / (DataSetDict[label]['total_reads']/1000000.)
                    if reads == 0 and geneReads == 0:
                        fraction = 'NaN'
                    elif reads > 0 and geneReads == 0:
                        fraction = 'Inf'
                    else:
                        fraction = (reads + 0.0)/(geneReads + 0.0)
                else:
                    reads = 0
                    RPM = 0
                    fraction = 0
                outline = outline + '\t' + str(reads) + '\t' + str(RPM) + '\t' + str(fraction)
            outline = outline  + '\t' + backsplice[4] + '\t' + backsplice[5]
            outfile.write(outline + '\n')

    outfile.close()
run()

