##################################
#                                #
# Last modified 2017/08/02       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
from sets import Set

def run():

    if len(sys.argv) < 7:
        print 'usage: python %s TSStable label1 fields1 label2 fields2 minClusterCounts outprefix [-leafCutter]' % sys.argv[0]
        print '\tminCounts refers to the minimum number of counts for a cluster'
        print '\tcluster table format: #chr <tab> left <tab> right <tab> strand <tab> peak <tab> d00.b1::peak d00.b1::counts d00.b1::RPM d00.b1::SI'
        sys.exit(1)

    doLC = False
    if '-leafCutter' in sys.argv:
        doLC = True

    table = sys.argv[1]
    label1 = sys.argv[2]
    fields1 = []
    for ID in sys.argv[3].split(','):
        fields1.append(int(ID))
    label2 = sys.argv[4]
    fields2 = []
    for ID in sys.argv[5].split(','):
        fields2.append(int(ID))
    minClusterCounts = int(sys.argv[6])
    outfilename = sys.argv[7]

    linelist = open(table)
    TSSDict = {}
    T=0
    for line in linelist:
        if line.startswith('#'):
            continue
        fields = line.strip().split('\t')
        chr = fields[0]
        left = fields[1]
        right = fields[2]
        peak = fields[4]
        strand = fields[3]
        gene = fields[5]
        if TSSDict.has_key(gene):
            pass
        else:
            TSSDict[gene] = {}
        TSS = (chr,peak,left,right,strand)
        if TSSDict[gene].has_key(TSS):
            pass
        else:
            TSSDict[gene][TSS] = {}
            TSSDict[gene][TSS][label1] = []
            TSSDict[gene][TSS][label2] = []
        for ID in fields1:
            TSSDict[gene][TSS][label1].append(int(float(fields[ID])))
        for ID in fields2:
            TSSDict[gene][TSS][label2].append(int(float(fields[ID])))
        T+=1

    print 'finished parsing TSS table'
    print 'found', T, 'distinct TSSs'
            
    outfile=open(outfilename, 'w')
    if doLC:
        outline = ''
        for ID in fields1:
            outline = outline + ' ' + label1
        for ID in fields2:
            outline = outline + ' ' + label2
        outfile.write(outline.strip() + '\n')
    else:
        outline = '#'
        for ID in fields1:
            outline = outline + '\t' + label1
        for ID in fields2:
            outline = outline + '\t' + label2
        outfile.write(outline + '\n')

    G=1
    for gene in TSSDict.keys():
        MinCountsPass = False
        CountsCC = 0
        for TSS in TSSDict[gene].keys():
            (chr,peak,left,right,strand) = TSS
            maxCC = 0
            for C in TSSDict[gene][TSS][label1]:
                maxCC = max(maxCC,C)
            for C in TSSDict[gene][TSS][label2]:
                maxCC = max(maxCC,C)
        if maxCC >= minClusterCounts and len(TSSDict[gene].keys()) > 1:
            for TSS in TSSDict[gene].keys():
                (chr,peak,left,right,strand) = TSS
                if doLC:
                    outline = chr + ':' + str(left) + ':' + str(right) + ':clu_' + str(G)
                else:
                    outline = chr + ':' + str(left) + ':' + str(right) + ':' + strand + ':clu_' + str(G)
                for C in TSSDict[gene][TSS][label1]:
                    outline = outline + '\t' + str(C)
                for C in TSSDict[gene][TSS][label2]:
                    outline = outline + '\t' + str(C)
                if doLC:
                    outfile.write(outline.replace('\t',' ') + '\n')
                else:
                    outfile.write(outline + '\n')
            G+=1

    print 'retained', G, 'clusters'

    outfile.close()

run()

