##################################
#                                #
# Last modified 2018/09/14       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import pysam
import os
from sets import Set
import Levenshtein
import numpy as np

# FLAG field meaning
# 0x0001 1 the read is paired in sequencing, no matter whether it is mapped in a pair
# 0x0002 2 the read is mapped in a proper pair (depends on the protocol, normally inferred during alignment) 1
# 0x0004 4 the query sequence itself is unmapped
# 0x0008 8 the mate is unmapped 1
# 0x0010 16 strand of the query (0 for forward; 1 for reverse strand)
# 0x0020 32 strand of the mate 1
# 0x0040 64 the read is the first read in a pair 1,2
# 0x0080 128 the read is the second read in a pair 1,2
# 0x0100 256 the alignment is not primary (a read having split hits may have multiple primary alignment records)
# 0x0200 512 the read fails platform/vendor quality checks
# 0x0400 1024 the read is either a PCR duplicate or an optical duplicate

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','X':'X','a':'t','t':'a','g':'c','c':'g','n':'n','x':'x','R':'R','r':'r','M':'M','m':'m','Y':'Y','y':'y','S':'S','s':'s','K':'K','k':'k','W':'W','w':'w'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 11:
        print 'usage: python %s read2_BAMfilename read1_fastq RT_indexes_file PCR_i7_file PCR_i5_file UMI_length RT_len Expected_number_of_cells GTF chrom.sizes outprefix [-UMIedit N] [-RTedit N] [-exonicOnly] [-uniqueBAM] [-noNH samtools] [-noNHinfo] [-medianExpCellsCutoff fraction] [-3UTRextend bp] [-refFlat] [-splitAmbiguousReads]' % sys.argv[0]
        print '\t the read1 fastq files can be in .bz2 or .gz format'
        print '\t index files -- one sequence per line'
        print '\t the i7 indexes should be in their original oritentation, the script will reverse complement them'
        print '\t the script will only use uniquely mapped reads'
        print '\t the [-medianExpCellsCutoff] option will remove all cells with UMIs fewer than the indicated fraction of the median for the top N cells, where N is the expected number of cells (default: 0.20)'
        print '\t the [-3UTRextend] option will extend the 3UTRs of genes by the indicated number of bp. Default: 100'
        print '\t #### [-exonicOnly] option not implemented yet'
        print '\t The [-splitAmbiguousReads] will give each overlapping gene that a uniquely mapped read could be assigned a fractional UMI count'
        sys.exit(1)

    BAM = sys.argv[1]
    fastq = sys.argv[2]
    RTindexes = sys.argv[3]
    PCRindexes7 = sys.argv[4]
    PCRindexes5 = sys.argv[5]
    UMILen = int(sys.argv[6])
    RTLen = int(sys.argv[7])
    ExpCellNum = int(sys.argv[8])
    GTF = sys.argv[9]
    chrominfo=sys.argv[10]
    chromInfoList = []
    chromInfoDict = {}
    linelist=open(chrominfo)
    for line in linelist:
        fields=line.strip().split('\t')
        chr=fields[0]
        start=0
        end=int(fields[1])
        chromInfoList.append((chr,start,end))
        chromInfoDict[chr] = end
    outprefix = sys.argv[11]

    noMulti=True
