###########################################################################
#                                                                         #
# C O P Y R I G H T   N O T I C E                                         #
#  Copyright (c) 2003-10 by:                                              #
#    * California Institute of Technology                                 #
#                                                                         #
#    All Rights Reserved.                                                 #
#                                                                         #
# Permission is hereby granted, free of charge, to any person             #
# obtaining a copy of this software and associated documentation files    #
# (the "Software"), to deal in the Software without restriction,          #
# including without limitation the rights to use, copy, modify, merge,    #
# publish, distribute, sublicense, and/or sell copies of the Software,    #
# and to permit persons to whom the Software is furnished to do so,       #
# subject to the following conditions:                                    #
#                                                                         #
# The above copyright notice and this permission notice shall be          #
# included in all copies or substantial portions of the Software.         #
#                                                                         #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,         #
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF      #
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND                   #
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS     #
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN      #
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN       #
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE        #
# SOFTWARE.                                                               #
###########################################################################
#
# This class contains the core code for visualizing sequences from several of the experiments.

from cistematic.core.motif import Motif
from cistematic.core.geneinfo import geneinfoDB
import math, random

pilPresent = False

try:
    import Image, ImageDraw, ImageFont
    pilPresent = True
except:
    pass


class Draw:
    """ The Draw class contains the code used to visualize the location of motifs in their 
        genomic neighborhood. It is meant to be used as the parent of other classes, 
        such as the orthology classes. It relies on the python imaging library, and saves
        PNG images.
    """
    drawable = False
    maxWidth = 1200
    leftMargin = 100
    rightMargin = 50
    lineHeight = 110    

    if pilPresent:
        theFont = ImageFont.load_default()
        drawable = True
    else:
        print "Draw: python image library missing - will not be able to draw on this system"


    def draw(self, picName, geneList=[], motifList=[], excludeGeneList=[], excludeMotifList=[],
            showHeader=True, showFooter=True, maxOccurences=100, skipSanity=False):
        """ Draws an image of motifs on the sequences. Can specifically list or exclude genes and/or motifs.
            Options showHeader will add experiment information while showFooter will add a motif key.
            This function will not show motifs that occur more than maxOccurences in the dataset.
        """
        if not self.drawable:
            return

        limitGenes = False
        limitMotifs = False
        if len(geneList) > 0:
            limitGenes = True

        if len(motifList) > 0:
            limitMotifs = True

        idb = geneinfoDB()
        bound = ""
        if self.boundToNextGene:
            bound = "up to "

        expResults = self.getResults()
        (up, cds, down) = self.getSeqParameters()
        hasORF = "NO"
        if cds > 0:
            hasORF = "YES"

        if cds > 1:
            hasORF = "MASKED"

        motColor = {}
        motConsensus = {}
        motNumber = {}
        maxLength = 1
        geneLength = {}
        adjustedMotLength = {}
        datasetIDs = self.getDatasetNames()
        orthologyList = []
        if len(datasetIDs) > 0:
            for datasetName in datasetIDs:
                dataset = self.getSetting(datasetName)
                theList = eval(dataset[0])
                for gene in theList:
                    if limitGenes and gene not in geneList:
                        continue

                    if gene not in excludeGeneList and gene not in orthologyList:
                        orthologyList.append(gene)
        else:
            for gene in self.genepool:
                if limitGenes and gene not in geneList:
                    continue

                if gene not in excludeGeneList and gene not in orthologyList:
                    orthologyList.append(gene)

        for geneID in orthologyList:
            geneLength[geneID] = len(self.genepool[geneID])
            if maxLength < geneLength[geneID]:
                maxLength = geneLength[geneID]

        seqScaler = float(self.maxWidth) / float(maxLength)
        for mot in expResults:
            if limitMotifs and mot.tagID not in motifList:
                continue

            if mot.tagID in excludeMotifList:
                continue

            if len(self.motifToGene(mot.tagID)) >= maxOccurences:
                continue

            if not skipSanity and not mot.isSane():
                continue

            motConsensus[mot.tagID] = mot.buildConsensus()
            currentRed = random.randint(5, 240)
            currentGreen = random.randint(5, 240)
            currentBlue = random.randint(5, 240)
            motColor[mot.tagID] = (currentRed, currentGreen, currentBlue)

        motKeys = motConsensus.keys()
        motKeys.sort()
        motKeysLen = len(motKeys)
        for tagID in motKeys:
            motNumber[tagID] = 0
            adjustedMotLength[tagID] = int(math.ceil(len(self.findMotif(tagID)) * seqScaler))

        numLines = len(orthologyList)
        if showHeader:
            numLines += 1

        footerLines = motKeysLen / 3
        if motKeysLen % 3:
            footerLines += 1

        imsize = (self.maxWidth + self.leftMargin + self.rightMargin, int(round((numLines + footerLines/2.) * self.lineHeight)))
        image = Image.new("RGB", imsize, color="#ffffff")
        draw = ImageDraw.Draw(image)
        currentHeight = 0
        if showHeader:
            line1 = "Experiment: %s in %s Type: %s Analysis: %s" % (self.experimentID, self.expFile, self.experimentType, self.analysisID)
            draw.text([10, 10], line1, font=self.theFont, fill=0)
            draw.line((10, 30, self.maxWidth + self.leftMargin + self.rightMargin - 10, 30), fill=0)
            line2 = "Upstream: %s%s ORF: %s Downstream: %s%s" % (bound, up, hasORF, bound, down)
            draw.text([10, 40], line2, font=self.theFont, fill=0)
            draw.line((10, 60, self.maxWidth + self.leftMargin + self.rightMargin - 10, 60), fill=0)
            currentHeight = self.lineHeight

        for geneID in orthologyList:
            geneNames = ""
            seqLength = geneLength[geneID]
            try:
                res = idb.geneIDSynonyms(geneID)
                for entry in res[1:]:
                    geneNames += "%s " % str(entry)
            except:
                pass

            motList = self.geneToMotif(geneID)
            draw.text([5, currentHeight+45], str(geneID[0]), font=self.theFont, fill=0) 
            draw.text([5, currentHeight+55], str(geneID[1]), font=self.theFont, fill=0) 
            draw.text([5, currentHeight+65], geneNames, font=self.theFont, fill=0) 
            adjustedSeqLength = seqLength * seqScaler
            seqStart = self.leftMargin + int(self.maxWidth) - adjustedSeqLength
            features = self.getFeatures(geneID)
            for (ftype, fstart, fstop, forientation) in features:
                if ftype != "CDS":
                    continue

                if float(fstop) < float(fstart):
                    fstart = fstop
                    fstop = fstart

                start = int(math.floor(float(fstart) * seqScaler)) 
                consLength = int(math.ceil((fstop - fstart) * seqScaler))
                if start + consLength > adjustedSeqLength:
                    consLength = adjustedSeqLength - start

                start += seqStart
                draw.rectangle([start, currentHeight + 51, start + consLength, currentHeight + 64], fill="#aaaaaa")

            draw.rectangle([seqStart, currentHeight + 35, seqStart + adjustedSeqLength, currentHeight + 65], outline=0)
            tagIndex = 0
            tagPosList = []
            for (tagID, (pos, sense)) in motList:
                if tagID not in motKeys:
                    continue

                tagIndex += 1
                motNumber[tagID] += 1
                start = 0
                top = 0
                bottom = 0
                start = int(math.floor(float(pos) * seqScaler)) + seqStart
                if sense == "F":
                    top = 12
                    if tagIndex % 2:
                        textHeight = 0
                    else:
                        textHeight = 70
                else:
                    top = 35
                    bottom = 18
                    if tagIndex % 2:
                        textHeight = 15
                    else:
                        textHeight = 85

                for (prevStart, prevHeight) in tagPosList:
                    if abs(prevStart - start) <= 7 and abs(prevHeight - textHeight) <= 7:
                        if textHeight < 50:
                            textHeight -= 7
                        else:
                            textHeight += 7

                tagPosList.append((start, textHeight))
                draw.rectangle([start, currentHeight + top, start + adjustedMotLength[tagID], currentHeight + 65 + bottom], fill=motColor[tagID])
                if tagID.count("-") == 2:
                    tagIDlist = tagID.split("-")
                    tempID = tagIDlist[0] + tagIDlist[1][0] + tagIDlist[2]
                else:
                    tempID = tagID

                draw.text([start - 5, currentHeight + textHeight], tempID, font=self.theFont, fill=motColor[tagID])

            conservedWindows = []
            try:
                conservedWindows = self.getConservedSequenceWindows(geneID)
            except:
                pass

            for (location, cLength, criteria) in conservedWindows:
                start = int(math.floor(float(location) * seqScaler)) + seqStart
                consLength = int(math.ceil(float(cLength) * seqScaler))
                draw.rectangle([start, currentHeight + 38, start + consLength, currentHeight + 50], fill ='#ff0000')

            for location in range(0, seqLength, 1000):
                start = int(math.floor(float(location) * seqScaler)) + seqStart
                draw.rectangle([start, currentHeight + 60, start + 1, currentHeight + 65], fill = '#000000')

            draw.text([self.leftMargin + self.maxWidth + 5, currentHeight + 45], str(seqLength), font=self.theFont, fill=0)
            currentHeight += self.lineHeight

        motNum = 0
        if showFooter:
            for motID in motKeys:
                x = self.leftMargin + (self.maxWidth / 3) * (motNum % 3)
                x = self.leftMargin + (self.maxWidth / 3) * (motNum % 3)
                y1 = currentHeight + ((self.lineHeight/2) * (motNum/3))+ 5
                y2 = currentHeight + ((self.lineHeight/2) * (motNum/3))+ 15
                y3 = currentHeight + ((self.lineHeight/2) * (motNum/3))+ 25
                draw.text([x, y1], motID, font=self.theFont, fill=motColor[motID])
                draw.text([x, y2], motConsensus[motID], font=self.theFont, fill=0)
                draw.text([x, y3], str(motNumber[motID]) + " matches", font=self.theFont, fill=0)
                motNum +=1

        del draw
        image.save(picName, "PNG")


    def drawMotifs(self, picName, motifList, geneList=[], excludeGeneList=[], showHeader=True,
                   showFooter=True, maxOccurences=100, genesWithMotifOnly=True, skipSanity=False):
        """ Draws an image of one or more motifs on the sequences. Can specifically list or exclude genes.
            Options showHeader will add experiment information while showFooter will add a motif key.
            This function will not show motifs that occur more than maxOccurences in the dataset, and will
            only show sequences with the motif by default.
        """
        restrictedGeneList = []
        if genesWithMotifOnly:
            for motID in motifList:
                matches = self.motifToGene(motID)
                for (loc, pos) in matches:
                    if loc not in restrictedGeneList and loc not in excludeGeneList:
                        restrictedGeneList.append(loc)
        else:
            restrictedGeneList = geneList

        self.draw(picName, restrictedGeneList, motifList, excludeGeneList, [], showHeader, showFooter, maxOccurences, skipSanity)


    def drawGenes(self, picName, geneList, motifList=[], excludeMotifList=[], showHeader=True,
                  showFooter=True, maxOccurences=100, includeHomologs=False, motifsOnGeneOnly=True,
                  skipSanity=False):
        """ Draws an image of one or more motifs on the sequences. Can specifically list or exclude motif.
            Options showHeader will add experiment information while showFooter will add a motif key.
            This function will not show motifs that occur more than maxOccurences in the dataset, and will
            only show motifs on the sequence by default in the footer.
        """
        restrictedMotifList = []
        theGeneList = []
        if includeHomologs:
            for geneID in geneList:
                if geneID not in theGeneList:
                    theGeneList.append(geneID)
                try:
                    hgenes = self.returnHomologs(geneID)
                    for gID in hgenes:
                        if gID in self.genepool and gID not in theGeneList:
                            theGeneList.append(gID)
                except:
                    pass
        else:
            theGeneList = geneList

        if motifsOnGeneOnly:
            for geneID in theGeneList:
                matches = self.geneToMotif(geneID)
                for (motID, pos) in matches:
                    if motID not in restrictedMotifList and motID not in excludeMotifList:
                        restrictedMotifList.append(motID)
        else:
            restrictedMotifList = motifList

        self.draw(picName, theGeneList, restrictedMotifList, [], excludeMotifList, showHeader, showFooter, maxOccurences, skipSanity)