##################################
#                                #
# Last modified 5/6/2009         # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import matplotlib
matplotlib.use('Agg')
from pylab import *
import matplotlib.axes 
from matplotlib import *
import matplotlib.colorbar 
import matplotlib.patches
import math
from sets import Set
from cistematic.core import Genome


try:
    import psyco
    psyco.full()
except:
    pass

import sys

if len(sys.argv) < 4:
    print 'usage: python %s listofallgenes chromosome(s) <comma separated, no spaces> expressiondata bindingsitesFile genome outputPictureName' [-PlotAllChromosomes] [-RPKM] [-FoldChange] % sys.argv[0]
    sys.exit(1)

doAll = False
doXY = False
doRPKM = True
doFold = False
if '-PlotAllChromosomes' in sys.argv:
    doAll = True
if ('-RPKM' in sys.argv) and ('-FoldChange' in sys.argv):
    print 'either RPKM or FoldChange required' % sys.argv[0]
    sys.exit(1)    
if '-RPKM' in sys.argv:
    doRPKM = True
if '-FoldChange' in sys.argv:
    doFold = True
    doRPKM = False

listofallgenesFile = sys.argv[1]
chrom = sys.argv[2]
chromargs = chrom.split(',')
chromosomes = []
for c in chromargs:
    chr = 'chr'+str(c)
    chromosomes.append(chr)
expression = sys.argv[3]
bindingSitesFile = sys.argv[4]
genome = sys.argv[5]
outfilename = sys.argv[6]

hg = Genome(genome)

genes = {}
expressionvalues = []
foldchangevalues = []
genelistFile = open(listofallgenesFile)
genelist = genelistFile.readlines()
for gene in genelist:
    gene = gene.split('\n')[0].split('\t')
    genes[gene[0]]=[]
    genes[gene[0]].append(gene[1])
    genes[gene[0]].append(float(gene[2]))
    genes[gene[0]].append(float(gene[3]))
    genes[gene[0]].append(gene[4])

expressionFile = open(expression)
expressiondata = expressionFile.readlines()
for line in expressiondata:
    fields = line.split('\n')[0].split('\t')
    genename = fields[0]
    if genename in genes.keys():
        if float(fields[2])==0:
            fields[2]=0.01
        if float(fields[4])==0:
            fields[4]=0.01
        genes[genename].append(float(fields[2]))
        genes[genename].append(float(fields[4]))
        expressionvalues.append(math.log(float(fields[2]),2))
        foldchangevalues.append(math.log(float(fields[4]),2))

logmaxRPKM = max(expressionvalues)
logminRPKM = min(expressionvalues)

if doFold:
    logmaxRPKM = max(foldchangevalues)
    logminRPKM = min(foldchangevalues)

print 'len(genelist)', len(genelist)
print 'len(expressiondata)', len(expressiondata)
print 'len(genes)', len(genes)

allchromosomes = []

for gene in genes.keys():
    allchromosomes.append(genes[gene][0])
allchromosomes = list(Set(allchromosomes))
allchromosomes.sort()

if doAll:
    print 'doAll'
    chromosomes = allchromosomes

lenchromlist = []
genome = {}
for chr in chromosomes:
    genome[chr] = []

genomebindingsites = {}
for chr in chromosomes:
    genomebindingsites[chr] = []
print 'genomebindingsites', genomebindingsites

######################################## parse binding sites data  ##########################################33
bindingSitesData = open(bindingSitesFile)
bindingSitesList = bindingSitesData.readlines()
for line in bindingSitesList:
    if line[0] != '#':
        fields = line.split('\t')
        if fields[1] in genomebindingsites.keys():
            y = len(genomebindingsites[fields[1]])
            genomebindingsites[fields[1]].append(y) 
            genomebindingsites[fields[1]][y] = []
            genomebindingsites[fields[1]][y].append(float(fields[2]))
            genomebindingsites[fields[1]][y].append(float(fields[3]))

for gene in genes.keys():
    chr = genes[gene][0]
    if chr in chromosomes:
        genome[chr].append(gene)

print 'chromosomes:', chromosomes
print 'total number of chromosomes:', len(allchromosomes)

F = gcf()
DPI = F.get_dpi()
F.set_figsize_inches(70, 35)
ax2 = pylab.axes([0.9, 0.05, 0.05, 0.45])
#ax2 = F.add_axes([0.9, 0.05, 0.05, 0.45])
boundary2 = abs(logminRPKM)/(logmaxRPKM-logminRPKM)
boundary1 = boundary2-0.1
boundary3 = boundary2+0.1
print 'boundary1', boundary1
print 'boundary2', boundary2
print 'boundary3', boundary3

ax = F.add_axes([0.0, 0.0, 1.0, 1.0])
ax.set_xlabel('')
ax.set_ylabel('', size=0)
ax.set_axis_bgcolor('black')
ticklocs = []
ax.set_yticks(ticklocs)
ax.set_xticks(ticklocs)

cdict = {'red': ((0.0, 0.0, 0.0), (boundary1, 0.0, 0.0), (boundary2, 0.0, 0.0), (boundary3, 0.6, 0.6) ,(1.0, 1.0, 1.0)),
         'green': ((0.0, 1.0, 1.0), (boundary1, 0.6, 0.6), (boundary2, 0.0, 0.0), (boundary3, 0.0, 0.0), (1.0, 0.0, 0.0)),
         'blue': ((0.0, 0.0, 0.0), (1.0, 0.0, 0.0))}
my_cmap = matplotlib.colors.LinearSegmentedColormap('my_colormap',cdict,1024)

