##################################
#                                #
# Last modified 5/6/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB

try:
	import psyco
	psyco.full()
except:
	pass

def getSequence(genome,chromosome,start,stop,sense):
    
    hg = Genome(genome)
    DNA = {'A':'T','T':'A','G':'C','C':'G','N':'N'}
    chromosome = chromosome[3:len(chromosome)]
    if sense=='F':
        sequence = string.upper(hg.sequence(chromosome,start,stop-start))
    if sense=='R':
        preliminarysequence = string.upper(hg.sequence(chromosome,start,stop-start))
        sequence=''
        for i in range(len(preliminarysequence)):
            sequence=sequence+DNA[preliminarysequence[len(preliminarysequence)-i-1]]
    
    return sequence

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s genome radius listofgenes/[-cistematic] listofsites outfilename [-getAllGenes] [-doNearestOnly listofsubsetofgenes/[-all]] [-upstreamOnly] [-distancetoTSS] [-usePeak fieldID]' % sys.argv[0]
        sys.exit(1)

    cachePages = 2000000

    genome = sys.argv[1]
    radius = int(sys.argv[2])
    listofgenesfilename = sys.argv[3]
    listofsitesfilename = sys.argv[4]
    outfilename = sys.argv[5]

    doNearestOnly = False
    doUpstream = False
    doTSS = False
    doPeak=False
    doGetGenes=False
    if '-getAllGenes' in sys.argv:
        doGetGenes=True
    if '-upstreamOnly' in sys.argv:
        doUpstream = True
        print 'doUpstream'
    if '-doNearestOnly' in sys.argv:
        doNearestOnly = True
        if sys.argv[sys.argv.index('-doNearestOnly') + 1]=='-all':
            genesubsetfile='-all'
        else:
            genesubsetfile = open(sys.argv[sys.argv.index('-doNearestOnly') + 1])
        print 'doNearestOnly'
    if '-usePeak' in sys.argv:
        doPeak = True
        print 'doPeak'
        PeakfieldID = int(sys.argv[sys.argv.index('-usePeak') + 1])
    if '-distancetoTSS' in sys.argv:
        doTSS = True
        print 'get sites closest to TSS, not to gene'




    genes = {}
    if listofgenesfilename == '-cistematic':
        hg = Genome(genome)
        idb = geneinfoDB()
        geneinfoDict = idb.getallGeneInfo(genome)
        featDict = hg.getallGeneFeatures()
        geneIDs = featDict.keys()
        i=0
        for k in featDict.keys():
            if i % 1000 == 0:
                print i
            i+=1
            start=0
            stop=0
            genes[k]={}
            if idb.getGeneInfo((genome,k))==[]:
                name = 'LOC'+str(k)
            else:
                name = idb.getGeneInfo((genome,k))[0]
            leftPos=[]
            rightPos=[]
            for feature in featDict[k]:
                leftPos.append(int(feature[2]))
                rightPos.append(int(feature[3]))
            genes[k]['geneID']=k
            genes[k]['name']=name
            genes[k]['chromosome']= 'chr'+str(featDict[k][0][1])
            genes[k]['leftPos']=min(leftPos)
            genes[k]['rightPos']=max(rightPos)
            genes[k]['orientation']=str(featDict[k][0][4])
       
    else:
        listofgenesfile = open(listofgenesfilename)
        lineslist = listofgenesfile.readlines()
        for line in lineslist:
            fields = line.split('\n')[0].split('\t')
            genes[fields[0]] = {}
            genes[fields[0]]['geneID']=fields[0]
            genes[fields[0]]['name']=fields[1]
            genes[fields[0]]['chromosome']= fields[2]
            genes[fields[0]]['leftPos']=int(fields[3])
            genes[fields[0]]['rightPos']=int(fields[4])
            genes[fields[0]]['orientation']=fields[5]

    if doNearestOnly:    
        if genesubsetfile=='-all':
            genesubset = genes
        else:
            lineslist = genesubsetfile.readlines()
            genesubset = {}
            for line in lineslist:
                fields = line.split('\n')[0].split('\t')
                genesubset[fields[0]] = {}
                genesubset[fields[0]]['geneID']=fields[0]
                genesubset[fields[0]]['name']=fields[1]
        genomemask = {}
        for geneID in genes.keys():
            genomemask[genes[geneID]['chromosome']]=[]
        for geneID in genesubset.keys():
            genomemask[genes[geneID]['chromosome']].append(geneID)
   
    listofsitesfile = open(listofsitesfilename)
    lineslist = listofsitesfile.readlines()
    listofsites = {}

    for line in lineslist:
        if line[0]=='#' or line[0]==' ' or line[0]=='\t':
            continue
        fields = line.split('\n')[0].split('\t')
        listofsites[fields[0]]={}
        listofsites[fields[0]]['name']=fields[0]
        listofsites[fields[0]]['chromosome'] = fields[1]
        listofsites[fields[0]]['start'] = int(fields[2])
        listofsites[fields[0]]['stop'] = int(fields[3])
        listofsites[fields[0]]['peakPos'] = (int(fields[2]) + int(fields[3]))//2
        if doPeak:
            listofsites[fields[0]]['peakPos'] = int(fields[PeakfieldID])
        if doNearestOnly:
            listofsites[fields[0]]['nearestGeneID']=0
            distance=radius
            if listofsites[fields[0]]['chromosome'] not in genomemask.keys():
                continue
            for geneID in genomemask[listofsites[fields[0]]['chromosome']]:
                if doTSS:
                    if genes[geneID]['orientation']=='F':
                        k = math.fabs(listofsites[fields[0]]['peakPos']-genes[geneID]['leftPos'])
                        if k <= distance:
                            distance=k
                            listofsites[fields[0]]['nearestGeneID']=geneID
                    if genes[geneID]['orientation']=='R':
                        k = math.fabs(listofsites[fields[0]]['peakPos']-genes[geneID]['rightPos'])
                        if k <= distance:
                            distance=k
                            listofsites[fields[0]]['nearestGeneID']=geneID 
                else:
                    k = min(math.fabs(listofsites[fields[0]]['peakPos']-genes[geneID]['leftPos']),math.fabs(listofsites[fields[0]]['peakPos']-genes[geneID]['rightPos']))
                    if k <= distance:
                        distance=k
                        listofsites[fields[0]]['nearestGeneID']=geneID

    if doNearestOnly:
        count = 0
        for site in listofsites.keys():
            if listofsites[site]['nearestGeneID']!=0:
                count+=1
        print 'number of sites with genes', count

    outfile = open(outfilename, 'w')
    outfile.write('gene_ID\tgene_name\tchromosome\tgene_start\tgene_end\tgene_orientation\tsite\tsite_start\tsite_end\tDistance_from_TSS\tsequence\n')

    if doGetGenes:
        for geneID in genes.keys():
            for siteID in listofsites.keys():
                if listofsites[siteID]['nearestGeneID']==geneID:
                    seq = getSequence(genome,genes[geneID]['chromosome'],listofsites[siteID]['start'],listofsites[siteID]['stop'],genes[geneID]['orientation'])
                    outfile.write('%s\t%s\t%s\t%d\t%d\t%s\t%s\t%d\t%d\t%d\t%s\n' % (genes[geneID]['geneID'], genes[geneID]['name'], genes[geneID]['chromosome'], genes[geneID]['leftPos'], genes[geneID]['rightPos'], genes[geneID]['orientation'], listofsites[siteID]['name'], listofsites[siteID]['start'], listofsites[siteID]['stop'],listofsites[siteID]['peakPos']-genes[geneID]['leftPos'],seq))
        outfile.close()
        sys.exit(1)             

    k=0
    if radius==0:
        for geneID in genes.keys():
            print geneID
            print 'k2', k
            k+=1
            if genes[geneID]['orientation']=='F':
                for siteID in listofsites.keys():
                    if genes[geneID]['chromosome']==listofsites[siteID]['chromosome']:
                        if genes[geneID]['leftPos']>=listofsites[siteID]['start'] and genes[geneID]['leftPos']<=listofsites[siteID]['stop']:
                            seq = getSequence(genome,genes[geneID]['chromosome'],listofsites[siteID]['start'],listofsites[siteID]['stop'],genes[geneID]['orientation'])
                            outfile.write('%s\t%s\t%s\t%d\t%d\t%s\t%s\t%d\t%d\t%d\t%s\n' % (genes[geneID]['geneID'], genes[geneID]['name'], genes[geneID]['chromosome'], genes[geneID]['leftPos'], genes[geneID]['rightPos'], genes[geneID]['orientation'], listofsites[siteID]['name'], listofsites[siteID]['start'], listofsites[siteID]['stop'],listofsites[siteID]['peakPos']-genes[geneID]['leftPos'],seq))
                    else:
                        continue
            if genes[geneID]['orientation']=='R':
                for siteID in listofsites.keys():
                    if genes[geneID]['chromosome']==listofsites[siteID]['chromosome']:
                        if genes[geneID]['rightPos']>=listofsites[siteID]['start'] and genes[geneID]['rightPos']<=listofsites[siteID]['stop']:
                            seq = getSequence(genome,genes[geneID]['chromosome'],listofsites[siteID]['start'],listofsites[siteID]['stop'],genes[geneID]['orientation'])
                            outfile.write('%s\t%s\t%s\t%d\t%d\t%s\t%s\t%d\t%d\t%d\t%s\n' % (genes[geneID]['geneID'], genes[geneID]['name'], genes[geneID]['chromosome'], genes[geneID]['leftPos'], genes[geneID]['rightPos'], genes[geneID]['orientation'], listofsites[siteID]['name'], listofsites[siteID]['start'], listofsites[siteID]['stop'],genes[geneID]['rightPos']-listofsites[siteID]['peakPos'],seq))
                    else:
                        continue
    else:
        print 'radius != 0'
        if doUpstream:
            k=0
            print 'doUpstreamOnly'
            for geneID in genes.keys():
                print 'geneID', geneID
                print k
                k+=1
                if genes[geneID]['orientation']=='F':
                    for siteID in listofsites.keys():
                        if genes[geneID]['chromosome']==listofsites[siteID]['chromosome']:
                            if doNearestOnly:
                                if ((genes[geneID]['leftPos']-radius<=listofsites[siteID]['stop']) and (genes[geneID]['leftPos']>=listofsites[siteID]['start'])) and (listofsites[siteID]['nearestGeneID']==geneID) and listofsites[siteID]['nearestGeneID']!=0:
                                    seq = getSequence(genome,genes[geneID]['chromosome'],listofsites[siteID]['start'],listofsites[siteID]['stop'],genes[geneID]['orientation'])
                                    outfile.write('%s\t%s\t%s\t%d\t%d\t%s\t%s\t%d\t%d\t%d\t%s\n' % (genes[geneID]['geneID'], genes[geneID]['name'], genes[geneID]['chromosome'], genes[geneID]['leftPos'], genes[geneID]['rightPos'], genes[geneID]['orientation'], listofsites[siteID]['name'], listofsites[siteID]['start'], listofsites[siteID]['stop'],listofsites[siteID]['peakPos']-genes[geneID]['leftPos'],seq))
                                else:
                                    continue
                            else:
                                if (genes[geneID]['leftPos']-radius<=listofsites[siteID]['stop']) and (genes[geneID]['leftPos']>=listofsites[siteID]['start']):
                                    seq = getSequence(genome,genes[geneID]['chromosome'],listofsites[siteID]['start'],listofsites[siteID]['stop'],genes[geneID]['orientation'])
                                    outfile.write('%s\t%s\t%s\t%d\t%d\t%s\t%s\t%d\t%d\t%d\t%s\n' % (genes[geneID]['geneID'], genes[geneID]['name'], genes[geneID]['chromosome'], genes[geneID]['leftPos'], genes[geneID]['rightPos'], genes[geneID]['orientation'], listofsites[siteID]['name'], listofsites[siteID]['start'], listofsites[siteID]['stop'],listofsites[siteID]['peakPos']-genes[geneID]['leftPos'],seq))
                                else:
                                    continue
                        else:
                            continue
                if genes[geneID]['orientation']=='R':
                    for siteID in listofsites.keys():
                        if genes[geneID]['chromosome']==listofsites[siteID]['chromosome']:
                            if doNearestOnly:
                                if (genes[geneID]['rightPos']<=listofsites[siteID]['stop']) and (genes[geneID]['rightPos']+radius>=listofsites[siteID]['start']) and (listofsites[siteID]['nearestGeneID']==geneID) and listofsites[siteID]['nearestGeneID']!=0:
                                    seq = getSequence(genome,genes[geneID]['chromosome'],listofsites[siteID]['start'],listofsites[siteID]['stop'],genes[geneID]['orientation'])
                                    outfile.write('%s\t%s\t%s\t%d\t%d\t%s\t%s\t%d\t%d\t%d\t%s\n' % (genes[geneID]['geneID'], genes[geneID]['name'], genes[geneID]['chromosome'], genes[geneID]['leftPos'], genes[geneID]['rightPos'], genes[geneID]['orientation'], listofsites[siteID]['name'], listofsites[siteID]['start'], listofsites[siteID]['stop'],genes[geneID]['rightPos']-listofsites[siteID]['peakPos'],seq))
                                else:
                                    continue
                            else:
                                if (genes[geneID]['rightPos']<=listofsites[siteID]['stop']) and (genes[geneID]['rightPos']+radius>=listofsites[siteID]['start']):
                                    seq = getSequence(genome,genes[geneID]['chromosome'],listofsites[siteID]['start'],listofsites[siteID]['stop'],genes[geneID]['orientation'])
                                    outfile.write('%s\t%s\t%s\t%d\t%d\t%s\t%s\t%d\t%d\t%d\t%s\n' % (genes[geneID]['geneID'], genes[geneID]['name'], genes[geneID]['chromosome'], genes[geneID]['leftPos'], genes[geneID]['rightPos'], genes[geneID]['orientation'], listofsites[siteID]['name'], listofsites[siteID]['start'], listofsites[siteID]['stop'],genes[geneID]['rightPos']-listofsites[siteID]['peakPos'],seq))
                                else:
                                    continue
                        else:
                            continue
        else:
            print 'upstream and downstream'
            k=0
            for geneID in genes.keys():
