##################################
#                                #
# Last modified 07/21/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s het_transcriptome_fasta ASE_file p-value cufflinks_genes_fpkm minFPKM outfile [-Bonferroni]' % sys.argv[0]
        sys.exit(1)

    fasta=sys.argv[1]
    ASE=sys.argv[2]
    minpvalue=float(sys.argv[3])
    genesFPKM=sys.argv[4]
    minFPKM=float(sys.argv[5])
    outfilename = sys.argv[6]

    chrDict={}

    doBonferroni=False
    if '-Bonferroni' in sys.argv:
        doBonferroni=True
        print 'will apply Bonferroni multiple hypothesis testing correction'
        lineslist=open(ASE)
        numlines=0
        for line in lineslist:
            numlines+=1
        minpvalue=minpvalue/numlines

    outfile = open(outfilename, 'w')

    linelist=open(ASE)
    for line in linelist:
        fields=line.strip().split('\t')
        if line.startswith('#gene'):
            parent1=fields[5].split('_collapsed_reads')[0]
            parent2=fields[6].split('_collapsed_reads')[0]
            parent1ASE=0
            parent2ASE=0
            parent1ASEN0=0
            parent2ASEN0=0
            continue
        pvalue = float(fields[9])
        chr=fields[1]
        if chrDict.has_key(chr):
            pass
        else:
            chrDict[chr]={}
            chrDict[chr]['ASE']={}
            chrDict[chr]['Expressed']=0
            chrDict[chr]['Het']=0
            chrDict[chr]['ASE'][parent1+'_ASE']=0
            chrDict[chr]['ASE'][parent2+'_ASE']=0
            chrDict[chr]['ASE'][parent1+'_complete_ASE']=0
            chrDict[chr]['ASE'][parent2+'_complete_ASE']=0
        if pvalue > minpvalue:
            continue
        else:
            parent1Counts = int(fields[5])
            parent2Counts = int(fields[6])
            if parent1Counts > parent2Counts:
                chrDict[chr]['ASE'][parent1+'_ASE'] += 1
            if parent1Counts < parent2Counts:
                chrDict[chr]['ASE'][parent2+'_ASE'] += 1
            if parent2Counts > 0 and parent1Counts == 0:
                chrDict[chr]['ASE'][parent2+'_complete_ASE'] += 1
            if parent1Counts > 0 and parent2Counts == 0:
                chrDict[chr]['ASE'][parent1+'_complete_ASE'] += 1

    linelist=open(genesFPKM)
    for line in linelist:
        fields=line.strip().split('\t')
        if line.startswith('tracking_id'):
            continue
        chr=fields[6].split(':')[0]
        if chrDict.has_key(chr):
            pass
        else:
            chrDict[chr]={}
            chrDict[chr]['ASE']={}
            chrDict[chr]['Expressed']=0
            chrDict[chr]['Het']=0
            chrDict[chr]['ASE'][parent1+'_ASE']=0
            chrDict[chr]['ASE'][parent2+'_ASE']=0
            chrDict[chr]['ASE'][parent1+'_complete_ASE']=0
            chrDict[chr]['ASE'][parent2+'_complete_ASE']=0
        FPKM=float(fields[10])
        if FPKM < minFPKM:
            continue
        else:
            chrDict[chr]['Expressed']+=1

    linelist=open(fasta)
    for line in linelist:
        if line.startswith('>'):
            fields=line.strip().split('::')
            if len(fields)<3:
                continue
            chr=fields[1].split(':')[0]
            if chrDict.has_key(chr):
                pass
            else:
                chrDict[chr]={}
                chrDict[chr]['ASE']={}
                chrDict[chr]['Expressed']=0
                chrDict[chr]['Het']=0
                chrDict[chr]['ASE'][parent1+'_ASE']=0
                chrDict[chr]['ASE'][parent2+'_ASE']=0
                chrDict[chr]['ASE'][parent1+'_complete_ASE']=0
                chrDict[chr]['ASE'][parent2+'_complete_ASE']=0
            chrDict[chr]['Het']+=0.5

    outline = '#chr\tHet\tExpressed\t' + parent1+'_ASE\t' + parent2+'_ASE\t' + parent1+ '_complete_ASE\t' + parent2+ '_complete_ASE'
    outfile.write(outline +'\n')

    keys=chrDict.keys()
    keys.sort()
    for chr in keys:
        outline=chr + '\t' + str(chrDict[chr]['Het']) + '\t' + str(chrDict[chr]['Expressed']) + '\t' + str(chrDict[chr]['ASE'][parent1+'_ASE']) + '\t' +  str(chrDict[chr]['ASE'][parent2+'_ASE']) + '\t' + str(chrDict[chr]['ASE'][parent1+'_complete_ASE']) + '\t' + str(chrDict[chr]['ASE'][parent2+'_complete_ASE'])
        outfile.write(outline +'\n')

    outfile.close()

run()