##################################
#                                #
# Last modified 2017/05/30       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s config minClusterCounts outprefix [-leafCutter]' % sys.argv[0]
        print '\tconfig format: label <tab> filename'
        print '\tminCounts refers to the minimum number of counts for a junction cluster'
        sys.exit(1)

    doLC = False
    if '-leafCutter' in sys.argv:
        doLC = True

    config = sys.argv[1]
    minClusterCounts = int(sys.argv[2])
    outfilename = sys.argv[3]

    StrandDict = {'1':'+', '2':'-', '0':'.'}
    JuncTypeDict = {'0': 'non-canonical', '1': 'GT|AG', '2': 'CT|AC', '3': 'GC|AG', '4': 'CT|GC', '5':'AT|AC', '6': 'GT|AT'}

    linelist = open(config)
    JuncDict = {}
    labels = []
    JC = 0
    for lineC in linelist:
        fieldsC = lineC.strip().split('\t')
        label = fieldsC[0]
        labels.append(label)
        file = fieldsC[1]
        linelist2 = open(file)
        for line in linelist2:
            fields = line.strip().split('\t')
            chr = fields[0]
            left = fields[1]
            right = fields[2]
            strand = StrandDict[fields[3]]
            if strand == '.':
                continue
            junctype = JuncTypeDict[fields[4]]
            UniqueCounts = int(fields[6])
            if JuncDict.has_key(chr):
                pass
            else:
                JuncDict[chr] = {}
                JuncDict[chr]['+'] = {}
                JuncDict[chr]['-'] = {}
#                JuncDict[chr]['.'] = {}
            junction = (left,right,junctype)
            if JuncDict[chr][strand].has_key(junction):
                pass
            else:
                JuncDict[chr][strand][junction] = {}
                JC += 1
            JuncDict[chr][strand][junction][label] = UniqueCounts

    print 'finished parsing junction tables'
    print 'found', JC, 'distinct splices'
            
    labels.sort()

    outfile=open(outfilename, 'w')
    if doLC:
        outline = ''
        for label in labels:
            print label
            outline = outline + ' ' + label
        outfile.write(outline.strip() + '\n')
    else:
        outline = '#'
        for label in labels:
            print label
            outline = outline + '\t' + label
        outfile.write(outline + '\n')

    chromosomes = JuncDict.keys()
    chromosomes.sort()

    ClustersCounts = 0

    for chr in chromosomes:
        print chr
        for strand in JuncDict[chr].keys():
            print chr, strand, len(JuncDict[chr][strand].keys())
            donorsDict = {}
            acceptorsDict = {}
            for junction in JuncDict[chr][strand].keys():
                (left,right,junctype) = junction
                if strand == '+':
                    D = left
                    A = right
                if strand == '-':
                    A = left
                    D = right
                if donorsDict.has_key(D):
                    pass
                else:
                    donorsDict[D] = {}
                if acceptorsDict.has_key(A):
                    pass
                else:
                    acceptorsDict[A] = {}
                acceptorsDict[A][D] = JuncDict[chr][strand][junction]
                donorsDict[D][A] = JuncDict[chr][strand][junction]
            donors = donorsDict.keys()
            donors.sort()
            print 'found', len(donors), 'donor sites'
            acceptors = acceptorsDict.keys()
            acceptors.sort()
            print 'found', len(acceptors), 'acceptor sites'
            for D in donors:
                if len(donorsDict[D].keys()) == 1:
                    continue
                maxCC = 0
                InsufficientCounts = True
                for label in labels:
                    CC = 0
                    for A in donorsDict[D].keys():
                        if donorsDict[D][A].has_key(label):
                            CC += donorsDict[D][A][label]
                    if CC > maxCC:
                        maxCC = CC
                    if maxCC >= minClusterCounts:
                        InsufficientCounts = False
                    if not InsufficientCounts:
                        break
                if InsufficientCounts:
                    continue
                ClustersCounts += 1
                for A in donorsDict[D].keys():
                    if doLC:
                        outline = chr + ':' + str(D) + ':' + str(A) + ':clu_' + str(ClustersCounts)
                    else:
                        outline = chr + ':' + str(D) + ':' + str(A) + ':' + strand + ':clu_' + str(ClustersCounts)
                    for label in labels:
                        if donorsDict[D][A].has_key(label):
                            outline = outline + '\t' + str(donorsDict[D][A][label])
                        else:
                            outline = outline + '\t' + str(0)
                    if doLC:
                        outfile.write(outline.replace('\t',' ') + '\n')
                    else:
                        outfile.write(outline + '\n')
                if ClustersCounts % 1000 == 0:
                    print ClustersCounts, 'clusters found'
            for A in acceptors:
                if len(acceptorsDict[A].keys()) == 1:
                    continue
                maxCC = 0
                InsufficientCounts = True
                for label in labels:
                    CC = 0
                    for D in acceptorsDict[A].keys():
                        if acceptorsDict[A][D].has_key(label):
                            CC += acceptorsDict[A][D][label]
                    if CC > maxCC:
                        maxCC = CC
                    if maxCC >= minClusterCounts:
                        InsufficientCounts = False
                    if not InsufficientCounts:
                        break
                if InsufficientCounts:
                    continue
                ClustersCounts += 1
                for D in acceptorsDict[A].keys():
                    if doLC:
                        outline = chr + ':' + str(D) + ':' + str(A) + ':clu_' + str(ClustersCounts)
                    else:
                        outline = chr + ':' + str(D) + ':' + str(A) + ':' + strand + ':clu_' + str(ClustersCounts)
                    for label in labels:
                        if acceptorsDict[A][D].has_key(label):
                            outline = outline + '\t' + str(acceptorsDict[A][D][label])
                        else:
                            outline = outline + '\t' + str(0)
                    if doLC:
                        outfile.write(outline.replace('\t',' ') + '\n')
                    else:
                        outfile.write(outline + '\n')
                if ClustersCounts % 1000 == 0:
                    print ClustersCounts, 'clusters found'
             
    outfile.close()

run()