#                print k
                k+=1
                if genes[geneID]['orientation']=='F':
                    for siteID in listofsites.keys():
                        if genes[geneID]['chromosome']==listofsites[siteID]['chromosome']:
                            if doNearestOnly:
                                if ((genes[geneID]['leftPos']-radius<=listofsites[siteID]['stop']) and (genes[geneID]['rightPos']>=listofsites[siteID]['start'])) and (listofsites[siteID]['nearestGeneID']==geneID) and listofsites[siteID]['nearestGeneID']!=0:
                                    seq = getSequence(genome,genes[geneID]['chromosome'],listofsites[siteID]['start'],listofsites[siteID]['stop'],genes[geneID]['orientation'])
                                    outfile.write('%s\t%s\t%s\t%d\t%d\t%s\t%s\t%d\t%d\t%d\t%s\n' % (genes[geneID]['geneID'], genes[geneID]['name'], genes[geneID]['chromosome'], genes[geneID]['leftPos'], genes[geneID]['rightPos'], genes[geneID]['orientation'], listofsites[siteID]['name'], listofsites[siteID]['start'], listofsites[siteID]['stop'],listofsites[siteID]['peakPos']-genes[geneID]['leftPos'],seq))
                                else:
                                    continue
                            else:
                                if ((genes[geneID]['leftPos']-radius)<=listofsites[siteID]['stop'] and (genes[geneID]['rightPos'])>=listofsites[siteID]['start']):
                                    seq = getSequence(genome,genes[geneID]['chromosome'],listofsites[siteID]['start'],listofsites[siteID]['stop'],genes[geneID]['orientation'])
                                    outfile.write('%s\t%s\t%s\t%d\t%d\t%s\t%s\t%d\t%d\t%d\t%s\n' % (genes[geneID]['geneID'], genes[geneID]['name'], genes[geneID]['chromosome'], genes[geneID]['leftPos'], genes[geneID]['rightPos'], genes[geneID]['orientation'], listofsites[siteID]['name'], listofsites[siteID]['start'], listofsites[siteID]['stop'],listofsites[siteID]['peakPos']-genes[geneID]['leftPos'],seq))
                                else:
                                    continue
                        else:
                            continue
                if genes[geneID]['orientation']=='R':
                    for siteID in listofsites.keys():
                        if genes[geneID]['chromosome']==listofsites[siteID]['chromosome']:
                            if doNearestOnly:
                                if (genes[geneID]['rightPos']+radius>=listofsites[siteID]['start'] and (genes[geneID]['leftPos'])<=listofsites[siteID]['stop']) and (listofsites[siteID]['nearestGeneID']==geneID) and listofsites[siteID]['nearestGeneID']!=0:
                                    seq = getSequence(genome,genes[geneID]['chromosome'],listofsites[siteID]['start'],listofsites[siteID]['stop'],genes[geneID]['orientation'])
                                    outfile.write('%s\t%s\t%s\t%d\t%d\t%s\t%s\t%d\t%d\t%d\t%s\n' % (genes[geneID]['geneID'], genes[geneID]['name'], genes[geneID]['chromosome'], genes[geneID]['leftPos'], genes[geneID]['rightPos'], genes[geneID]['orientation'], listofsites[siteID]['name'], listofsites[siteID]['start'], listofsites[siteID]['stop'],genes[geneID]['rightPos']-listofsites[siteID]['peakPos'],seq))
                                else:
                                    continue
                            else:
                                if ((genes[geneID]['rightPos']+radius)>=listofsites[siteID]['start'] and (genes[geneID]['leftPos'])<=listofsites[siteID]['stop']):
                                    seq = getSequence(genome,genes[geneID]['chromosome'],listofsites[siteID]['start'],listofsites[siteID]['stop'],genes[geneID]['orientation'])
                                    outfile.write('%s\t%s\t%s\t%d\t%d\t%s\t%s\t%d\t%d\t%d\t%s\n' % (genes[geneID]['geneID'], genes[geneID]['name'], genes[geneID]['chromosome'], genes[geneID]['leftPos'], genes[geneID]['rightPos'], genes[geneID]['orientation'], listofsites[siteID]['name'], listofsites[siteID]['start'], listofsites[siteID]['stop'],genes[geneID]['rightPos']-listofsites[siteID]['peakPos'],seq))
                                else:
                                    continue
                        else:
                            continue

run()