#    if '-nomulti' in sys.argv:
#        print 'will only consider unique alignments'
#        noMulti=True

    doExonicOnly = False
    if '-exonicOnly' in sys.argv:
        doExonicOnly = True

    UMIedit = 1
    if '-UMIedit' in sys.argv:
        UMIedit = int(sys.argv[sys.argv.index('-UMIedit') + 1])

    RTedit = 1
    if '-RTedit' in sys.argv:
        RTedit = int(sys.argv[sys.argv.index('-RTedit') + 1])

    UMICutoff = 0.20
    if '-medianExpCellsCutoff' in sys.argv:
        UMICutoff = float(sys.argv[sys.argv.index('-medianExpCellsCutoff') + 1])

    ReadIDDict = {}

    samfile = pysam.Samfile(BAM, "rb" )

    UTR3ext = 100
    if '-3UTRextend' in sys.argv:
        UTR3ext = int(sys.argv[sys.argv.index('-3UTRextend') + 1])

    doUniqueBAM = False
    if '-uniqueBAM' in sys.argv:
        print 'will treat all alignments as unique'
        doUniqueBAM = True
        TotalReads = 0
        pass

    doRF = False
    if '-refFlat' in sys.argv:
        print 'will treat GTF file as if it is a refFlat one'
        doRF = True

    doSAR = False
    if '-splitAmbiguousReads' in sys.argv:
        print 'will assign fractional weights to genes with amniguous alignments'
        doSAR = True

    doNoNHinfo = False
    if '-noNHinfo' in sys.argv:
        doNoNHinfo = True
        print 'will directly evaluate read mulitplicity'
        MultiplicityDict = {}
    else:
        try:
            print 'testing for NH tags presence'
            for alignedread in samfile.fetch():
                multiplicity = alignedread.opt('NH')
                print 'file has NH tags'
                break
        except:	
            if '-noNH' in sys.argv:
                print 'no NH: tags in BAM file, will replace with a new BAM file with NH tags'
                samtools = sys.argv[sys.argv.index('-noNH')+1]
                BAMpreporcessingScript = sys.argv[0].rpartition('/')[0] + '/bamPreprocessing.py'
                cmd = 'python ' + BAMpreporcessingScript + ' ' + BAM + ' ' + BAM + '.NH'
                os.system(cmd)
                cmd = 'rm ' + BAM
                os.system(cmd)
                cmd = 'mv ' + BAM + '.NH' + ' ' + BAM
                os.system(cmd)
                cmd = samtools + ' index ' + BAM
                os.system(cmd)
            elif doUniqueBAM:
                pass
            else:
                print 'no NH: tags in BAM file, exiting'
                sys.exit(1)

    if doNoNHinfo:
        i=0
        for (chr,start,end) in chromInfoList:
            try:
                jj=0
                for alignedread in samfile.fetch(chr, start, end):
                    jj+=1
                    if jj==1:
                        break
            except:
                print 'problem with region:', chr, start, end, 'skipping'
                continue
            for alignedread in samfile.fetch(chr, start, end):
                i+=1
                if i % 5000000 == 0:
                    print str(i/1000000) + 'M alignments processed in multiplicity assessment', chr,start,alignedread.pos,end
                fields = str(alignedread).split('\t')
                ID=fields[0]
