##################################
#                                #
# Last modified 5/6/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import math
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB

try:
	import psyco
	psyco.full()
except:
	pass

def getDistance(bindingsite, expressiondatadictionary):

    if (expressiondatadictionary['orientation']=='F'):
        if (bindingsite['start']-expressiondatadictionary['leftPos'] < 0) and (bindingsite['stop']-expressiondatadictionary['leftPos'] < 0):
            distance = bindingsite['stop']-expressiondatadictionary['leftPos']
        else:
            distance = bindingsite['start']-expressiondatadictionary['leftPos']
    if (expressiondatadictionary['orientation']=='R'):
        if (bindingsite['start']-expressiondatadictionary['leftPos'] > 0) and (bindingsite['stop']-expressiondatadictionary['leftPos'] > 0):
            distance = expressiondatadictionary['rightPos']-bindingsite['start']
        else:
            distance = expressiondatadictionary['rightPos']-bindingsite['start']
    return distance

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s  genome listofgenesfile expressiondata myoD_24h_sites myogenin_60h_sites outfilename' % sys.argv[0]
        sys.exit(1)

    genome = sys.argv[1]
    listofgenesfilename = sys.argv[2]
    expressiondatafilename = sys.argv[3]
    myoD24hfilename = sys.argv[4]
    myogenin60hfilename = sys.argv[5]
    outfilename = sys.argv[6]

    hg = Genome(genome)
    idb = geneinfoDB()
    featDict = hg.getallGeneFeatures()

    geneIDlist = []
    listofgenesfile = open(listofgenesfilename)
    lineslist = listofgenesfile.readlines()
    for line in lineslist:
        fields = line.split('\n')[0].split('\t')
        geneID = str(fields[0])
        geneIDlist.append(geneID)
    
    expressiondatafile = open(expressiondatafilename)
    lineslist = expressiondatafile.readlines()
    expressiondata = {}
    for line in lineslist:
        fields = line.split('\n')[0].split('\t')
        geneID = str(fields[0])
        if geneID not in geneIDlist:
            continue
        expressiondata[geneID]={}
        expressiondata[geneID]['geneID']=fields[0]
        expressiondata[geneID]['geneName']=fields[1]
        expressiondata[geneID]['0hRPKM']=float(fields[2])
        expressiondata[geneID]['60hRPKM']=float(fields[3])
        expressiondata[geneID]['difference']=float(fields[4])
        expressiondata[geneID]['foldchange']=float(fields[5])
        expressiondata[geneID]['-1kb-TSS']=[]
        expressiondata[geneID]['TSS-+1kb-']=[]
        expressiondata[geneID]['-10kb-1kb']=[]
        expressiondata[geneID]['-50kb-10kb']=[]
        expressiondata[geneID]['-100kb-10kb']=[]
        expressiondata[geneID]['-500kb-100kb']=[]
        expressiondata[geneID]['genebody']=[]
        expressiondata[geneID]['0-20kbdownstream']=[]
        expressiondata[geneID]['20kb-100kbdownstream']=[]
        ID = str(geneID)
        expressiondata[geneID]['chromosome']= 'chr' + str(featDict[ID][0][1])
        leftPos = []
        rightPos = []
        for feature in featDict[ID]:
            leftPos.append(feature[2])
            rightPos.append(feature[3])
        expressiondata[geneID]['leftPos']=min(leftPos)
        expressiondata[geneID]['rightPos']=max(rightPos)
        expressiondata[geneID]['orientation']=str(featDict[geneID][0][4])   

    myoD24hfile = open(myoD24hfilename)
    lineslist = myoD24hfile.readlines()
    myoD24hsites = {}
    for line in lineslist:
        if line[0]=='#':
            continue
        fields = line.split('\n')[0].split('\t')
        myoD24hsites[fields[0]]={}
        myoD24hsites[fields[0]]['name']=fields[0]
        myoD24hsites[fields[0]]['chromosome'] = fields[1]
        myoD24hsites[fields[0]]['start'] = int(fields[2])
        myoD24hsites[fields[0]]['stop'] = int(fields[3])
        myoD24hsites[fields[0]]['RPM'] = float(fields[4])
        myoD24hsites[fields[0]]['fold'] = float(fields[5])
        myoD24hsites[fields[0]]['fold'] = float(fields[7])
        myoD24hsites[fields[0]]['multi%'] = float(fields[6])
        myoD24hsites[fields[0]]['leftPlus%'] = float(fields[8])
        myoD24hsites[fields[0]]['peakPos'] = int(fields[9])
        myoD24hsites[fields[0]]['peakHeight'] = float(fields[10])

    myogenin60hfile = open(myogenin60hfilename)
    lineslist = myogenin60hfile.readlines()
    myogenin60hsites = {}
    for line in lineslist:
        if line[0]=='#':
            continue
        fields = line.split('\n')[0].split('\t')
        myogenin60hsites[fields[0]]={}
        myogenin60hsites[fields[0]]['name']=fields[0]
        myogenin60hsites[fields[0]]['chromosome'] = fields[1]
        myogenin60hsites[fields[0]]['start'] = int(fields[2])
        myogenin60hsites[fields[0]]['stop'] = int(fields[3])
        myogenin60hsites[fields[0]]['RPM'] = float(fields[4])
        myogenin60hsites[fields[0]]['fold'] = float(fields[5])
        myogenin60hsites[fields[0]]['plus%'] = float(fields[7])
        myogenin60hsites[fields[0]]['multi%'] = float(fields[6])
        myogenin60hsites[fields[0]]['leftPlus%'] = float(fields[8])
        myogenin60hsites[fields[0]]['peakPos'] = int(fields[9])
        myogenin60hsites[fields[0]]['peakHeight'] = float(fields[10])

    k=0
    for geneID in expressiondata.keys():
        print k
        k+=1
        if geneID not in geneIDlist:
            continue
        if (expressiondata[geneID]['foldchange']==0) and (expressiondata[geneID]['difference']>-1.0):
            continue
        if ((math.fabs(expressiondata[geneID]['difference'])>1) and ((expressiondata[geneID]['foldchange']>1.5) or (expressiondata[geneID]['foldchange']<0.7))):
            for site in myogenin60hsites.keys():
                if (myogenin60hsites[site]['chromosome']==expressiondata[geneID]['chromosome']) and (math.fabs(getDistance(myogenin60hsites[site], expressiondata[geneID]))<500000):
                    distance = getDistance(myogenin60hsites[site], expressiondata[geneID])
                    if (distance < 0) and (distance > -1000):
                        expressiondata[geneID]['-1kb-TSS'].append(myogenin60hsites[site]['name'])
                    if (distance > 0) and (distance < 1000):
                        expressiondata[geneID]['TSS-+1kb-'].append(myogenin60hsites[site]['name'])
                    if (distance < -1000) and (distance > -10000):
                        expressiondata[geneID]['-10kb-1kb'].append(myogenin60hsites[site]['name'])
                    if (distance < -10000) and (distance > -50000):
                        expressiondata[geneID]['-50kb-10kb'].append(myogenin60hsites[site]['name'])
                    if (distance < -10000) and (distance > -100000):
                        expressiondata[geneID]['-100kb-10kb'].append(myogenin60hsites[site]['name'])
                    if (distance < -100000) and (distance > -500000):
                        expressiondata[geneID]['-500kb-100kb'].append(myogenin60hsites[site]['name'])
                    if (distance > 0) and (distance < math.fabs(expressiondata[geneID]['rightPos']-expressiondata[geneID]['leftPos'])):
                        expressiondata[geneID]['genebody'].append(myogenin60hsites[site]['name'])
                    if (distance > math.fabs(expressiondata[geneID]['rightPos']-expressiondata[geneID]['leftPos'])) and (distance < math.fabs(expressiondata[geneID]['rightPos']-expressiondata[geneID]['leftPos'])+20000):
                        expressiondata[geneID]['0-20kbdownstream'].append(myogenin60hsites[site]['name'])
                    if (distance > math.fabs(expressiondata[geneID]['rightPos']-expressiondata[geneID]['leftPos'])+20000) and (distance < math.fabs(expressiondata[geneID]['rightPos']-expressiondata[geneID]['leftPos'])+100000):
                        expressiondata[geneID]['20kb-100kbdownstream'].append(myogenin60hsites[site]['name'])
                else:
                    continue
            for site in myoD24hsites.keys():
                if (myoD24hsites[site]['chromosome']==expressiondata[geneID]['chromosome']) and (math.fabs(getDistance(myoD24hsites[site], expressiondata[geneID]))<500000):
                    distance = getDistance(myoD24hsites[site], expressiondata[geneID])
                    if (distance < 0) and (distance > -1000):
                        expressiondata[geneID]['-1kb-TSS'].append(myoD24hsites[site]['name'])
                    if (distance > 0) and (distance < 1000):
                        expressiondata[geneID]['TSS-+1kb-'].append(myoD24hsites[site]['name'])
                    if (distance < -1000) and (distance > -10000):
                        expressiondata[geneID]['-10kb-1kb'].append(myoD24hsites[site]['name'])
                    if (distance < -10000) and (distance > -50000):
                        expressiondata[geneID]['-50kb-10kb'].append(myoD24hsites[site]['name'])
                    if (distance < -10000) and (distance > -100000):
                        expressiondata[geneID]['-100kb-10kb'].append(myoD24hsites[site]['name'])
                    if (distance < -100000) and (distance > -500000):
                        expressiondata[geneID]['-500kb-100kb'].append(myoD24hsites[site]['name'])
                    if (distance > 0) and (distance < math.fabs(expressiondata[geneID]['rightPos']-expressiondata[geneID]['leftPos'])):
                        expressiondata[geneID]['genebody'].append(myoD24hsites[site]['name'])
                    if (distance > math.fabs(expressiondata[geneID]['rightPos']-expressiondata[geneID]['leftPos'])) and (distance < math.fabs(expressiondata[geneID]['rightPos']-expressiondata[geneID]['leftPos'])+20000):
                        expressiondata[geneID]['0-20kbdownstream'].append(myoD24hsites[site]['name'])
                    if (distance > math.fabs(expressiondata[geneID]['rightPos']-expressiondata[geneID]['leftPos'])+20000) and (distance < math.fabs(expressiondata[geneID]['rightPos']-expressiondata[geneID]['leftPos'])+100000):
                        expressiondata[geneID]['20kb-100kbdownstream'].append(myoD24hsites[site]['name'])
                else:
                    continue


    outfile = open(outfilename, 'w')
    outfile.write('geneID')
    outfile.write('\t')
    outfile.write('geneName')
    outfile.write('\t')
    outfile.write('chromosome')
    outfile.write('\t')
    outfile.write('leftPos')
    outfile.write('\t')
    outfile.write('rightPos')
    outfile.write('\t')
    outfile.write('orientation')
    outfile.write('\t')
