##################################
#                                #
# Last modified 2022/05/09       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import os
import numpy as np
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s tailfindr outputfilename' % sys.argv[0]
        print '\tassumed input format:'
        print '\tread_id,tail_start,tail_end,samples_per_nt,tail_length,file_path,transcript_id,mapping_quality,sam_flag'
        sys.exit(1)

    datafilename = sys.argv[1]

    O2NDict = {}

    GeneDict = {}

    TotalReads = 0
    Skipped = 0

    doStdIn = False
    if datafilename != '-':
        if datafilename.endswith('.bz2'):
            cmd = 'bzip2 -cd ' + datafilename
        elif datafilename.endswith('.gz'):
            cmd = 'gunzip -c ' + datafilename
        else:
            cmd = 'cat ' + datafilename
        p = os.popen(cmd, "r")
    else:
        doStdIn = True
    line = 'line'
    while line != '':
        if doStdIn:
            line = sys.stdin.readline()
        else:
            line = p.readline()
        if line == '':
            break
        if line.startswith('read_id,tail_start'):
            continue
        TotalReads+=1
        fields = line.strip().split(',')
        geneID = fields[6]
        try:
            TL = float(fields[4])
        except:
#            print 'skipping:', fields
            Skipped+=1
            continue
        if GeneDict.has_key(geneID):
            pass
        else:
            GeneDict[geneID] = []
        GeneDict[geneID].append(TL)

    print 'skipped:', Skipped, 'alignemnts out of', TotalReads, 'total alignemnts, fraction skipped:', Skipped/(0.0+TotalReads)

    outfile = open(sys.argv[2],'w')

    outline = '#geneID\tTPM\tmean_tail_length\tstdev_tail_length\ttail_lengths'
    outfile.write(outline + '\n')

    NormFactor = TotalReads/1000000.

    genes = GeneDict.keys()
    genes.sort()

    for geneID in genes:
        TPM = len(GeneDict[geneID])/NormFactor
        TLmean = np.mean(GeneDict[geneID])
        TLstd = np.std(GeneDict[geneID])
        outline = geneID + '\t' + str(TPM) + '\t' + str(TLmean) + '\t' + str(TLstd) + '\t'
        for TL in GeneDict[geneID]:
            outline = outline + str(TL) + ','
        outline = outline[:-1]
        outfile.write(outline + '\n')
        
    outfile.close()

        
run()