#                if alignedread.is_read1:
#                    ID = ID + '/1'
#                if alignedread.is_read2:
#                    ID = ID + '/2'
                if MultiplicityDict.has_key(ID):
                    pass
                else:
                    MultiplicityDict[ID] = 0
                MultiplicityDict[ID] += 1

    j=0
    lineslist = open(GTF)
    TranscriptDict = {}
    GeneDict = {}
    GeneIDDict = {}
    G = 0
    for line in lineslist:
        j+=1
        if j % 1000000 == 0:
            print j, 'lines processed'
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if fields[2]!='exon':
            continue
        chr = fields[0]
        if 'gene_name "' in fields[8]:
            geneName = fields[8].split('gene_name "')[1].split('";')[0]
        else:
            geneName = fields[8].split('gene_id "')[1].split('";')[0]
        geneID = fields[8].split('gene_id "')[1].split('";')[0]
        if doRF:
            geneID = geneID + '-' + chr
        if 'transcript_name "' in fields[8]:
            transcriptName = fields[8].split('transcript_name "')[1].split('";')[0]
        else:
            transcriptName = fields[8].split('transcript_id "')[1].split('";')[0]
        transcriptID=fields[8].split('transcript_id "')[1].split('";')[0]
        transcript = (geneID, geneName, transcriptName, transcriptID)
        if GeneDict.has_key(geneID):
            pass
        else:
            GeneDict[geneID] = {}
            GeneDict[geneID]['transcripts'] = {}
            GeneDict[geneID]['transcripts'][transcript] = 1
            GeneDict[geneID]['coordinates'] = []
            G+=1
            GeneDict[geneID]['G'] = G
            GeneIDDict[G] = geneID
        if TranscriptDict.has_key(transcript):
            pass
        else:
            TranscriptDict[transcript]=[]
        left = int(fields[3])
        right = int(fields[4])
        strand = fields[6]
        TranscriptDict[transcript].append((chr,left,right,strand))
        GeneDict[geneID]['chr'] = chr
        GeneDict[geneID]['strand'] = strand
        GeneDict[geneID]['coordinates'].append(left)
        GeneDict[geneID]['coordinates'].append(right)

    for transcript in TranscriptDict.keys():
        TranscriptDict[transcript].sort()

    print 'finished parsing GTF file'

    print 'found', len(GeneDict.keys()), 'genes'

    ReadDict = {}

    RN=0
    GN = 0
    for geneID in GeneDict.keys():
        GN += 1
        chr = GeneDict[geneID]['chr'] 
        start = min(GeneDict[geneID]['coordinates'])
        end = max(GeneDict[geneID]['coordinates'])
        strand = GeneDict[geneID]['strand']
        if strand == '+':
            end = min(end + UTR3ext, chromInfoDict[chr])
        if strand == '-':
            start = max(0,start - UTR3ext)
        for alignedread in samfile.fetch(chr, start, end):
            RN+=1
            if RN % 1000000 == 0:
                print str(RN/1000000) + 'M alignments processed;', GN, 'genes processed'
            fields=str(alignedread).split('\t')
            FLAGfields = FLAG(int(fields[1]))
            if alignedread.is_reverse:
                s = '-'
            else:
                s = '+'
            if s != strand:
                continue
            ID = fields[0].split('_2:N:0:')[0]
            if doUniqueBAM:
                multiplicity = 1
            elif doNoNHinfo:
                multiplicity = MultiplicityDict[ID]
            else:
                multiplicity = alignedread.opt('NH')
            if multiplicity > 1:
                continue
            cigar = alignedread.cigar
            pos = alignedread.pos
            FivePEnd = alignedread.pos
            if strand == '-':
                currentPos = alignedread.pos
                for (m,bp) in cigar:
                    currentPos = currentPos + bp
                FivePEnd = currentPos
            if FivePEnd < start or FivePEnd > end:
                continue
            if ReadDict.has_key(ID):
                pass
            else:
                ReadDict[ID] = {}
                ReadDict[ID]['al'] = []
                ReadDict[ID]['gs'] = []
            ReadDict[ID]['al'] = (chr,pos,s,str(cigar))
            ReadDict[ID]['gs'].append(GeneDict[geneID]['G'])

