##################################
#                                #
# Last modified 2025/03/04       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import pysam
from sets import Set
# import h5py
# import numpy as np    
import os
import re

def FLAG(FLAG):

    Numbers = [0,1,2,4,8,16,32,64,128,256,512,1024,2048]

    FLAGList=[]

    MaxNumberList=[]
    for i in Numbers:
        if i <= FLAG:
            MaxNumberList.append(i)

    Residual=FLAG
    maxPos = len(MaxNumberList)-1

    while Residual > 0:
        if MaxNumberList[maxPos] <= Residual:
            Residual = Residual - MaxNumberList[maxPos]
            FLAGList.append(MaxNumberList[maxPos])
            maxPos-=1
        else:
            maxPos-=1
  
    return FLAGList

def getReverseComplement(preliminarysequence):
    
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}
    sequence=''
    for j in range(len(preliminarysequence)):
        sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-j-1]]
    return sequence

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s bam genome.fa outfile_prefix [-excludeContext string(,string2,string3,...,stringN) radius] [-excludeChr chr1[,chr2,...,chrN]] [-chrPrefix string] [-strand +|-]' % sys.argv[0]
        sys.exit(1)

    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N','a':'t','t':'a','g':'c','c':'g','n':'n'}

    doPlusOnly = False
    doMinusOnly = False

    if '-strand' in sys.argv:
        wantedStrand = sys.argv[sys.argv.index('-strand') + 1]
        if wantedStrand == '+':
            doPlusOnly = True 
            print 'will only output plus-strand reads'
        if wantedStrand == '-':
            doMinusOnly = True 
            print 'will only output minus-strand reads'

    doChrPrefix = False
    if '-chrPrefix' in sys.argv:
        doChrPrefix = True
        chrPrefix = sys.argv[sys.argv.index('-chrPrefix') + 1]

    ExcludedChrs = {}
    if '-excludeChr' in sys.argv:
        for chr in sys.argv[sys.argv.index('-excludeChr') + 1].split(','):
            ExcludedChrs[chr] = 1

    BAM = sys.argv[1]
    fasta = sys.argv[2]
    outprefix = sys.argv[3]

    GenomeDict={}
    sequence=''
    inputdatafile = open(fasta)
    for line in inputdatafile:
        if line[0]=='>':
            if sequence != '':
                GenomeDict[chr] = ''.join(sequence).upper()
            chr = line.strip().split('>')[1]
            sequence=[]
            Keep=False
            continue
        else:
            sequence.append(line.strip())
    GenomeDict[chr] = ''.join(sequence).upper()

    for chr in ExcludedChrs.keys():
        if GenomeDict.has_key(chr):
           del GenomeDict[chr]

    print 'finished inputting genomic sequence'

    doExcludeContext = False
    if '-excludeContext' in sys.argv:
        doExcludeContext = True
        contexts = []
        for cont in sys.argv[sys.argv.index('-excludeContext') + 1].split(','):
            contexts.append(cont.upper())
        radius = int(sys.argv[sys.argv.index('-excludeContext') + 2])
        print 'will filter out the following sequence contexts:', contexts, 'with a radius of', radius, 'bp around each match'

        for chr in GenomeDict.keys():
            print 'masking', chr
            ToBeMasked = {}
            for C in contexts:
                print C, chr, 
                for m in re.finditer(C,GenomeDict[chr]):
                    pos1 = m.start()
                    pos2 = m.end()
                    for i in range(pos1-radius,pos2+radius):
                        ToBeMasked[i] = 1
            print len(ToBeMasked), len(GenomeDict[chr]), len(ToBeMasked)/(len(GenomeDict[chr]) + 0.0)
            NewSeq = []
            for i in range(len(GenomeDict[chr])):
                if ToBeMasked.has_key(i):
                    NewSeq.append('X')
                else:
                    NewSeq.append(GenomeDict[chr][i])
            GenomeDict[chr] = ''.join(NewSeq)

        print 'finished masking genomic sequence'

    outfile = open(outprefix + '.reads.tsv', 'w')

    samfile = pysam.Samfile(BAM, "rb" )

    RL = 0
    for read in samfile.fetch(until_eof=True):
        RL+=1
        if RL % 100000 == 0:
            print str(RL/100000) + 'M alignments processed processed', chr
        fields = str(read).split('\t')
        ID = read.qname
        if read.is_unmapped:
            continue
        chr = samfile.getrname(read.tid)
        if GenomeDict.has_key(chr):
            pass
        else:
            continue
        FLAGfields = FLAG(int(fields[1]))
        if 256 in FLAGfields:
            continue
        if 16 in FLAGfields:
            strand = '-'
            if doPlusOnly:
                continue
