##################################
#                                #
# Last modified 03/07/2011       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
from sets import Set
import string
import math
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s <gtf file name> <region calls> <MACS | ERANGE> outfilename [-CISTEMATIC genome] [-refFlat refFlat] [-intron (sorted gtf only)]' % sys.argv[0]
        sys.exit(1)

    gtf = sys.argv[1]
    regions = sys.argv[2]
    Caller=sys.argv[3]
    outputfilename = sys.argv[4]

    doIntron=False
    if '-intron' in sys.argv:
        doIntron=True
        ExonIntronCoordinateDict={}

    doCISTEMATIC=False
    if '-CISTEMATIC' in sys.argv:
        doCISTEMATIC=True
        genome=sys.argv[sys.argv.index('-CISTEMATIC')+1]
        print 'will use CISTEMATIC annotation', genome

    doRefFlat=False
    if '-refFlat' in sys.argv:
        doRefFlat=True
        refFlat=sys.argv[sys.argv.index('-refFlat')+1]
        print 'will use refFlat file', refFlat

    TSSList=[]

    outfile = open(outputfilename, 'w')

    if doCISTEMATIC:
        hg = Genome(genome)
        idb = geneinfoDB()
        geneinfoDict = idb.getallGeneInfo(genome)
        featDict = hg.getallGeneFeatures()
        i=0
        for k in featDict.keys():
            if i % 1000 == 0:
                print len(featDict.keys())-i 
            i+=1
            start=0
            stop=0
            leftPos=[]
            rightPos=[]
            for feature in featDict[k]:
                leftPos.append(int(feature[2]))
                rightPos.append(int(feature[3]))
            chr='chr'+str(featDict[k][0][1])
            orientation=str(featDict[k][0][4])
            if orientation=='+' or orientation=='F':
                TSS=min(leftPos)
            if orientation=='-' or orientation=='R':
                TSS=max(rightPos)
            if idb.getGeneInfo((genome,k))==[]:
                name = 'LOC'+str(k)
            else:
                name = idb.getGeneInfo((genome,k))[0]
            TSSList.append((chr,TSS,orientation,name))
    elif doRefFlat:
        lineslist=open(refFlat)
        for line in lineslist:
            if line.startswith('#'):
                continue
            fields=line.strip().split('\t')
            chr=fields[2]
            name=fields[0]+'::'+fields[1]
            orientation=fields[3]
            if orientation=='+':
                TSS=int(fields[4])
            if orientation=='-':
                TSS=int(fields[5])
            TSSList.append((chr,TSS,orientation,name))
    else:
        lineslist=open(gtf)
        for line in lineslist:
            if line.startswith('#'):
                continue
            fields=line.strip().split('\t')
            if fields[2]!='transcript':
                continue
            chr=fields[0]
            if doIntron:
                left=int(fields[3])
                left=int(fields[4])
                if ExonIntronCoordinateDict.has_key(chr):
                    pass
                elif len(ExonIntronCoordinateDict.keys())==0:
                    AnnotationDict={}
                    for s in range(left,righT)
                else:
                    
            orientation=fields[6]
            if 'gene_name' in fields[8]:
                name=fields[8].split('gene_name "')[1].split('";')[0]
            else:
                name=fields[8].split('gene_id "')[1].split('";')[0]
            if orientation=='+':
                TSS=int(fields[3])
            if orientation=='-':
                TSS=int(fields[4])
            TSSList.append((chr,TSS,orientation,name))

    print 'finished importing annotation'

    TSSList=list(Set(TSSList))
    TSSList.sort()

    TSSDict={}
    for (chr,TSS,orientation,name) in TSSList:
        if TSSDict.has_key(chr):
            pass
        else:
            TSSDict[chr]=[]
        TSSDict[chr].append((TSS,orientation,name))

    lineslist=open(regions)
    for line in lineslist:
        if line.startswith('#'):
            continue
        fields=line.strip().split('\t')
        if Caller=='MACS':
            if fields[0]=='chr':
                continue
            chr=fields[0]
            peak=int(fields[1])+int(fields[4])
        if Caller=='ERANGE':
            chr=fields[1]
            peak=int(fields[9])
        closest=9000000000
        closestTSS=''
        if TSSDict.has_key(chr):
            pass
        else:
            continue
        for (TSS,orientation,name) in TSSDict[chr]:
            distance=math.fabs(TSS-peak)
            if distance<=closest:
                closest=distance
                closestTSS=(TSS,orientation,name)
        (TSS,orientation,name)=closestTSS
        if orientation=='+' or orientation=='F':
            distance=peak-TSS
        if orientation=='-' or orientation=='R':
            distance=TSS-peak
        outline=str(distance)+'\t'+name+'\t'+line
        outfile.write(outline)

    print 'done'

    outfile.close()

run()