#    outfile.write('0hRPKM')
#    outfile.write('\t')
    outfile.write('0hRPKM')
    outfile.write('\t')
    outfile.write('60hRPKM')
    outfile.write('\t')
    outfile.write('Net Difference')
    outfile.write('\t')
    outfile.write('Fold Change')
    outfile.write('\t')
    outfile.write('-500kb-100kb')
    outfile.write('\t')
    outfile.write('-100kb-10kb')
    outfile.write('\t')
    outfile.write('-50kb-10kb')
    outfile.write('\t')
    outfile.write('-10kb-1kb')
    outfile.write('\t')
    outfile.write('-1kb-TSS')
    outfile.write('\t')
    outfile.write('TSS-+1kb-')
    outfile.write('\t')
    outfile.write('genebody')
    outfile.write('\t')
    outfile.write('0-20kbdownstream')
    outfile.write('\t')
    outfile.write('20kb-100kbdownstream')
    outfile.write('\n')

    for geneID in expressiondata:
        print expressiondata[geneID]['geneID']
        outfile.write(expressiondata[geneID]['geneID'])
        outfile.write('\t')
        outfile.write(expressiondata[geneID]['geneName'])
        outfile.write('\t')
        outfile.write(expressiondata[geneID]['chromosome'])
        outfile.write('\t')
        outfile.write(str(expressiondata[geneID]['leftPos']))
        outfile.write('\t')
        outfile.write(str(expressiondata[geneID]['rightPos']))
        outfile.write('\t')
        outfile.write(expressiondata[geneID]['orientation'])
        outfile.write('\t')
        outfile.write(str(expressiondata[geneID]['0hRPKM']))
        outfile.write('\t')
        outfile.write(str(expressiondata[geneID]['24hRPKM']))
        outfile.write('\t')
        outfile.write(str(expressiondata[geneID]['difference']))
        outfile.write('\t')
        outfile.write(str(expressiondata[geneID]['foldchange']))
        outfile.write('\t')
        for site in expressiondata[geneID]['-500kb-100kb']:
            outfile.write(site)
            outfile.write(', ')
        outfile.write('\t')
        for site in expressiondata[geneID]['-100kb-10kb']:
            outfile.write(site)
            outfile.write(', ')
        outfile.write('\t')
        for site in expressiondata[geneID]['-50kb-10kb']:
            outfile.write(site)
            outfile.write(', ')
        outfile.write('\t')
        for site in expressiondata[geneID]['-10kb-1kb']:
            outfile.write(site)
            outfile.write(', ')
        outfile.write('\t')
        for site in expressiondata[geneID]['-1kb-TSS']:
            outfile.write(site)
            outfile.write(', ')
        outfile.write('\t')
        for site in expressiondata[geneID]['TSS-+1kb-']:
            outfile.write(site)
            outfile.write(', ')
        outfile.write('\t')
        for site in expressiondata[geneID]['genebody']:
            outfile.write(site)
            outfile.write(', ')
        outfile.write('\t')
        for site in expressiondata[geneID]['0-20kbdownstream']:
            outfile.write(site)
            outfile.write(', ')
        outfile.write('\t')
        for site in expressiondata[geneID]['20kb-100kbdownstream']:
            outfile.write(site)
            outfile.write(', ')
        outfile.write('\n')

run()
