##################################
#                                #
# Last modified 08/25/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s list-of-files outfilename [-rightShift bp] [-leftShift bp]' % sys.argv[0]
        print '      Note: list-of-files format: filename <tab> 1,2,3,4 (fields positions in the chr,left,right,strand order)'
        print '            or                    filename <tab> bed12 (if the file is a TopHat junctions.bed file)'
        sys.exit(1)

    inputfilename = sys.argv[1]
    outfilename = sys.argv[2]

    leftShift = 0
    rightShift = 0

    if '-leftShift' in sys.argv:
        leftShift = int(sys.argv[sys.argv.index('-leftShift') + 1])
    if '-rightShift' in sys.argv:
        rightShift = int(sys.argv[sys.argv.index('-rightShift') + 1])


    listoflines = open(inputfilename)
    FileDict={}
    for line in listoflines:
        fields=line.strip().split('\t')
        if len(fields) < 2:
            print 'skipping', fields
        name=fields[0]
        if fields[1]=='bed12':
            FileDict[name]='bed12'
            continue
        else:
            fields=fields[1].split(',')
            FieldsList=[]
            for i in fields:
                FieldsList.append(int(i))
            FileDict[name]=FieldsList
    
    JunctionsList=[]
    for file in FileDict.keys():
        print file
        listoflines = open(file)
        fieldIDs=FileDict[file]
        if fieldIDs=='bed12':
            for line in listoflines:
                fields=line.strip().split('\t')
                try:
                    chr=fields[0]
                    strand=fields[5]
                    offsets=fields[10].split(',')
                    left=int(fields[1]) + int(offsets[0])-1
                    left = left + leftShift
                    right=int(fields[2]) - int(offsets[1])
                    right = right + rightShift
                    JunctionsList.append((chr,left,right,strand))
                except:
                    print 'skipping', fields
        else:
            for line in listoflines:
                fields=line.strip().split('\t')
                try:
                    chr=fields[fieldIDs[0]]
                    left=int(fields[fieldIDs[1]])
                    left = left + leftShift
                    right=int(fields[fieldIDs[2]])
                    right = right + rightShift
                    strand=fields[fieldIDs[3]]
                    JunctionsList.append((chr,left,right,strand))
                except:
                    print 'skipping', fields
        print len(JunctionsList)

    JunctionsList=list(Set(JunctionsList))
    JunctionsList.sort()

    print len(JunctionsList)

    outfile = open(outfilename, 'w')

    for (chr,left,right,strand) in JunctionsList:
        outline=chr+'\t'+str(left)+'\t'+str(right)+'\t'+strand
        outfile.write(outline+'\n')

    outfile.close()

run()

