##################################
#                                #
# Last modified 2021/10/11       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import os
import pysam
import random
import string
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s BAM chrom.sizes cluster.csv outprefix' % sys.argv[0]
        print '\tIt is assumed the clusters files looks like this: '
        print '\t\tAAACGGATCAGGGCCT-1,7'
        print '\t\tAAACGGATCAGTTCCC-1,8'
        print '\tIt is assumed barcodes are recorded in the BAM file like this:'
        print '\t\tCB:Z:CTTGAATCATCTTGAG-1'
        print '\t\tCB:Z:CATTATCTCCTCAGCT-1'
        sys.exit(1)

    BAM = sys.argv[1]
    chrominfo=sys.argv[2]
    chromInfoList=[]
    linelist=open(chrominfo)
    for line in linelist:
        fields = line.strip().split('\t')
        chr = fields[0]
        start = 0
        end = int(fields[1])
        chromInfoList.append((chr,start,end))
    BCfile = sys.argv[3]
    outprefix = sys.argv[4]

    i=0
    BCDict={}
    outfiles = {}

    samfile = pysam.Samfile(BAM, "rb" )

    missingBCDict = {}
    foundBCDict = {}

    linelist=open(BCfile)
    for line1 in linelist:
        fields = line1.strip().split(',')
        BC = fields[0]
        Cluster = fields[1]
        if outfiles.has_key(Cluster):
            pass
        else:
            outfiles[Cluster] = pysam.Samfile(outprefix + '.' + Cluster + '.bam', "wb", template=samfile)
            print outprefix + '.' + Cluster + '.bam'
        BCDict[BC] = Cluster

    WR = 0
    for (chr,start,end) in chromInfoList:
        try:
            for alignedread in samfile.fetch(chr, start, end):
                AAAA = alignedread
                break
        except:
            print 'not found in BAM file', chr, start, end
            continue
        for alignedread in samfile.fetch(chr, start, end):
            i+=1
            if i % 1000000 == 0:
                print 'processed', str(i/1000000) + 'M alignments', chr,start,alignedread.pos,end, 'written', WR, ' alignments'
#            print str(alignedread)
            if "('CB', '" in str(alignedread):
                BC = str(alignedread).split("('CB', '")[1].split("'),")[0]
            else:
                continue
#            BC = str(alignedread).split('\tCB:Z:')[1].split('\t')[0]
#            print BC
            if BCDict.has_key(BC):
                foundBCDict[BC] = 1
                Cluster = BCDict[BC]
                outfiles[Cluster].write(alignedread)
                WR+=1
            else:
                missingBCDict[BC] = 1

    print 'found: ', len(foundBCDict.keys())
    print 'missing: ', len(missingBCDict.keys())

run()
