##################################
#                                #
# Last modified 07/02/2012       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import scipy.stats
from sets import Set
import os
import subprocess

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s BAM samtools GTF outfile [-3IDfields]' % sys.argv[0]
        print '\tThe script assumes that reads have been mapped against the output of makeHetTranscriptome2.py, i.e. transcripts in the following format: GeneName:GeneID:TranscriptName:TranscriptID:genotype' 
        print '\tUse the -3IDfields option if transcript names are specified by three fields: GeneName:GeneID:TranscriptName:genotype' 
        print '\tIdentical reads will be collapsed' 
        sys.exit(1)

    BAM=sys.argv[1]
    samtools=sys.argv[2]
    GTF = sys.argv[3]
    outfilename = sys.argv[4]

    do3IDFields = False
    if '-3IDfields' in sys.argv:
        do3IDFields = True

    GeneToChrDict={}
    linelist=open(GTF)
    for line in linelist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr=fields[0]
        geneID=fields[8].split('gene_id "')[1].split('";')[0]
        if 'gene_name' in fields[8]:
            geneName=fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName=geneID
        GeneToChrDict[(geneName,geneID)]=chr

    GeneDict={}
    parentsDict={}

    cmd = samtools + ' view ' + BAM
    p = os.popen(cmd, "r")
    Mapped = False
    i=0
    while Mapped == False:
        line = p.readline()
        i+=1
        fields = line.strip().split('\t')
        if fields[2] == '*':
            continue
        else:
            Mapped=True
            CurrentID=fields[0]
            CurrentIDList = []
            sequence = fields[9]
            if do3IDFields:
                (geneName,geneID,transcriptID,genotype) = tuple(fields[2].split(':'))
                transcriptName = transcriptID
            else:
                (geneName,geneID,transcriptName,transcriptID,genotype) = tuple(fields[2].split(':'))
            CurrentIDList.append((geneName,geneID,transcriptName,transcriptID,genotype,sequence))
            break
    while line != '':
        line = p.readline()
        if line == '':
            continue
        i+=1
        if i % 1000000 == 0:
            print str(i/1000000) + 'M alignments processed'
        fields = line.strip().split('\t')
        
        ID=fields[0]
        sequence = fields[9]
        if ID == CurrentID:
            if do3IDFields:
                (geneName,geneID,transcriptID,genotype) = tuple(fields[2].split(':'))
                transcriptName = transcriptID
            else:
                (geneName,geneID,transcriptName,transcriptID,genotype) = tuple(fields[2].split(':'))
            CurrentIDList.append((geneName,geneID,transcriptName,transcriptID,genotype,sequence))
        else:
            if fields[2] == '*':
                continue
            if do3IDFields:
                (geneName,geneID,transcriptID,genotype) = tuple(fields[2].split(':'))
                transcriptName = transcriptID
            else:
                (geneName,geneID,transcriptName,transcriptID,genotype) = tuple(fields[2].split(':'))
            genes = []
            genotypes = []
            sequences = []
            for (gName,gID,tName,tID,gt,seq) in CurrentIDList:
                genes.append((gName,gID))
                genotypes.append(gt)
                sequences.append(seq)
                parentsDict[gt]=''
            genes = list(Set(genes))
            genotypes = list(Set(genotypes))
            sequences = list(Set(sequences))
            if len(genes) == 1 and len(genotypes) == 1 and len(sequences) == 1:
                if GeneDict.has_key((gName,gID)):
                    pass
                else:
                    GeneDict[(gName,gID)]={}
                if GeneDict[(gName,gID)].has_key(gt):
                    pass
                else:
                    GeneDict[(gName,gID)][gt]=[]
                GeneDict[(gName,gID)][gt].append(seq)
            CurrentID = ID
            CurrentIDList = []
            CurrentIDList.append((geneName,geneID,transcriptName,transcriptID,genotype,sequence))
    genes = []
    genotypes = []
    for (gName,gID,tName,tID,gt,seq) in CurrentIDList:
        genes.append((gName,gID))
        genotypes.append(gt)
        sequences.append(seq)
        parentsDict[gt]=''
    genes = list(Set(genes))
    genotypes = list(Set(genotypes))
    sequences = list(Set(sequences))
    if len(genes) == 1 and len(genotypes) == 1 and len(sequences) == 1:
        if GeneDict.has_key((gName,gID)):
            pass
        else:
            GeneDict[(gName,gID)]={}
        if GeneDict[(gName,gID)].has_key(gt):
            pass
        else:
            GeneDict[(gName,gID)][gt]=[]
        GeneDict[(gName,gID)][gt].append(seq)

    outfile = open(outfilename, 'w')

    keys=GeneDict.keys()
    keys.sort()

    parents=parentsDict.keys()
    if len(parents) > 2:
        print 'more than 2 parents detected, exiting'
        print parents
        sys.exit(1)

    parents.sort()
    parent1 = parents[0]
    parent2 = parents[1]

    print parents

    outline = '#geneName\tgeneID\tchr\t' + parent1 + '_collapsed_reads' + '\t' + parent2 + '_collapsed_reads' + '\t' + parent1 + '_fraction' + '\t' + parent2 + '_fraction' + '\t' + '_pvalue'
    outfile.write(outline +'\n')

    for (geneName,geneID) in keys:
        if GeneDict[(geneName,geneID)].has_key(parent1):
            parent1Counts = len(list(Set(GeneDict[(geneName,geneID)][parent1])))
        else:
            parent1Counts = 0
        if GeneDict[(geneName,geneID)].has_key(parent2):
            parent2Counts = len(list(Set(GeneDict[(geneName,geneID)][parent2])))
        else:
            parent2Counts = 0
        pvalue =  scipy.stats.binom_test(parent1Counts, parent2Counts + parent1Counts, 0.5)
        outline = geneName + '\t' + geneID + '\t' + GeneToChrDict[(geneName,geneID)] + '\t' + str(parent1Counts) + '\t' + str(parent2Counts) + '\t' + str(parent1Counts/(parent1Counts + parent2Counts + 0.0)) + '\t' + str(parent2Counts/(parent1Counts + parent2Counts + 0.0)) + '\t' + str(pvalue)
        outfile.write(outline +'\n')

    outfile.close()

run()