##################################
#                                #
# Last modified 5/6/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB

try:
	import psyco
	psyco.full()
except:
	pass

def run():

    if len(sys.argv) < 3:
        print 'usage: python %s genome getallgenesERANGEfile outfilename [-getallgenesfrombed] [-usePeak fieldID] [-displaydistanceto3UTR]' % sys.argv[0]
        sys.exit(1)

    genome = sys.argv[1]
    infilename = sys.argv[2]
    outfilename = sys.argv[3]
    doPeak = False
    show3UTR = False
    frombed = False

    if '-getallgenesfrombed' in sys.argv:
        frombed = True

    if '-usePeak' in sys.argv:
        doPeak = True
        print 'doPeak'
        PeakfieldID = int(sys.argv[sys.argv.index('-usePeak') + 1])

    if '-displaydistanceto3UTR' in sys.argv:
        show3UTR = True

    genes = {}
    hg = Genome(genome)
    idb = geneinfoDB()
    geneinfoDict = idb.getallGeneInfo(genome)
    featDict = hg.getallGeneFeatures()
    geneIDs = featDict.keys()
    i=0
    for k in featDict.keys():
        if i % 1000 == 0:
            print i
        i+=1
        start=0
        stop=0
        if idb.getGeneInfo((genome,k))==[]:
            name = 'LOC'+str(k)
        else:
            name = idb.getGeneInfo((genome,k))[0]
        genes[name]={}
        leftPos=[]
        rightPos=[]
        for feature in featDict[k]:
            leftPos.append(int(feature[2]))
            rightPos.append(int(feature[3]))
#        print name
        genes[name]['geneID']=k
        genes[name]['name']=name
        genes[name]['chromosome']= 'chr'+str(featDict[k][0][1])
        genes[name]['leftPos']=min(leftPos)
        genes[name]['rightPos']=max(rightPos)
        genes[name]['orientation']=str(featDict[k][0][4])

    outfile = open(outfilename, 'w')
    if show3UTR:
        outfile.write('Distance To TSS - second to last column\n')
        outfile.write('Distance To end of 3`UTR - second to last column\n')
    else:
        outfile.write('Distance To TSS - last column\n')
    outfile.write('gene_ID\tgene_name\tchromosome\tgene_start\tgene_end\tgene_orientation\tsite\tsite_start\tsite_end\tDistance_from_TSS\n')


    listoflines = open(infilename)
    lineslist = listoflines.readlines()
    for line in lineslist:
        fields=line.split('\n')[0].split('\t')
        genename = fields[0].split(' ')[0]
        if frombed:
            peak = (int(fields[1]) + int(fields[2]))//2
        else:
            peak = (int(fields[2]) + int(fields[3]))//2
        if genename not in genes.keys():
            print genename, "not found in gene list extracted from cistematic"
            continue
        if doPeak:
            peak = int(fields[PeakfieldID])
        if genes[genename]['orientation']=='F':
            distance = peak-genes[genename]['leftPos']
        if genes[genename]['orientation']=='R':
            distance = genes[genename]['rightPos']-peak
        if show3UTR:
            if genes[genename]['orientation']=='F':
                UTRdistance = peak-genes[genename]['rightPos']
            if genes[genename]['orientation']=='R':
                UTRdistance = genes[genename]['leftPos']-peak
            outfile.write(line.split('\n')[0]+'\t'+str(distance)+'\t'+str(UTRdistance)+'\n')        
        else:
            outfile.write(line.split('\n')[0]+'\t'+str(distance)+'\n')        

    outfile.close()

run()

