##################################
#                                #
# Last modified 02/28/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import time
from sets import Set
import sys
import string
import math
import random
from cistematic.core import Genome
from cistematic.core.geneinfo import geneinfoDB
from commoncode import *
import sys
import pickle

def getDistance(gene1,gene2):

    startA = int(gene1[1])
    startB = int(gene2[1])
    endA = int(gene1[2])
    endB = int(gene2[2])
    distance = min(abs(startA-endB),abs(startB-endA),abs(endA-endB),abs(startA-startB))
    return distance

def getClusters(genes, expressedGenes, distance, doRequireAdjacency):

    t1 = time.time()

    clusters = []
    if doRequireAdjacency:
        currentcluster = []
        currentcluster.append(expressedGenes[0])
        lastExpressedGene = expressedGenes[0]
        for (chr,start,stop,orientation,name,RPKM) in expressedGenes:
            currentGene=(chr,start,stop,orientation,name,RPKM)
            if getDistance(lastExpressedGene,currentGene) < distance and genes.index(currentGene)==genes.index(lastExpressedGene)+1:
                currentcluster.append(currentGene)
            else:
                if len(currentcluster)>1:
                    clusters.append(currentcluster)
                currentcluster=[]
            lastExpressedGene=currentGene
        
    else:
        currentcluster = []
        currentcluster.append(expressedGenes[0])
        lastExpressedGene = expressedGenes[0]
        for (chr,start,stop,orientation,name,RPKM) in expressedGenes:
            currentGene=(chr,start,stop,orientation,name,RPKM)
            if getDistance(lastExpressedGene,currentGene) < distance:
                currentcluster.append(currentGene)
            else:
                if len(currentcluster)>1:
                    clusters.append(currentcluster)
                currentcluster=[]
            lastExpressedGene=currentGene

    t2 = time.time()
    print 'time', t2-t1

    return clusters

def shuffleGenes(genes, RPKMlist, minRPKM):

    newExpressedGenes=[]
    newGenes=[]
    random.shuffle(RPKMlist)
    for i in range(len(genes)):
        (chr,start,stop,orientation,name,RPKM)=genes[i]
        newGenes.append((chr,start,stop,orientation,name,RPKMlist[i]))
        if RPKMlist[i]>minRPKM:
            newExpressedGenes.append(genes[i])

    return (newGenes,newExpressedGenes)

def run():

    try:
        import psyco
        psyco.full()
    except:
        print 'psyco not running'
    import sys

    if len(sys.argv) < 8:
        print 'usage: python %s genome RPKMfile geneField RPKMfield minRPKM distance number_Iterations outfile [-AdjacencyRequired] [-customAnnotation annotationFileName annotationExpressionLevel]'
        sys.exit(1)

    genome = sys.argv[1]
    RPKMfile = sys.argv[2]
    geneField = int(sys.argv[3])
    RPKMField = int(sys.argv[4])
    minRPKM = float(sys.argv[5])
    distance = int(sys.argv[6])
    numIterations = int(sys.argv[7])
    outfilename = sys.argv[8]
    doRequireAdjacency = False
    if '-AdjacencyRequired' in sys.argv:
        doRequireAdjacency=True

    outfile=open(outfilename,'w')

    listofgenesfile = open(RPKMfile)
    lineslist = listofgenesfile.readlines()
    RPKMDict={}
    for line in lineslist:
        if line[0]=='#':
            continue
        fields=line.strip().split('\t')    
        RPKMDict[fields[geneField]]=float(fields[RPKMField])

    genes=[]
    expressedGenes=[]
    RPKMList=[]
    hg = Genome(genome)
    idb = geneinfoDB()
    geneinfoDict = idb.getallGeneInfo(genome)
    featDict = hg.getallGeneFeatures()
    geneIDs = featDict.keys()
    i=0
    outfile.write('GeneID\tGeneName\tChr\tStart\tEnd\tOrientation\tRPKM\n')
    for k in featDict.keys():
        if i % 1000 == 0:
            print len(featDict.keys())-i 
        i+=1
        if idb.getGeneInfo((genome,k))==[]:
            name = 'LOC'+str(k)
        else:
            name = idb.getGeneInfo((genome,k))[0]
        leftPos=[]
        rightPos=[]
        for feature in featDict[k]:
            leftPos.append(int(feature[2]))
            rightPos.append(int(feature[3]))
        chr= 'chr'+str(featDict[k][0][1])
        orientation=str(featDict[k][0][4])
        if orientation=='F' or orientation=='+':
            start=min(leftPos)
            stop=max(rightPos)
        if orientation=='R' or orientation=='-':
            stop=min(leftPos)
            start=max(rightPos)
        try:
            RPKMList.append(RPKMDict[name])
            genes.append((chr,start,stop,orientation,name,RPKMDict[name]))
            if RPKMDict[name]>minRPKM:
                expressedGenes.append((chr,start,stop,orientation,name,RPKMDict[name]))
        except:
            continue
   
    genes.sort()
    expressedGenes.sort()

    clusters = getClusters(genes, expressedGenes, distance, doRequireAdjacency)
    outfile.write(str(len(clusters))+' clusters found\n')
    outfile.write('Number of genes in clusters:\n')
    outline='['
    total=0
    for cluster in clusters:
        c=str(len(cluster))
        outline=outline+c+',' 
        total+=len(cluster)
    outfile.write('Total genes: '+str(total)+'\n')
    outline=outline+']\n'
    outfile.write(outline)
    outfile.write('-------------------------------------------------------------------------------------\n')
    outfile.write('Clusters:\n')
    for cluster in clusters:
        outline=''
        for (chr,start,stop,orientation,name,RPKM) in cluster:
            outline=outline+name+','
        outfile.write(outline+'\n')
        
    outfile.write('---------------------------------------------------------------------------------------')
    
    randomClusterCount=[]
    randomClusteredGenes=[]
    for j in range(numIterations):
        print j
        (newGenes,newExpressedGenes)=shuffleGenes(genes, RPKMList, minRPKM)
        clusters=getClusters(newGenes, newExpressedGenes, distance, doRequireAdjacency)
        randomClusterCount.append(len(clusters))
        numGenes=0
        for cluster in clusters:
            numGenes+=len(cluster)
        randomClusteredGenes.append(numGenes)

    outfile.write('Randomized genes clusters found distribution: \n')
    outline=str(randomClusterCount)+'\n'
    outfile.write(outline)
    outfile.write('Randomized genes clustered genes found distribution: \n')
    outline=str(randomClusteredGenes)+'\n'
    outfile.write(outline)

    outfile.close()

run()