# for now treating all reads as plus because of the A/T issue
# 
#             if doMinusOnly:
#                 strand = '+'
#             sequence = getReverseComplement(fields[9])
            sequence = fields[9]
        else:
            strand = '+'
            if doMinusOnly:
                continue
            sequence = fields[9]
#### Don't skip the 2048s !!!!!
#        if 2048 in FLAGfields:
#            continue
        try:
            MM = read.opt('MM')
            ML = read.opt('ML')
        except:
            print 'MM tag not present, skipping', ID
            continue
        pos = int(fields[3])
        contexts = MM.split(';')
        MLpos = 0
        outputPositions = []
        outputProbs = []

        ReadPosToGenomePos = {}
        currentPos = pos
        readPos = 0
        for (m,bp) in read.cigar:
            if m == 0:
                for j in range(bp):
                    ReadPosToGenomePos[readPos + j] = currentPos + j
                readPos += bp
                currentPos += bp
            elif m == 4:
                readPos += bp
            elif m == 2:
                currentPos += bp
            elif m == 3:
                currentPos += bp
            elif m == 1:
                readPos += bp
            else:
                print ID,( m,bp), 'unrecognized CIGAR field, exiting'
                sys.exit(1)
#        if strand == '+':
#            readPos = 0
#            for (m,bp) in read.cigar:
#                if m == 0:
#                    for j in range(bp):
#                        ReadPosToGenomePos[readPos + j] = currentPos + j
#                    readPos += bp
#                    currentPos += bp
#                elif m == 4:
#                    readPos += bp
#                elif m == 2:
#                    currentPos += bp
#                elif m == 3:
#                    currentPos += bp
#                elif m == 1:
#                    readPos += bp
#                else:
#                    print ID,( m,bp), 'unrecognized CIGAR field, exiting'
#                    sys.exit(1)
#        if strand == '-':
#            readPos = len(sequence)
#            for (m,bp) in read.cigar:
#                if m == 0:
#                    for j in range(bp):
#                        ReadPosToGenomePos[readPos - j] = currentPos + j
#                    readPos -= bp
#                    currentPos += bp
#                elif m == 4:
#                    readPos -= bp
#                elif m == 2:
#                    currentPos += bp
#                elif m == 3:
#                    currentPos += bp
#                elif m == 1:
#                    readPos -= bp
#                else:
#                    print ID,( m,bp), 'unrecognized CIGAR field, exiting'
#                    sys.exit(1)

        for modSeq in contexts:
            mod = modSeq.split(',')[0]
            if mod[0:3] == 'A+a' and do6A:
                if strand == '+':
                    positions = [m.start() for m in re.finditer('A', sequence)]
#                    positionsT = [m.start() for m in re.finditer('T', sequence)]
#                    positionsA = [m.start() for m in re.finditer('A', sequence)]
                if strand == '-':
#                    positionsA = [m.start() for m in re.finditer('A', sequence)]
#                    positionsT = [m.start() for m in re.finditer('T', sequence)]
                    positions = [m.start() for m in re.finditer('T', sequence)]
                    positions.reverse()
                if list(Set(modSeq.split(',')[1:])) == ['0']:
#                    print 'P', ID, strand, list(Set(modSeq.split(',')[1:]))
                    outputPositions += positions
                else:
#                    print 'F', ID, strand, list(Set(modSeq.split(',')[1:]))
                    currentPos = 0
                    for P in modSeq.split(',')[1:]:
                        currentPos = currentPos + int(P)
                        outputPositions += [positions[currentPos]]
                        currentPos += 1
                outputProbs += ML[MLpos:MLpos+len(modSeq.split(','))-1]
                MLpos = MLpos + len(modSeq.split(','))-1
#                print 'A', len(positions), 'A:', len(positionsA), 'T:', len(positionsT), len(outputPositions), len(outputProbs)
            elif mod[0:3] == 'C+m' and do5C:
                if strand == '+':
                    positions = [m.start() for m in re.finditer('C', sequence)]
                if strand == '-':
                    positions = [m.start() for m in re.finditer('G', sequence)]
                    positions.reverse()
                if list(Set(modSeq.split(',')[1:])) == ['0']:
                    outputPositions += positions
                else:
                    currentPos = 0
                    for P in modSeq.split(',')[1:]:
                        currentPos = currentPos + int(P)
                        outputPositions += [positions[currentPos]]
                        currentPos += 1
                outputProbs += ML[MLpos:MLpos+len(modSeq.split(','))-1]
                MLpos = MLpos + len(modSeq.split(','))-1
