##################################
#                                #
# Last modified 2025/03/03       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
import pysam
import numpy as np
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s minimap_sam dorado_bam outfilename [-minReadLen bp] [-minGeneCov fraction]' % sys.argv[0]
        print '\t use - for stdin for the minimap SAM file'
        print '\tit is assumed the BAM file has the polyA tail lengths encoded as pt:i tags'
        sys.exit(1)

    SAM = sys.argv[1]
    BAM = sys.argv[2]
    outfilename = sys.argv[3]

    minGeneCov = 0
    minGeneCovSkipped = 0
    if '-minGeneCov' in sys.argv:
        minGeneCov = float(sys.argv[sys.argv.index('-minGeneCov') + 1])

    print 'minimum gene coverage:', minGeneCov

    minRL = 0
    minRLSkipped = 0
    if '-minReadLen' in sys.argv:
        minRL = int(sys.argv[sys.argv.index('-minReadLen') + 1])

    print 'minimum read length:', minRL

    readPADict = {}
    readGeneDict = {}

    i=0
    skipped=0
    samfile = pysam.Samfile(BAM, "rb", check_sq=False)
    for read in samfile.fetch(until_eof=True):
        i+=1
        if i % 1000000 == 0:
            print 'inputting raw reads', str(i/1000000) + 'M reads processed'
        fields=str(read).split('\t')
        ID = read.qname
        sequence = fields[9]
        if len(sequence) < minRL:
            minRLSkipped += 1
            continue
        try:
            PAtail = read.opt('pt')
        except:
            skipped += 1
            continue
        if readPADict.has_key(ID):
            print 'duplicate read ID detected, exiting'
            print ID
            sys.exit(1)
        readPADict[ID] = (PAtail,len(sequence))

    print 'found', i, ' reads in the BAM file'
    print 'skipped', skipped, ' reads without a pt:i tag'
    print 'skipped', minRLSkipped, ' reads shorter than the minimum read length, which is set to', minRL

    TranscriptDict = {}

    TotalReads = 0
    L = 0    
    doStdIn = False
    if SAM != '-':
        if SAM.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + SAM
        elif SAM.endswith('.gz'):
            cmd = 'gunzip -c ' + SAM
        else:
            cmd = 'cat ' + SAM
        p = os.popen(cmd, "r")
    else:
        doStdIn = True
    line = 'line'
    while line != '':
        if doStdIn:
            line = sys.stdin.readline()
        else:
            line = p.readline()
        if line == '':
            break
        if line.startswith('[M::'):
            continue
        if line.startswith('@SQ'):
            fields = line.strip().split('\t')
            TID = fields[1][3:]
            TIDlen = int(fields[2][3:])
            TranscriptDict[TID] = TIDlen
            continue
        if line.startswith('@PG'):
            continue
        fields = line.strip().split('\t')
        if len(fields) < 9:
            continue
        L += 1.
        if L % 1000000 == 0:
            print str(L/1000000) + 'M alignments processed'
        readID = fields[0]
        geneID = fields[2]

        if geneID != '*' and readPADict.has_key(readID):

            geneLen = TranscriptDict[geneID]
#            print geneLen, readPADict[readID][1]
            GeneCov = readPADict[readID][1]/(geneLen + 0.0)
            if GeneCov < minGeneCov:
                minGeneCovSkipped += 1
                continue

            if readGeneDict.has_key(readID):
                pass
            else:
                if readPADict.has_key(readID):
                    pass
                else:
                    continue            
                readGeneDict[readID] = []
                TotalReads += 1
            readGeneDict[readID].append(geneID)

    print 'skipped', minGeneCovSkipped, ' reads shorter than the minimum transcript coverage treshold, which is set to', minGeneCov

    print 'found', TotalReads, 'aligned reads passing the minimum length requirement'
    print len(readGeneDict.keys())

    outfile = open(outfilename,'w')

    outline = '#geneID\tTPM\tmean_tail_length\tstdev_tail_length\tmean_read_length\tread_length_stdev\ttail_lengths'
    outfile.write(outline + '\n')

    NormFactor = TotalReads/1000000.

    GeneDict = {}

    for readID in readGeneDict.keys():
        (PA,RL) = readPADict[readID]
        for geneID in readGeneDict[readID]:
            if GeneDict.has_key(geneID):
                pass
            else:
                GeneDict[geneID] = {}
                GeneDict[geneID]['PA'] = []
                GeneDict[geneID]['RL'] = []
            GeneDict[geneID]['PA'].append(int(PA))
            GeneDict[geneID]['RL'].append(int(RL))

    genes = GeneDict.keys()
    genes.sort()

    for geneID in genes:
        TPM = len(GeneDict[geneID]['PA'])/NormFactor
        TLmean = np.mean(GeneDict[geneID]['PA'])
        TLstd = np.std(GeneDict[geneID]['PA'])
        RLmean = np.mean(GeneDict[geneID]['RL'])
        RLstd = np.std(GeneDict[geneID]['RL'])
        outline = geneID + '\t' + str(TPM) + '\t' + str(TLmean) + '\t' + str(TLstd) + '\t' + str(RLmean) + '\t' + str(RLstd) + '\t'
        for TL in GeneDict[geneID]['PA']:
            outline = outline + str(TL) + ','
        outline = outline[:-1]
        outfile.write(outline + '\n')
        
    outfile.close()

        
run()

