##################################
#                                #
# Last modified 09/11/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set
import scipy.stats

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s list_of_ASE_files GTF min_reads outfile' % sys.argv[0]
        print '\tAssumed list_of_ASE_files format: label\tASE_file_name\tfraction_field'
        print '\tAssumed ASE_format: #geneName\tgeneID\tchr\t129S1_collapsed_reads\tC57BL_collapsed_reads\t129S1_fraction\tC57BL_fraction\tp-value'
        sys.exit(1)

    ASEfiles=sys.argv[1]
    GTF=sys.argv[2]
    minReads=int(sys.argv[3])
    outfilename = sys.argv[4]

    GeneDict={}

    outfile = open(outfilename, 'w')

    linelist=open(GTF)
    for line in linelist:
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        left = int(fields[3])
        right = int(fields[4])
        strand=fields[6]
        geneID=fields[8].split('gene_id "')[1].split('"')[0]
        if 'gene_name "' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('"')[0]
        else:
            geneName = geneID
        if GeneDict.has_key((geneID,geneName)):
            pass
        else:
            GeneDict[(geneID,geneName)]=[]
        GeneDict[(geneID,geneName)].append((chr,left,right,strand))

    print 'finished parsing GTF'

    ASEDict = {}

    labels = []

    linelistA=open(ASEfiles)
    for lineA in linelistA:
        fieldsA=lineA.strip().split('\t')
        label = fieldsA[0]
        print label
        labels.append(label)
        ASE = fieldsA[1]
        fractionFieldID = int(fieldsA[2])
        linelist=open(ASE)
        for line in linelist:
            if line.startswith('#'):
                continue
            fields=line.strip().split('\t')
            geneName = fields[0]
            geneID = fields[1]
            reads = int(fields[3]) + int(fields[4])
            if ASEDict.has_key((geneID,geneName)):
                pass
            else:
                ASEDict[(geneID,geneName)] = {}
            if reads < minReads:
                ASEDict[(geneID,geneName)][label] = 'NS'
            else:
                ASEDict[(geneID,geneName)][label] = float(fields[fractionFieldID])

    outline = '#geneID\tgeneName\tchr\tpos'
    for label in labels:
        outline = outline + '\t' + label
    outfile.write(outline + '\n')

    for (geneID,geneName) in ASEDict.keys():
        chr = GeneDict[(geneID,geneName)][0][0]
        GeneDict[(geneID,geneName)].sort()
        strand = GeneDict[(geneID,geneName)][0][3]
        if strand == '-':
            pos = GeneDict[(geneID,geneName)][-1][2]
        if strand == '+':
            pos = GeneDict[(geneID,geneName)][0][1]
        outline = geneID + '\t' + geneName + '\t' + chr + '\t' + str(pos)
        for label in labels:
            if ASEDict.has_key((geneID,geneName)) and ASEDict[(geneID,geneName)].has_key(label):
                 outline = outline + '\t' + str(ASEDict[(geneID,geneName)][label])
            else:
                 outline = outline + '\tNS'
        outfile.write(outline +'\n')

    outfile.close()

run()