#                print 'C', len(positions), len(outputPositions), len(outputProbs)
            elif mod[0:3] == 'T+g' and doT:
                if strand == '+':
                    positions = [m.start() for m in re.finditer('T', sequence)]
                if strand == '-':
                    positions = [m.start() for m in re.finditer('A', sequence)]
                    positions.reverse()
                if list(Set(modSeq.split(',')[1:])) == ['0']:
                    outputPositions += positions
                else:
                    currentPos = 0
                    for P in modSeq.split(',')[1:]:
                        currentPos = currentPos + int(P)
                        outputPositions += [positions[currentPos]]
                        currentPos += 1
                outputProbs += ML[MLpos:MLpos+len(modSeq.split(','))-1]
                MLpos = MLpos + len(modSeq.split(','))-1
            else:
                MLpos = MLpos + len(modSeq.split(','))-1
                continue
#        print len(outputPositions), len(outputProbs)
        preliminaryOutput = zip(outputPositions, outputProbs)
        preliminaryOutput.sort()

        finalOutput = []

        print 'preliminaryOutput', len(preliminaryOutput), preliminaryOutput[0]
#        print fields[5]

        for (ppos,ll) in preliminaryOutput:
            if ReadPosToGenomePos.has_key(ppos):
                pass
            else:
                continue
            gpos = ReadPosToGenomePos[ppos]
            seq = GenomeDict[chr][gpos-1:gpos+2]
            seq1 = GenomeDict[chr][gpos:gpos+2]
            seq2 = GenomeDict[chr][gpos-1:gpos+1]
#            print ID, strand, ppos, sequence[ppos-1:ppos+2], gpos, seq, seq1, seq2
            if strand == '+':
                if GConly:
                    if seq2 == 'GC':
                        finalOutput.append((gpos,ll))
                elif CGonly:
                    if seq1 == 'CG':
                        finalOutput.append((gpos,ll))
                elif CGGConly:
                    if seq1 == 'CG' or seq2 == 'GC':
                        finalOutput.append((gpos,ll))
                elif dom6AGCCGonly:
                    if seq1 == 'CG' or seq2 == 'GC':
                        finalOutput.append((gpos,ll))
                    elif seq1[0] == 'A':
                        finalOutput.append((gpos,ll))
                elif dom6AGConly:
                    if seq2 == 'GC':
                        finalOutput.append((gpos,ll))
                    elif seq1[0] == 'A':
                        finalOutput.append((gpos,ll))
                elif dom6ACGonly:
                    if seq1 == 'CG':
                        finalOutput.append((gpos,ll))
                    elif seq1[0] == 'A':
                        finalOutput.append((gpos,ll))
                elif dom6Aonly:
                    if seq1[0] == 'A':
                        finalOutput.append((gpos,ll))
                elif doTonly:
                    if seq1[0] == 'T':
                        finalOutput.append((gpos,ll))
                else:
                    finalOutput.append((gpos,ll))
            if strand == '-':
                if GConly:
                    if seq1 == 'GC':
                        finalOutput.append((gpos,ll))
                elif CGonly:
                    if seq2 == 'CG':
                        finalOutput.append((gpos,ll))
                elif CGGConly:
                    if seq2 == 'CG' or seq1 == 'GC':
                        finalOutput.append((gpos,ll))
                elif dom6AGCCGonly:
                    if seq2 == 'CG' or seq1 == 'GC':
                        finalOutput.append((gpos,ll))
                    elif seq1[0] == 'T':
                        finalOutput.append((gpos,ll))
                elif dom6AGConly:
                    if seq1 == 'GC':
                        finalOutput.append((gpos,ll))
                    elif seq1[0] == 'T':
                        finalOutput.append((gpos,ll))
                elif dom6ACGonly:
                    if seq2 == 'CG':
                        finalOutput.append((gpos,ll))
                    elif seq1[0] == 'T':
                        finalOutput.append((gpos,ll))
                elif dom6Aonly:
                    if seq1[0] == 'T':
                        finalOutput.append((gpos,ll))
                elif doTonly:
                    if seq1[0] == 'A':
                        finalOutput.append((gpos,ll))
                else:
                    finalOutput.append((gpos,ll))

        print 'finalOutput', len(finalOutput), strand, ID

        if len(finalOutput) == 0:
            print 'skipping', ID, strand, fields[5], 'due to no aligned informatie positions found'
            continue

        finalOutput.sort()

        outline = chr + '\t' + str(finalOutput[0][0]) + '\t' + str(finalOutput[-1][0]) + '\t' + strand + '\t' + ID + '\t' + '.'
        Ps = ''
        LLs = ''
        for (gpos,ll) in finalOutput:
            Ps = Ps + str(gpos) + ','
            LLs = LLs + str(ll) + ','
        outline = outline + '\t' + Ps[0:-1]
        outline = outline + '\t' + LLs[0:-1]
        outfile.write(outline + '\n')

    outfile.close()

    
run()