#    RN=0
#    for (chr,start,end) in chromInfoList:
#        try:
#            for alignedread in samfile.fetch(chr, 0, 100):
#                a='b'
#        except:
#            print 'region', chr,start,end, 'not found in bam file, skipping'
#            continue
#        currentPos=0
#        for alignedread in samfile.fetch(chr, start, end):
#            RN+=1
#            if RN % 5000000 == 0:
#                print str(RN/1000000) + 'M alignments processed', chr, currentPos, end
##            if doReadLength:
##                if len(alignedread.seq) > maxRL or len(alignedread.seq) < minRL:
##                    continue
#            fields=str(alignedread).split('\t')
#            FLAGfields = FLAG(int(fields[1]))
##            if doEnd1Only:
##                if 128 in FLAGfields:
##                    continue
##            if doEnd2Only:
##                if 64 in FLAGfields:
##                    continue
#            ID = fields[0].split('_2:N:0:')[0]
##            if doMaxMM:
##                mismatches = 0
##                for (m,bp) in alignedread.cigar:
##                    if m == 8:
##                        mismatches+=1
##                if mismatches > maxMM:
##                    continue
##            if doMaxMMMD:
##                MM = alignedread.opt('MD')
##                mismatches = 0
##                if MM.isdigit():
##                    pass
##                else:
##                    for s in range(len(MM)):
##                        if MM[s].isalpha():
##                            mismatches+=1
##                if mismatches > maxMM:
##                    continue
#            if doUniqueBAM:
#                multiplicity = 1
#            elif doNoNHinfo:
##                if alignedread.is_read1:
##                    ID = ID + '/1'
##                if alignedread.is_read2:
##                    ID = ID + '/2'
#                multiplicity = MultiplicityDict[ID]
#            else:
#                multiplicity = alignedread.opt('NH')
##            if noMulti and multiplicity > 1:
##                continue
#            if multiplicity > 1:
#                continue
#            scaleby = 1.0/multiplicity
#            FLAGfields = FLAG(int(fields[1]))
#            if alignedread.is_reverse:
#                s = '-'
#            else:
#                s = '+'
#            pos = alignedread.pos
#            cigar = alignedread.cigar
#            if ReadDict.has_key(ID):
#                pass
#            else:
#                ReadDict[ID] = {}
#                ReadDict[ID]['al'] = []
##            if noMulti:
##                ReadDict[ID]['al'] = (chr,pos,s,str(cigar))
##            else:
##                ReadDict[ID]['al'].append((chr,pos,s,str(cigar)))
#            ReadDict[ID]['al'] = (chr,pos,s,str(cigar))

    print 'found alignments for', len(ReadDict.keys()), 'reads'

    linelist = open(RTindexes)
    RTindexesDict = {}
    for line in linelist:
        index = line.strip().split('\t')[0]
        RTindexesDict[index] = 1

    linelist = open(PCRindexes7)
    PCRindexes7Dict = {}
    for line in linelist:
        index = line.strip().split('\t')[0]
        index = getReverseComplement(index)
        PCRindexes7Dict[index] = 1

    linelist = open(PCRindexes5)
    PCRindexes5Dict = {}
    for line in linelist:
        index = line.strip().split('\t')[0]
        PCRindexes5Dict[index] = 1

    print 'finished parsing index sequences'

    if fastq.endswith('.gz'):
        cmd = 'zcat ' + fastq
    elif fastq.endswith('.bz2'):
        cmd = 'bzip2 -cd ' + fastq
    elif fastq.endswith('.zip'):
        cmd = 'unzip -p ' + fastq
    else:
        cmd = 'cat ' + fastq
    p = os.popen(cmd, "r")
    line = 'line'
    RN = 0.0
    BC = 0
    BBC = 0
    BCRTP = 0
    BCPCR7P = 0
    BCPCR5P = 0
    BCRTF = 0
    BCPCR7F = 0
    BCPCR5F = 0
    BCRTE = 0
    BCPCR7E = 0
    BCPCR5E = 0
    TTT = 0
    DD7 = 0
    DD5 = 0
    DDRT = 0
    i= 1
    while line != '':
        line = p.readline()
        if line == '':
            break
        RN += 1
        if RN % 20000000 == 0:
            print RN, str(RN/4000000) + 'M reads processed'
        if i == 1:
            i = 2
            if line.startswith('@'):
                ID = line.strip()[1:].split(' 1:N:0:')[0]
                PCR7 = line.strip()[1:].split(' 1:N:0:')[1].split('+')[0]
                PCR5 = line.strip()[1:].split(' 1:N:0:')[1].split('+')[1]
            else:
                print 'fastq file broken, exiting'
                print line.strip()
                sys.exit(1)
            continue
        if i == 2:
            i = 3
            sequence = line.strip()
            UMI = sequence[0:UMILen]
            RT = sequence[UMILen:UMILen+RTLen]
            if sequence[UMILen+RTLen:UMILen+RTLen+6] != 'TTTTTT':
                TTT += 1
                continue
            if RTindexesDict.has_key(RT):
                NearestRTIdx = [RT]	
                BCRTP += 1
            else:
                EDist = len(RT)
                NearestRTIdx = []
                for RTidx in RTindexesDict.keys():
                    LDist = Levenshtein.distance(RT,RTidx)
                    if LDist <= RTedit: 
                        if LDist < EDist:
                            EDist = LDist
                            NearestRTIdx = [RTidx]
                        if LDist == EDist:
                            NearestRTIdx.append(RTidx)
                if len(NearestRTIdx) == 0:
                    BCRTF += 1
                    continue
                else:
                    BCRTE += 1
            if PCRindexes7Dict.has_key(PCR7):
                NearestPCR7Idx = [PCR7]
                BCPCR7P += 1
            else:
                EDist = len(PCR7)
                NearestPCR7Idx = []
                for PCR7idx in PCRindexes7Dict.keys():
                    LDist = Levenshtein.distance(PCR7,PCR7idx)
                    if LDist <= RTedit:
                        if LDist < EDist:
                            EDist = LDist
                            NearestPCR7Idx = [PCR7idx]
                        if LDist == EDist:
                            NearestPCR7Idx.append(PCR7idx)
                if len(NearestPCR7Idx) == 0:
                    BCPCR7F += 1
                    continue
                else:
                    BCPCR7E += 1
            if PCRindexes5Dict.has_key(PCR5):
                NearestPCR5Idx = [PCR5]
                BCPCR5P += 1
            else:
                EDist = len(PCR5)
                NearestPCR5Idx = []
                for PCR5idx in PCRindexes5Dict.keys():
                    LDist = Levenshtein.distance(PCR5,PCR5idx)
                    if LDist <= RTedit:
                        if LDist < EDist:
                            EDist = LDist
                            NearestPCR5Idx = [PCR5idx]
                        if LDist == EDist:
                            NearestPCR5Idx.append(PCR5idx)
                if len(NearestPCR5Idx) == 0:
                    BCPCR5F += 1
                    continue
                else:
                    BCPCR5E += 1
            NearestRTIdx = list(Set(NearestRTIdx))
            NearestPCR5Idx = list(Set(NearestPCR5Idx))
            NearestPCR7Idx = list(Set(NearestPCR7Idx))
            if len(NearestRTIdx) > 1:
                print 'Master list RT indexes closer to each other than the allowed editing distance, skipping'
                print RT, NearestRTIdx
                DDRT += 1
                continue
            if len(NearestPCR5Idx) > 1:
                print 'Master list i5 indexes closer to each other than the allowed editing distance, skipping'
                print PCR5, NearestPCR5Idx
                DD5 += 1
                continue
            if len(NearestPCR7Idx) > 1:
                print 'Master list i7 indexes closer to each other than the allowed editing distance, skipping'
                print PCR7, NearestPCR7Idx
                DD7 += 1
                continue
            BBC += 1
            if ReadDict.has_key(ID):
                ReadDict[ID]['bc'] = (NearestRTIdx[0],NearestPCR7Idx[0],NearestPCR5Idx[0],UMI)
                BC += 1
            continue
        if i == 3:
            i = 4
            continue
        if i == 4:
            i = 1
            continue

    print 'finished parsing barcodes, found', str(BBC), 'aligned reads with barcodes passing criteria'
    print 'finished parsing barcodes, found', str(BC), 'aligned reads with barcodes with matching ID in end2 file'
    print 'number RT perfect matches:', BCRTP
    print 'number RT imperfect matches:', BCRTE
    print 'number i7 perfect matches:', BCPCR7P
    print 'number i7 imperfect matches:', BCPCR7E
    print 'number i5 perfect matches:', BCPCR5P
    print 'number i5 imperfect matches:', BCPCR5E
    print 'failed due to not containing proper oligo dT', TTT
    print 'failed due to matching multiple RT barcodes', DDRT
    print 'failed due to matching multiple i5 barcodes', DD5
    print 'failed due to matching multiple i7 barcodes', DD7
    print 'failed to find RT matches', BCRTF
    print 'failed to find i7 matches', BCPCR7F
    print 'failed to find i5 matches', BCPCR5F

    GeneExpressionMatrix = {}

    SeenGenes = {}

    i = 0
    ambiguous = 0
    for ID in ReadDict.keys():
        i += 1
        if i % 1000000 == 0:
            print i
        if ReadDict[ID].has_key('bc'):
            pass
        else:
            continue
        (RT,i7,i5,UMI) = ReadDict[ID]['bc']
        BC = (RT,i7,i5)
        AL = ReadDict[ID]['al']
        genes = ReadDict[ID]['gs']
        if len(genes) == 1:
            TheG = [genes[0]]
        else:
            (chr, pos, strand, cigar) = AL
            cigarfields = cigar.split('), (')
            if len(cigarfields) == 1:
                ambiguous += 1
                if doSAR:
                    TheG = genes
                else:
                    continue
            else:
                splicesites = {}
                KCF = 0
                for CF in cigarfields:
                    KCF+=1
                    CF = CF.replace('[','').replace(']','').replace('(','').replace(')','')
                    m = int(CF.split(', ')[0])
                    bp = int(CF.split(', ')[1])
                    if m == 0 and KCF > 1:
                        splicesites[currentPos] = 1
                    elif m == 3:
                        if strand == '+':
                            splicesites[currentPos+1] = 1
                        if strand == '-':
                            splicesites[currentPos+1] = 1
                    else:
                        pass
                    currentPos = currentPos+bp
                matchedSS = {}
                for G in genes:
                    matchedSS[G] = {}
                    geneID = GeneIDDict[G]
                    for transcript in GeneDict[geneID]['transcripts']:
                        KTP = 0
                        for (chr,LL,RR,strand) in TranscriptDict[transcript]:
                            KTP += 1
                            if KTP == 1:
                                if splicesites.has_key(RR):
                                    matchedSS[G][RR] = 1
                            elif KTP == len(TranscriptDict[transcript]):
                                if splicesites.has_key(LL):
                                    matchedSS[G][LL] = 1
                            else:
                                if splicesites.has_key(RR):
                                    matchedSS[G][RR] = 1
                                if splicesites.has_key(LL):
                                    matchedSS[G][LL] = 1
                matchedSSgenes = []
                for G in genes:
                    matchedSSgenes.append((len(matchedSS[G].keys()),G))
                matchedSSgenes.sort()
                matchedSSgenes.reverse()
                if matchedSSgenes[0][0] > matchedSSgenes[1][0]:
                    TheG = [matchedSSgenes[0][0]]
                else:
                    ambiguous += 1
                    if doSAR:
                        TheG = []
                        for G in matchedSSgenes:
                            if G[0] == matchedSSgenes[0]:
                                TheG.append(G[1])
                    else:
                        continue
        if GeneExpressionMatrix.has_key(BC):
            pass
        else:
            GeneExpressionMatrix[BC] = {}
            GeneExpressionMatrix[BC]['UMIs'] = {}
            GeneExpressionMatrix[BC]['gene_counts'] = {}
        if GeneExpressionMatrix[BC]['UMIs'].has_key(AL):
            if GeneExpressionMatrix[BC]['UMIs'][AL].has_key(UMI):
                GeneExpressionMatrix[BC]['UMIs'][AL][UMI] += 1
                continue
            else:
                EDist = len(UMI)
                NearestUMI = []
                for U in GeneExpressionMatrix[BC]['UMIs'][AL].keys():
                    LDist = Levenshtein.distance(U,UMI)
                    if LDist <= UMIedit:
                        if LDist < EDist:
                            EDist = LDist
                            NearestUMI = [U]
                        if LDist == EDist:
                            NearestUMI.append(U)
                NearestUMI = list(Set(NearestUMI))
                if len(NearestUMI) == 0:
                    GeneExpressionMatrix[BC]['UMIs'][AL][UMI] = 1
                    for G in TheG:
                        SeenGenes[G] = 1
                        if GeneExpressionMatrix[BC]['gene_counts'].has_key(G):
                            if len(TheG) == 1:
                                GeneExpressionMatrix[BC]['gene_counts'][G] +=1
                            else:
                                GeneExpressionMatrix[BC]['gene_counts'][G] += 1./len(TheG)
                        else:
                            if len(TheG) == 1:
                                GeneExpressionMatrix[BC]['gene_counts'][G] = 1
                            else:
                                GeneExpressionMatrix[BC]['gene_counts'][G] = 1./len(TheG)
                elif len(NearestUMI) == 1:
                    GeneExpressionMatrix[BC]['UMIs'][AL][NearestUMI[0]] += 1
                else:
                    print 'multiple matching UMIs detected for the following read:'
                    print AL
                    print NearestUMI,
                    print ReadDict[ID]['bc']
        else:
            GeneExpressionMatrix[BC]['UMIs'][AL] = {}
            GeneExpressionMatrix[BC]['UMIs'][AL][UMI] = 1
            for G in TheG:
                SeenGenes[G] = 1
                if GeneExpressionMatrix[BC]['gene_counts'].has_key(G):
                    if len(TheG) == 1:
                        GeneExpressionMatrix[BC]['gene_counts'][G] +=1
                    else:
                        GeneExpressionMatrix[BC]['gene_counts'][G] += 1./len(TheG)
                else:
                    if len(TheG) == 1:
                        GeneExpressionMatrix[BC]['gene_counts'][G] = 1
                    else:
                        GeneExpressionMatrix[BC]['gene_counts'][G] = 1./len(TheG)

    print 'finished collapsing UMIs'
    if doSAR:
        print 'found', ambiguous, 'reads with ambiguous gene assignment'
    else:
        print 'discarded', ambiguous, 'reads with ambiguous gene assignment'  
    print 'found', len(GeneExpressionMatrix.keys()), 'cell barcodes in total'
    print 'filtering poor quality cells'

    outfile = open(outprefix + '.UMIs_per_cell', 'w')
    outline = '#RT\ti7\ti5\trank\tUMIs\tAligned_Positions'
    outfile.write(outline + '\n')

    outlines = []

    for BC in GeneExpressionMatrix.keys():
        U = 0
        for AL in GeneExpressionMatrix[BC]['UMIs'].keys():
            U += len(GeneExpressionMatrix[BC]['UMIs'][AL].keys())
        ALs = len(GeneExpressionMatrix[BC]['UMIs'].keys())
        outlines.append((U,ALs,BC))

    outlines.sort()
    outlines.reverse()

    R=0
    for (U,ALs,BC) in outlines:
        R+=1
        outline = BC[0] + '\t' + BC[1] + '\t' + BC[2] + '\t' + str(R) + '\t' + str(U) + '\t' + str(ALs)
        outfile.write(outline + '\n')

    outfile.close()

    GeneMatrixUMICounts = []
    for i in range(ExpCellNum):
        GeneMatrixUMICounts.append(outlines[i][0])

    print 'Expected number of cells:', ExpCellNum
    print 'fraction of the median UMIs of top ECN cells cutoff:', UMICutoff
    print 'median UMIs of top ECN cells:', np.median(GeneMatrixUMICounts)

    cutoff = np.median(GeneMatrixUMICounts)*UMICutoff

    print 'cutoff value:', cutoff, 'UMIs'

    for BC in GeneExpressionMatrix.keys():
        U = 0
        for AL in GeneExpressionMatrix[BC]['UMIs'].keys():
            U += len(GeneExpressionMatrix[BC]['UMIs'][AL].keys())
        if U < cutoff:
            del GeneExpressionMatrix[BC]

    print 'cells remaining after filtering poor quality cells:', len(GeneExpressionMatrix.keys())

    outfile = open(outprefix + '.table', 'w')

    outline = '#geneID'

    BCs = GeneExpressionMatrix.keys()
    BCs.sort()
  
    for BC in BCs:
        (RT,i7,i5) = BC
        outline = outline + '\t' + RT + '-' + i7 + '-' + i5
    outfile.write(outline+'\n')

    Gs = SeenGenes.keys()
    for G in Gs:
        geneID = GeneIDDict[G]
        outline = geneID
        for BC in BCs:
            if GeneExpressionMatrix[BC]['gene_counts'].has_key(G):
                umicounts = GeneExpressionMatrix[BC]['gene_counts'][G]
            else:
                umicounts = 0
            outline = outline + '\t' + str(umicounts)
        outfile.write(outline+'\n')

    print 'finished outputting counts'

    outfile.close()
            
run()
