##################################
#                                #
# Last modified 5/6/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

from sets import Set
import string
import random
import math


def run():

    try:
        import psyco
        psyco.full()
    except:
        print 'psyco not running'
    import sys
    if len(sys.argv) < 7:
        print 'usage: python %s sortedListOfAllGenes bindingSitesFile expressionDataFile radius minimalExpressionLevel minimalFoldCchange outputFile'
        sys.exit(1)
    listOfGenesFilename = sys.argv[1]
    bindingSitesFile = sys.argv[2]
    expressionDataFile = sys.argv[3]
    radius = int(sys.argv[4])
    minExpression = float(sys.argv[5])
    minFoldChange = float(sys.argv[6])
    print 'minFoldChange', minFoldChange
    outfilename = sys.argv[7]

    expressionDataFile = open(expressionDataFile)
    expressionData = expressionDataFile.readlines()
    RegulatedGenes = []

# take all expressed genes 
    for line in expressionData:
        fields = line.split('\t')
        foldChange = float(fields[4])
        dataPoint1 = float(fields[1])
        dataPoint2 = float(fields[2])
        if ((foldChange < 1/minFoldChange) or (foldChange > minFoldChange)) and ((dataPoint1 > minExpression) or (dataPoint2 > minExpression)):
            geneName = fields[0]
            RegulatedGenes.append(geneName)

# take only those genes and their coordinates that are up or downregulated:
    GeneListFile = open(listOfGenesFilename)
    listOfGenes = {'chr1':{}, 'chr2':{}, 'chr3':{}, 'chr4':{}, 'chr5':{}, 'chr6':{}, 'chr7':{}, 'chr8':{}, 'chr9':{}, 'chr10':{}, 'chr11':{}, 'chr12':{}, 'chr13':{}, 'chr14':{}, 'chr15':{}, 'chr16':{}, 'chr17':{},
                   'chr18':{}, 'chr19':{}, 'chrX':{}, 'chrY':{}}
    GeneList = GeneListFile.readlines()
    for line in GeneList:
        fields = line.split('\t')
        geneName = fields[0]
        if (geneName in RegulatedGenes) and (fields[1] in listOfGenes.keys()):
            chromosome = fields[1]
            listOfGenes[chromosome][geneName]=[]
            listOfGenes[chromosome][geneName].append(int(fields[2]))
            listOfGenes[chromosome][geneName].append(int(fields[3]))

    print 'all', sum([len(listOfGenes[chromosome]) for chromosome in listOfGenes.keys()])

# parse binding sites data
    bindingSitesData = open(bindingSitesFile)
    bindingSitesList = bindingSitesData.readlines()
    bindingSites = []
    lineindex = -1
    for line in bindingSitesList:
        if line[0] != '#':
            lineindex += 1
            bindingSites.append(lineindex) 
            fields = line.split('\t')
            bindingSites[lineindex] = []
            bindingSites[lineindex].append(fields[1])
            bindingSites[lineindex].append(int(fields[2]))
            bindingSites[lineindex].append(int(fields[3]))

    outfile = open(outfilename,'w')

    for site in bindingSites:
        if site[0] in listOfGenes.keys():
            chromosome = site[0]
            siteStart = site[1]
            siteEnd = site[2]
            distances = {} 
            for geneName in listOfGenes[chromosome].keys():
                geneStart = listOfGenes[chromosome][geneName][0]
                geneEnd = listOfGenes[chromosome][geneName][0]
                if (((geneStart < siteStart) and (siteStart < geneEnd)) or ((geneStart < siteEnd) and (siteEnd < geneEnd))):
                    outfile.write(geneName)
                    outfile.write('\t')
                    outfile.write(chromosome)
                    outfile.write('\t')
                    outfile.write(str(siteStart))
                    outfile.write('\t')
                    outfile.write(str(siteEnd))
                    outfile.write('\n')
                else:
                    dist = min(abs(siteStart-geneEnd),abs(geneStart-siteEnd),abs(siteEnd-geneEnd),abs(geneStart-siteStart))
                    distances[dist]=geneName
            if min(distances.keys()) <= radius:
                outfile.write(distances[min(distances.keys())])
                outfile.write('\t')
                outfile.write(chromosome)
                outfile.write('\t')
                outfile.write(str(siteStart))
                outfile.write('\t')
                outfile.write(str(siteEnd))
                outfile.write('\n')

    outfile.close()

#    boundGenes = Set(boundGenes)
#    allgeneslist = Set(allgeneslist)
#    boundGenes = Set.intersection(boundGenes,allgeneslist)
#    print 'len(boundGenes)', len(boundGenes)
        
	
run()