##################################
#                                #
# Last modified 02/28/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import time
from sets import Set
import string
import random
import math
import sys
import pickle

def getUniqueElements(seq):
    
    checked = []
    for e in seq:
        if e not in checked:
            checked.append(e)
    return checked


def getDistance(geneA, geneB):

    startA = int(geneA[2])
    startB = int(geneB[2])
    endA = int(geneA[3])
    endB = int(geneB[3])
    distance = min(abs(startA-endB),abs(startB-endA),abs(endA-endB),abs(startA-startB))
    return distance


# ATTENTION!!!! this version of the function only works with a sorted list of all genes (genes on the same chromosome and in order, from a sorted gene list file)

def getClusters(boundGenes, listOfGenes, distance):

# listOfGenes structure: list of lists with the following structure:
# 'gene name'        'chrN'             'Start Position'   'End Position'
# listOfGenes[i][0]  listOfGenes[i][1]  listOfGenes[i][2]  listOfGenes[i][3]  

    t1 = time.time()
    clusters = []

# conditions: 
# 1. genes has to be adjacent. 2. both genes are in the list of 'bound genes', 3. Distance is smaller than the threshold
#    
    clusters = []
    current = 0
    currentcluster = []
    currentcluster.append(listOfGenes[0][0])
    for i in range(1,len(listOfGenes)):
        if ((listOfGenes[current][0] in boundGenes) and (listOfGenes[i][0] in boundGenes) and (listOfGenes[current][1] == listOfGenes[i][1]) and (getDistance(listOfGenes[current],listOfGenes[i]) <= distance)):
            currentcluster.append(listOfGenes[i][0])
            current = i
        else:
            if len(currentcluster)>1:
                clusters.append(currentcluster)
            currentcluster = []
            currentcluster.append(listOfGenes[i][0])
            current = i

    finalclusters = clusters
    t2 = time.time()
    print 'time', t2-t1

    return finalclusters


def getClustersNoAdjacencyRequired(boundGenes, listOfGenes, distance):

# listOfGenes structure: list of lists with the following structure:
# 'gene name'        'chrN'             'Start Position'   'End Position'
# listOfGenes[i][0]  listOfGenes[i][1]  listOfGenes[i][2]  listOfGenes[i][3]  

    t1 = time.time()
    clusters = []

# conditions: 
# 1. genes need not be adjacent. 2. both genes are in the list of 'bound genes', 3. Distance is smaller than the threshold
#    
    clusters = []
    current = 0
    currentcluster = []
    for i in range(0,len(listOfGenes)):
        if (listOfGenes[i][0] not in boundGenes):
            continue
        if ((listOfGenes[i][0] in boundGenes) and (len(currentcluster)==0)):
            currentcluster.append(listOfGenes[i][0])
            current = i
            continue
        if ((getDistance(listOfGenes[current],listOfGenes[i]) <= distance)):
            currentcluster.append(listOfGenes[i][0])
            current = i
            continue
        if (getDistance(listOfGenes[current],listOfGenes[i]) > distance):
            if len(currentcluster)>1:
                clusters.append(currentcluster)
            currentcluster = []
            currentcluster.append(listOfGenes[i][0])
            current = i
            continue

    t2 = time.time()
    print 'time', t2-t1

    return clusters



# written for listOfGenes being a list!!	
def getSamplingDistribution(boundGenes, listOfGenes, numIterations, distance):

#Sampling Distribution: a list of the numebr of genes in clusters generated each iteration

    t1 = time.time()
    samplingDistribution = []
    for i in range(numIterations):
         a = random.sample(xrange(len(listOfGenes)), len(boundGenes))
         newlist = []
         a.sort()
         for j in a:
             newlist.append(listOfGenes[j][0])
         clusters = getClusters(newlist, listOfGenes, distance)
         samplingDistribution.append(sum([len(cluster) for cluster in clusters]))
         print 'sampling iteration', i
    samplingDistribution.sort()
    t2 = time.time()
    print 'total iteration time', t2-t1
	
    return samplingDistribution

def getSamplingDistributionNoAdjacencyRequired(boundGenes, listOfGenes, numIterations, distance):

#Sampling Distribution: a list of the numebr of genes in clusters generated each iteration

    t1 = time.time()
    samplingDistribution = []
    for i in range(numIterations):
         a = random.sample(xrange(len(listOfGenes)), len(boundGenes))
         newlist = []
         a.sort()
         for j in a:
             newlist.append(listOfGenes[j][0])
         clusters = getClustersNoAdjacencyRequired(newlist, listOfGenes, distance)
         samplingDistribution.append(sum([len(cluster) for cluster in clusters]))
         print 'sampling iteration', i
    samplingDistribution.sort()
    t2 = time.time()
    print 'total iteration time', t2-t1
	
    return samplingDistribution


def samplingMean(samplingDistribution):

    samplingMean = float(sum(samplingDistribution))/len(samplingDistribution)

    return(samplingMean)
    
