##################################
#                                #
# Last modified 2020/12/15       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
import numpy
import gzip
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s  datafilename datafileFieldID1 outfilename [-fraction total_counts|nan] [-sortCounts] [-bins size min max] [-specificbins (0),number1,number2,number3...,numberN] [-pvalues] [-RegionLength leftCoordinateField] [-fields ID1,ID2,..IDN] [-splitby string] [-nosort] [-gtfIDR]' % sys.argv[0]
        print '\tnote use _s_ if you want to split by " ' 
        print '\tnote use _ss_ if you want to split by space' 
        print '\tuse - for input if you want to read from standard input' 
        print '\tuse - if you want to print to standard output' 
        sys.exit(1)

    datafilename = sys.argv[1]
    fieldID = int(sys.argv[2])
    outfilename = sys.argv[3]

    doStdOut = False
    if outfilename == '-':
        doStdOut = True
    else:
        outfile = open(outfilename, 'w')

    dopvalues = False
    if '-pvalues' in sys.argv:
        if doStdOut:
            pass
        else:
            print 'treating data as pvalues'
        dopvalues = True

    doFrac = False
    if '-fraction' in sys.argv:
        TotalFrac = sys.argv[sys.argv.index('-fraction') + 1]
        if TotalFrac != 'nan':
            TotalFrac = int(TotalFrac)
        print 'will output fractions, assuming total counts', TotalFrac
        doFrac = True

    doSortCounts = False
    if '-sortCounts' in sys.argv:
        print 'will sort output'
        doSortCounts = True

    doGTFIDR=False
    if '-gtfIDR' in sys.argv:
        if doStdOut:
            pass
        else:
            print 'treating data as a GTF file with an npIDR score'
        doGTFIDR=True

    doBins=False
    if '-bins' in sys.argv:
        doBins=True
        binsize=float(sys.argv[sys.argv.index('-bins') + 1])
        minB=float(sys.argv[sys.argv.index('-bins') + 2])
        maxB=float(sys.argv[sys.argv.index('-bins') + 3])
        if doStdOut:
            pass
        else:
            print 'will split data into bins, size:', binsize, 'from', minB, 'to', maxB

    doFields=False
    if '-fields' in sys.argv:
        if doStdOut:
            pass
        else:
            print 'will use fields', sys.argv[sys.argv.index('-fields') + 1], 'as data'
        doFields=True
        fields=sys.argv[sys.argv.index('-fields') + 1]
        fields=fields.split(',')
        IDfields=[]
        if doStdOut:
            pass
        else:
            print fields
        for ID in fields:
            IDfields.append(int(ID))

    doSort=True
    if '-nosort' in sys.argv:
        if doStdOut:
            pass
        else:
            print 'will not sort output'
        doSort=False

    doSpecificBins=False
    if '-specificbins' in sys.argv:
        if doStdOut:
            pass
        else:
            print 'will split data into bins'
        doSpecificBins=True
        bins=sys.argv[sys.argv.index('-specificbins') + 1]
        bins=bins.split(',')
        binList=[]
        for bin in bins:
            binList.append(float(bin))
        if doStdOut:
            pass
        else:
            print binList

    doRegionLength=False
    if '-RegionLength' in sys.argv:
        if doStdOut:
            pass
        else:
            print 'will output region length'
        doRegionLength=True
        RLField=int(sys.argv[sys.argv.index('-RegionLength') + 1])
        if doStdOut:
            pass
        else:
            print 'RLField', RLField

    splitBy='\t'
    if '-splitby' in sys.argv:
        splitBy=sys.argv[sys.argv.index('-splitby') + 1]
        if splitBy == '_s_':
            splitBy='"' 
        if splitBy == '_ss_':
            splitBy=' ' 

    HistDict={}
    if datafilename == '-':
        lineslist  = sys.stdin
    else:
        if datafilename.endswith('.gz'):
            lineslist  = gzip.open(datafilename)
        else:
            lineslist  = open(datafilename)
    t=0
    for line in lineslist:
        t+=1
        if t % 1000000 == 0:
            if doStdOut:
                pass
            else:
                print t, 'lines processed'
        if line[0]=='#':
            continue
        fields = line.strip().split(splitBy)
        for i in range(len(fields)):
            fields[i]=fields[i].strip()
        if len(fields)<fieldID+1:
            continue
        if doBins or doSpecificBins:
            if doRegionLength:
                try:
                    num=int(fields[RLField+1])-int(fields[RLField])
                except:
                    print 'problem', line
                    continue
            elif doGTFIDR:
                num=float(fields[8].split('npIDR "')[1].split('";')[0])
            else:
                try:
                    num=float(fields[fieldID])
                except:
                    print 'problem with:               ', line
                    continue
        else:
            if doFields:
                try:    
                    num=[]
                    for ID in IDfields:
                        num.append(fields[ID])
                    num=tuple(num)
                except:
                    print 'skipping line:', line.strip()
                    continue
            else:
                num=fields[fieldID]  
        if dopvalues:
            num=fields[fieldID].split('-')[len(fields[fieldID].split('-'))-1]
            num='E-'+num
        if doRegionLength:
            try:
                num=int(fields[RLField+1])-int(fields[RLField])
            except:
                print 'problem', line
                continue
        if HistDict.has_key(num):
            HistDict[num]+=1
        else:
            HistDict[num]=1

    if doFrac:
        if TotalFrac == 'nan':
            TotalFrac = t

    OutList = []
    
    if doBins:
        keys=HistDict.keys()
        keys.sort()
        NewHistDict={}
        k=minB
        np = numpy.arange(minB,maxB,binsize)
        for k in np:
            newk = round(k,len(str(binsize).split('.')[1]))
            NewHistDict[newk]=0
        NewHistDict[maxB]=0
        NewHistDict[minB]=0
        i=0
        if doStdOut:
            pass
        else:
            print NewHistDict
        for key in HistDict:
            i+=1
            if i % 1000 == 0:
                if doStdOut:
                    pass
                else:
                    print i
            if key <= minB:
                NewHistDict[minB] += HistDict[key]
            elif key >= maxB:
                NewHistDict[maxB] += HistDict[key]
            else:
                k=minB
                while k < maxB:
                    if key >= k and key < k + binsize:
                        NewHistDict[k] += HistDict[key]
                        break
                    else:
                        k=k + binsize
                        k = round(k,len(str(binsize).split('.')[1]))
        keys=NewHistDict.keys()
        keys.sort()
        for key in keys:
            OutList.append((NewHistDict[key],str(key)))
    elif doSpecificBins:
        keys=HistDict.keys()
        keys.sort()
        binList.append(max(keys))
        NewHistDict={}
        for i in binList:
            NewHistDict[i]=0
        for key in HistDict:
            if key >= max(binList):
                NewHistDict[max(binList)]+=HistDict[key]
            else:
                for i in binList[0:len(binList)-1]:
                    if key >= i and key < binList[binList.index(i)+1]:
                        NewHistDict[i]+=HistDict[key]
        keys=NewHistDict.keys()
        keys.sort()
        for key in keys:
            OutList.append((NewHistDict[key],str(key) ))
    else:
        i=0
        keys=HistDict.keys()
        if doSort:
            keys.sort()
        i+=1
        if doFields:
            if i % 1000000 == 0:
                print i, 'lines outputted'
            for key in keys:
                k=list(key)
                line=''
                for s in k:
                    line=line+s+'\t'
                line = line + str(HistDict[key]) + '\n'
                OutList.append((HistDict[key],line.strip()))
        else:
            if i % 1000000 == 0:
                print i, 'lines outputted'
            for key in keys:
                OutList.append((HistDict[key],str(key)))

    if doSortCounts:
        OutList.sort()
        OutList.reverse()

    for (a,b) in OutList:
        if doFrac:
            A = a/(TotalFrac + 0.0)
        else:
            A = a
        outline = b + '\t' + str(A)
        if doStdOut:
            print outline.strip()
        else:
            outfile.write(outline.strip() + '\n')

    if not doStdOut:
        outfile.close()
        
run()