norm = colors.Normalize(vmin=logminRPKM, vmax=logmaxRPKM)
cb2 = colorbar.ColorbarBase(ax2, cmap=my_cmap, norm=norm, orientation='vertical')

F.savefig(outfilename, dpi=200)

pos = 0

for chr in chromosomes:
    pos+=1
    if chr == 'chrX' or chr == 'chrY':
        lenchr = len(hg.getChromosomeSequence(chr[3:len(chr)]))
    else:
        lenchr = len(hg.getChromosomeSequence(int(chr[3:len(chr)])))
    lenchromlist.append(lenchr)
    if len(chromosomes)==1:
        spaceperchrom = 0.1
    else:
        spaceperchrom = 1.0/(len(chromosomes)+1)
    if len(chromosomes) == len(allchromosomes):
        if chr == 'chrX':
            if len(chromosomes)==1:
                verticalPos = 1-(len(chromosomes)*spaceperchrom)
            else: 
                verticalPos = 1-((len(chromosomes)-1)*spaceperchrom) 
        if chr == 'chrY':
            if len(chromosomes)==1:
                verticalPos = 1-(len(chromosomes)*spaceperchrom)
            else: 
                verticalPos = 1-((len(chromosomes))*spaceperchrom) 
        else:
            verticalPos = 1-(int(chr[3:len(chr)])*spaceperchrom)
    else:
        verticalPos = 1-pos*spaceperchrom
    horizontallength = 0.98*float(lenchr)/float(max(lenchromlist))
    ax1 = F.add_axes([0.015, verticalPos, horizontallength, spaceperchrom/2])
    ax1.set_xlabel('', size=40,color='r')
    ax1.set_ylabel(chr, size=60,color='r')
#    ax1.set_axis_bgcolor('white')
    ticklocs = []
    ax1.set_yticks(ticklocs)
    ax1.set_xticks(ticklocs)
#    ax1.set_axis_bgcolor('black')
    print chr
    print 'genes on chromosome', len(genome[chr])
    print 'verticalPos', verticalPos
    k = len(genome[chr])
    for gene in genome[chr]:
        if genes[gene][0] == chr:
            leftPos = (genes[gene][1]/lenchr)*horizontallength
            rightPos = (genes[gene][2]/lenchr)*horizontallength
            ax2 = F.add_axes([0.015+leftPos, verticalPos, rightPos-leftPos, spaceperchrom/4])
            ax2.set_xlabel(gene, fontsize=4,rotation='vertical',color='y')
            rcParams["axes.linewidth"]=0.00
            ax2.set_ylabel('', size=0)
            ax2.set_yticks(ticklocs)
            ax2.set_xticks(ticklocs)
            if (len(genes[gene])==6):
                if doRPKM:
                    lognormvalue = (math.log(genes[gene][4],2)-logminRPKM)/(logmaxRPKM-logminRPKM)
                if doFold:
                    lognormvalue = (math.log(genes[gene][5],2)-logminRPKM)/(logmaxRPKM-logminRPKM)
                r = 0.0
                g = 0.0
                b = 0.0
                if lognormvalue <= boundary1:
                    g=1-lognormvalue
                    if g<0:
                        print lognormvalue
                        print genes[gene]
                        print 'lognormvalue <= boundary1'
                    color = (r,g,b,1.0)
                    ax2.set_axis_bgcolor(color)
                    continue
                if (lognormvalue > boundary1) and (lognormvalue <= boundary2):
                    g=0.6*(1-((lognormvalue-boundary1)/0.1))
                    color = (r,g,b,1.0)
                    ax2.set_axis_bgcolor(color)
                    if g<0:
                        print lognormvalue
                        print genes[gene]
                        print '(lognormvalue > boundary1) and (lognormvalue <= boundary2)'
                    continue
                if (lognormvalue > boundary2) and (lognormvalue <= boundary3):
                    r=0.6*(1-((boundary3-lognormvalue)/0.1))
                    color = (r,g,b,1.0)
                    ax2.set_axis_bgcolor(color)
                    if g<0:
                        print lognormvalue
                        print genes[gene]
                        print '(lognormvalue > boundary2) and (lognormvalue <= boundary3)'
                    continue
                if lognormvalue > boundary3:
                    r=lognormvalue
                    color = (r,g,b,1.0)
                    ax2.set_axis_bgcolor(color)
                    if g<0:
                        print lognormvalue
                        print genes[gene]
                        print 'lognormvalue > boundary3'
                    continue
        else:
            ax2.set_axis_bgcolor('w')
            continue
    rcParams["lines.linewidth"] = 2
    rcParams["lines.color"] = 'k'
    F.savefig(outfilename, dpi=200)
    for y in genomebindingsites[chr]:
        diff = (y[1]-y[0])*10
        leftPos = ((y[0]-diff)/lenchr)*horizontallength
        rightPos = ((y[0]+diff)/lenchr)*horizontallength
        ax2 = F.add_axes([0.015+leftPos, verticalPos+spaceperchrom/4, rightPos-leftPos, spaceperchrom/4])
#        ax2 = F.add_axes([0.01+leftPos, verticalPos, rightPos-leftPos, spaceperchrom/2])
        ax2.set_xlabel('', fontsize=0,rotation='vertical',color='y')
        rcParams["axes.linewidth"]=0.00
        ax2.set_ylabel('', size='x-small')
        ax2.set_yticks(ticklocs)
        ax2.set_xticks(ticklocs)
        ax2.set_axis_bgcolor('black')
    rcParams["lines.linewidth"] = 2
    rcParams["lines.color"] = 'k'
    F.savefig(outfilename, dpi=200)

show()
#F.savefig(outfilename, dpi=200)