def samplingVariance(samplingDistribution):

    sdcalculation = 0
    samplingMean = float(sum(samplingDistribution))/len(samplingDistribution)
    for i in samplingDistribution:
        sdcalculation += (i-samplingMean)*(i-samplingMean)
    samplingVariance = sdcalculation/len(samplingDistribution)
    return(samplingVariance)

def run():

    try:
        import psyco
        psyco.full()
    except:
        print 'psyco not running'
    import sys

    if len(sys.argv) < 5:
        print 'usage: python %s genesAssociatedWithBindingSitesFile ListOfGenesInTheGenomAssemblyFile number_Iterations distance outfile [-AdjacencyRequired]'
        sys.exit(1)

    boundGenesFilename = sys.argv[1]
    listOfGenesFilename = sys.argv[2]
    numIterations = int(sys.argv[3])
    distance = int(sys.argv[4])
    outfilename = sys.argv[5]
    doRequireAdjacency = False

    if '-AdjacencyRequired' in sys.argv:
        doRequireAdjacency = True

    boundGenesFile = open(boundGenesFilename)
    listOfGenesFile = open(listOfGenesFilename)
    boundGenes = []
    listOfGenes = {}
	
    for line in boundGenesFile:
        fields = line.split('\t')
        fields = fields[0].split(' ')
        boundGenes.append(fields[0])

# boundGene structure: list of gene names

    boundGenes = getUniqueElements(boundGenes)
    print 'boundGenes', len(boundGenes)

    GeneListFile = open(listOfGenesFilename)
    listOfGenes = []
    GeneList = GeneListFile.readlines()
    allgeneslist = []
    for i in range(len(GeneList)):
        line = GeneList[i]
        fields = line.split('\t')
        listOfGenes.append(i)
        listOfGenes[i] = []
        listOfGenes[i].append(fields[0])
        allgeneslist.append(fields[0])
        listOfGenes[i].append(fields[1])
        listOfGenes[i].append(int(fields[2]))
        listOfGenes[i].append(int(fields[3]))

# listOfGenes structure: list of lists with the following structure:
# 'gene name'        'chrN'             'Start Position'   'End Position'
# listOfGenes[i][0]  listOfGenes[i][1]  listOfGenes[i][2]  listOfGenes[i][3]  

    print 'len(listOfGenes)', len(listOfGenes)

    boundGenes = Set(boundGenes)
    allgeneslist = Set(allgeneslist)
    boundGenes = list(Set.intersection(boundGenes,allgeneslist))
    print 'len(boundGenes)', len(boundGenes)

    outfile = open(outfilename, 'w')
				
    if doRequireAdjacency:
        clusters = getClusters(boundGenes, listOfGenes, distance)
        samplingDistribution = getSamplingDistribution(boundGenes, listOfGenes, numIterations, distance)
    else:
        clusters = getClustersNoAdjacencyRequired(boundGenes, listOfGenes, distance)
        samplingDistribution = getSamplingDistributionNoAdjacencyRequired(boundGenes, listOfGenes, numIterations, distance)

    print 'Number of clusters', len(clusters)
    print 'Number of genes in clusters', str(sum([len(cluster) for cluster in clusters]))
    print 'Sampling Mean (number of genes inclusters)', samplingMean(samplingDistribution)
    print 'Sampling Variance (number of genes inclusters)', samplingVariance(samplingDistribution)
    print 'Sampling SD (number of genes inclusters)', math.sqrt(samplingVariance(samplingDistribution))

    outfile.write('Sampling Mean ')
    outfile.write(str(samplingMean(samplingDistribution)))
    outfile.write('\nSampling Variance ')
    outfile.write(str(samplingVariance(samplingDistribution)))
    outfile.write('\nSampling SD ')
    outfile.write(str(math.sqrt(samplingVariance(samplingDistribution))))
    outfile.write('\n\n')
    outfile.write('Number of clusters ')
    outfile.write(str(len(clusters)))
    outfile.write('\nNumber of genes in clusters ')
    outfile.write(str(sum([len(cluster) for cluster in clusters])))
    outfile.write('\n\n')
   
    for i in range(len(clusters)):
        line = '\nCluster%d\t' % (i)
        outfile.write(line)
        for j in range(len(clusters[i])):
            outfile.write(clusters[i][j])
            outfile.write('\t')
    outfile.write('\n')
    outfile.write('\n')
    outfile.write('samplingDistribution ')
    for i in samplingDistribution:
        outfile.write(str(i))
        outfile.write('\t')
    outfile.write('\n')
    outfile.write('\n')
    clusterlengthcounts = {}
    clusterlengthslist = []
    for cluster in clusters:
        clusterlengthslist.append(len(cluster))
    for i in range(min(clusterlengthslist),max(clusterlengthslist)+1):
        clusterlengthcounts[str(i)]=0
    for i in range(len(clusterlengthslist)):
        clusterlengthcounts[str(len(clusters[i]))]+=1
    for k in clusterlengthcounts.keys():
        outfile.write('\nLength of clusters / Number of clusters of that length\t')
        outfile.write(k)
        outfile.write('\t')
        outfile.write(str(clusterlengthcounts[k]))

    outfile.close()
	
run()