###########################################################################
#                                                                         #
# C O P Y R I G H T   N O T I C E                                         #
#  Copyright (c) 2003-10 by:                                              #
#    * California Institute of Technology                                 #
#                                                                         #
#    All Rights Reserved.                                                 #
#                                                                         #
# Permission is hereby granted, free of charge, to any person             #
# obtaining a copy of this software and associated documentation files    #
# (the "Software"), to deal in the Software without restriction,          #
# including without limitation the rights to use, copy, modify, merge,    #
# publish, distribute, sublicense, and/or sell copies of the Software,    #
# and to permit persons to whom the Software is furnished to do so,       #
# subject to the following conditions:                                    #
#                                                                         #
# The above copyright notice and this permission notice shall be          #
# included in all copies or substantial portions of the Software.         #
#                                                                         #
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,         #
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF      #
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND                   #
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS     #
# BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN      #
# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN       #
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE        #
# SOFTWARE.                                                               #
###########################################################################
#
import string
import cistematic
from cistematic.core.geneinfo import geneinfoDB
import cistematic.core
from cistematic.core.motif import Motif
try:
    from pysqlite2 import dbapi2 as sqlite
except:
    from sqlite3 import dbapi2 as sqlite

import os, tempfile, shutil
from os import environ

if environ.get("CISTEMATIC_ROOT"):
    cisRoot = environ.get("CISTEMATIC_ROOT") 
else:
    cisRoot = "/proj/genome"

if environ.get("CISTEMATIC_TEMP"):
    cisTemp = environ.get("CISTEMATIC_TEMP")
else:
    cisTemp = "/tmp"

tempfile.tempdir = cisTemp

def getModules(motifID, motLen, totalLen=60, mainGenome={}, queryGenomes={}, directory=cisTemp, filterMasked=True):
    """ build three fasta files containing sequences from (a) one main genome, (b) multiple query genomes, 
        and (c) motif matches in all genomes.
    """
    doMainGenome = True
    upLen = (totalLen - motLen) / 2
    downLen = totalLen - upLen - motLen
    for genomeDict in [mainGenome, queryGenomes]:
        if doMainGenome:
            outfilename = "%s/%s-main.fsa" % (directory, motifID)
            outfile = open(outfilename, "w")
            matchfilename = "%s/%s-matches.fsa" % (directory, motifID)
            matchfile = open(matchfilename, "w")

        for genome in genomeDict:
            infile = open(genomeDict[genome])
            print "doing %s" % genome
            cistematic.core.cacheGeneDB(genome)
            current = "-1"
            prevpos = -1 * totalLen
            prevchrom = "nochrom"
            lineArray = []
            misc = ""
            for line in infile:
                fields = line.split("\t")
                chrom = fields[3]
                pos = int(fields[4])
                lineArray.append((chrom, pos, line))

            lineArray.sort()
            infile.close()
            for (chrom, pos, line) in lineArray:
                fields = line.split("\t")
                if fields[1] == current:
                    continue

                current = fields[1]
                misc = fields[0][-1]
                if chrom == prevchrom and pos < (prevpos + 2 * (upLen + downLen)):
                    print "skipping %s-%s-%s%s\ttoo close to previous pos" % (genome, chrom, misc, current)
                    continue

                try:
                    motseq = cistematic.core.retrieveSequence(genome, chrom, pos + 1, pos + motLen)
                    start = pos - upLen
                    if start < 0:
                        start = 0

                    seq = cistematic.core.retrieveSequence(genome, chrom, start, start + totalLen)
                except:
                    print "skipping %s-%s-%s%s\terror retrieving sequence" % (genome, chrom, misc, current)
                    continue

                sense = fields[5]
                if sense == "R":
                    motseq = cistematic.core.complement(motseq)
                    seq = cistematic.core.complement(seq, len(seq))

                if filterMasked and motseq != motseq.upper():
                    continue

                stop = start + len(seq)
                matchID = "%s-%s-%s-%s%s" % (motifID.replace("-", "_"), genome, chrom.replace("-", "_"), misc, current)
                outfile.write("> %s %d [%d to %d, %s]\n%s\n" % (matchID, pos, start, stop, sense, seq))
                matchfile.write("> %s\n%s\n" % (matchID, motseq))
                prevchrom = chrom
                prevpos = pos

        if doMainGenome:
            outfile.close()
            print "doing database genomes"
            outfilename = "%s/%s-db.fsa" % (directory, motifID)
            outfile = open(outfilename, "w")
            doMainGenome = False

    cistematic.core.uncacheGeneDB()
    outfile.close()
    matchfile.close()


def doBlast(motifID, matchLength, similarity, directory=cisTemp, maxInputSize=3000, firstMatchOnly=False):
    """ Blast sequences generated by getModules() to identify regions in the main genome sequences that are 
        conserved in the query genomes. Results are saved in a motifIDBlast.db sqlite database.
    """
    # check that we don't exceed maxSetSize
    fsafilename = "%s/%s-main.fsa" % (directory, motifID)
    fsafile = open(fsafilename, "r")
    index = 0
    for line in fsafile:
        if line[0] != ">":
            continue

        index += 1

    if index > maxInputSize:
        print "number of matches in main genome exceeded maxInputSize %d - aborting cisMatcher.doBlast()" % maxInputSize
        return

    #build blast DB
    #run blast
    blastCommands = ["cd %s" % directory,
                     "nohup %s/programs/blast/bin/formatdb -t %smodDB -n %smodDB -p F -i %s-db.fsa" % (cisRoot, motifID, motifID, motifID),
                     "%s/programs/blast/bin/blastall -p blastn -d %smodDB -i %s-main.fsa -o %s.blastres -m 8 -v 15 -b 15" % (cisRoot, motifID, motifID, motifID)]
    blastCommandLine = string.join(blastCommands, "; ") 
    os.system(blastCommandLine)
    #filter blast
    fsafilename = "%s/%s-main.fsa" % (directory, motifID)
    fsafile = open(fsafilename, "r")
    infilename = "%s/%s.blastres" % (directory, motifID)
    infile = open(infilename, "r")
    outfilename = "%s/%s.cisblast" % (directory, motifID)
    outfile = open(outfilename, "w")
    tempdb = "%s%sBlast.db" % (tempfile.mktemp(), motifID)
    db = sqlite.connect(tempdb)
    sql = db.cursor()

    stmt = "create table blast_entries(ID INTEGER PRIMARY KEY, GENOME1 varchar, GENEID1 varchar, GENOME2 varchar, GENEID2 varchar, SIMILARITY float, LENGTH int, MISMATCHES int, INDEL int, start1 int, stop1 int, start2 int, stop2 int, evalue float, score float)"
    sql.execute(stmt)
    stmt = "create table blast_segments(ID INTEGER PRIMARY KEY, GENOME varchar, MATCHID varchar, chrom varchar, loc int, start int, stop int, sense varchar)"
    sql.execute(stmt)
    db.commit()

    stmt = "INSERT into blast_segments VALUES(NULL,  ?, ?, ?, ?, ?, ?, ?) "
    batch = []
    for line in fsafile:
        if line[0] != ">":
            continue

        (junk, matchid, loc, segstart, junk2, segstop, sense) = line.split()
        start = int(segstart[1:])
        stop = int(segstop[:-1])
        sense = sense[0]
        (motTag, genome, chromid, mid) = matchid.split("-")
        batch.append((genome, matchid, chromid, int(loc), start, stop, sense))

    sql.executemany(stmt, batch)

    counter = 0
    previousGene1 = ""
    previousGene2 = ""
    matchDict = {}

    print "Building matchDict"
    for line in infile:
        fields = line.split("\t")
        geneid1 = str(fields[0])
        (motTAG, genome1, mchrom1, match1) = geneid1.split("-")
        geneid2 = str(fields[1])
        (motTAG, genome2, mchrom2, match2) = geneid2.split("-")
        lineMatchLength = int(fields[3])
        lineSimilarity = float(fields[2])
        evalue = float(fields[10])
        score  = float(fields[11].strip())
        if evalue < 0.01 and lineMatchLength > matchLength and lineSimilarity > similarity:
            if geneid1 == geneid2:
                continue

            if geneid1 == previousGene1 and geneid2 == previousGene2:
                continue

            previousGene1 = geneid1 
            previousGene2 = geneid2
            if geneid1 not in matchDict:
                matchDict[geneid1] = {}

            elif firstMatchOnly:
                continue

            if geneid2 not in matchDict[geneid1]:
                matchDict[geneid1][geneid2] = []

            matchDict[geneid1][geneid2].append((score, line))

    os.remove(infilename)
    print "Processing matchDict"
    matchKeys = matchDict.keys()
    matchKeys.sort()

    stmt = "INSERT into blast_entries VALUES(NULL, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) "
    batch = []
    for geneid1 in matchKeys:
        geneKeys = matchDict[geneid1].keys()
        geneKeys.sort()
        for geneid2 in geneKeys:
            counter += 1
            array = matchDict[geneid1][geneid2]
            array.sort()
            (score, line) = array[-1]
            outfile.write(line)
            fields = line.split("\t")
            geneid1 = str(fields[0])
            (motTAG, genome1, mchrom1, gid1) = geneid1.split('-')
            geneid2 = str(fields[1])
            (motTAG, genome2, mchrom2, gid2) = geneid2.split('-')
            similarity = float(fields[2])
            matchLength = int(fields[3])
            mismatches = int(fields[4])
            indels = int(fields[5])
            start1 = int(fields[6])
            stop1  = int(fields[7])
            start2 = int(fields[8])
            stop2  = int(fields[9])
            evalue = float(fields[10])
            score  = float(fields[11].strip())
            batch.append((genome1, geneid1, genome2, geneid2, similarity, matchLength, mismatches, indels, start1, stop1, start2, stop2, evalue, score))

    sql.executemany(stmt, batch)
    stmt = "CREATE index blastE on blast_entries(genome1, geneid1)"
    sql.execute(stmt)
    stmt = "CREATE index blastS on blast_segments(GENOME, MATCHID)"
    sql.execute(stmt) 
    db.commit()
    sql.close()
    db.close()
    shutil.copyfile(tempdb, "%s/%sBlast.db" % (directory, motifID))
    os.remove(tempdb)
    infile.close()
    outfile.close()
    print "inserted %d lines" % counter


def getCandidates(motifID, motLen, mainGenome={}, radiusStep=1000, maxRadius=200000, maxInputSize=3000, refineMotifs=False, directory=cisTemp, ucscOrg="Human", ucscDB="hg17"):
    """ Core of cismatcher. Extracts conserved regions from data in the Blast.db database generated by doBlast().
    """
    genome = mainGenome.keys()[0]
    cistematic.core.cacheGeneDB(genome)
    motfile = open(mainGenome[genome], "r")
    modfilename = "%s/%s-main.fsa" % (directory, motifID)
    modfile = open(modfilename, "r")
    outfilename = "%s/cismatcher.%s.out" % (directory, motifID)
    outfile = open(outfilename, "w")
    annotfilename = "%s/cismatcher.%s.annot.txt" % (directory, motifID)
    annotfile = open(annotfilename, "w")
    motseqs = {}
    goodseqs = []
    if refineMotifs:
        mfilename = "%s/%s-matches.fsa" % (directory, motifID)
        mfile = open(mfilename, "r")
        mid = ""
        for line in mfile:
            if line[0] != ">":
                motseqs[mid] = line.strip()
            else:
                mid = line.strip()[2:]

        mfile.close()

    current = "-1"
    prevchrom = "nochrom"
    lines = []
    numEntries = 0
    for line in motfile:
        fields = line.split("\t")
        chrom = fields[3]
        if fields[1] == current and chrom == prevchrom:
            continue

        if chrom != prevchrom:
            refchrom = chrom
            if "rand" in refchrom:
                refchrom = "%s_random" % refchrom[:-4]

            prevchrom = chrom

        current = fields[1]
        motstart = int(fields[4]) + 1
        motstop = motstart + motLen
        sense = fields[5]
        if sense == "F":
            sense = "+"
        else:
            sense = "-"

        lines.append("chr%s\tcistematic\tmotif\t%d\t%d\t.\t%s\t.\t%s-%s-%s\n" % (refchrom, motstart, motstop, sense, motifID, chrom, current))
        numEntries += 1

    if numEntries > maxInputSize:
        print "number of matches in main genome exceeded maxInputSize %d - aborting cisMatcher.getCandidates()" % maxInputSize
        return

    if lines:
        outheader = 'track name=%strack description="Cistematic %s hits"\n' % (motifID, motifID)
        annotfile.write(outheader)
        for outline in lines:
            annotfile.write(outline)

    motfile.close()
    try:
        dbname = "%s/%sBlast.db" % (directory, motifID)
        tempdb = "%s%sBlast.db" % (tempfile.mktemp(), motifID)
        shutil.copyfile(dbname, tempdb)
    except:
        print "no blast database - aborting cisMatcher->getCandidates()"
        return

    index = 0
    prevchrom = "nochrom"
    for line in modfile:
        if line[0] != ">":
            continue

        (junk, matchid, loc, segstart, junk2, segstop, sense) = line.split()
        start = int(segstart[1:]) + 1
        stop = int(segstop[:-1]) + 1
        (motifID, genome, chrom, match) = matchid.split('-')
        if sense == "F":
            sense = "="
        else:
            sense = "-"

        if genome not in matchid:
            continue

        if chrom != prevchrom:
            refchrom = chrom
            if "rand" in refchrom:
                refchrom = "%s_random" % refchrom[:-4]

            prevchrom = chrom

        line = "chr%s\tcistematic\tblast_region\t%d\t%d\t.\t%s\t.\t%sb%d\n" % (refchrom, start, stop, sense, matchid, index)
        annotfile.write(line)
        index += 1

    modfile.close()
    db = sqlite.connect(tempdb)
    sql = db.cursor()
    stmt = ' select * from blast_entries where GENOME1="%s" ORDER BY GENEID1 ' % genome
    sql.execute(stmt)
    res = sql.fetchall()
    idb = geneinfoDB(cache=True)
    prevchrom="nochrom"
    outlines = []
    index = 0
    for entry in res:
        (ID, gen1, mid1, gen2, mid2, sim, length, mis, indel, start1, stop1, start2, stop2, evalue, score) = entry
        line = "%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n" % (ID, gen1, mid1, gen2, mid2, sim, length, mis, indel, start1, stop1, start2, stop2, evalue, score)
        stmt2 = 'select chrom, loc, start, stop, sense from blast_segments where GENOME = "%s" and MATCHID = "%s" ' % (gen1, mid1)
        sql.execute(stmt2)
        result = sql.fetchone()
        (chrom, loc, start, stop, sense)  = result
        if chrom != prevchrom:
            refchrom = chrom
            if "rand" in refchrom:
                refchrom = "%s_random" % refchrom[:-4]

            prevchrom = chrom

        if sense == "F":
            matchstart = start +  + 1 + int(start1)
            matchstop = start +  + 1 + int(stop1)
        else:
            matchstart = stop + 1 - int(start1)
            matchstop = stop + 1 - int(stop1)

        if sense == "F":
            sense = "+"
        else:
            sense = "-"

        if matchstart > matchstop:
            temp = matchstop
            matchstop = matchstart
            matchstart = temp

        line = "chr%s\tcistematic\tmatch\t%d\t%d\t%d\t%s\t.\t%s.%d\n" % (refchrom, matchstart, matchstop, (float(sim) - 90) * 100, sense, mid2, index)
        index += 1
        annotfile.write(line)
        feature = "NONE"
        symbol = ""
        desc = ""
        relativeLoc = ""
        radius = 0
        featureList = cistematic.core.getFeaturesIntersecting(genome, chrom, start , motLen, ftype="CDS") + cistematic.core.getFeaturesIntersecting(genome, chrom, start - radius, 2 * radius, ftype="%UTR")
        if len(featureList) > 0:
            relativeLoc = featureList[0][6]
            feature = featureList[0][0] # always pick first feature 
            try:
                desc = cistematic.core.getAnnotInfo((genome, feature))[0]
                symbol = idb.getGeneInfo((genome, feature))[0]
            except:
                pass

        while feature == "NONE" and radius < maxRadius:
            radius += radiusStep
            featureList = cistematic.core.getFeaturesIntersecting(genome, chrom, start - radius, 2 * radius, ftype="CDS") + cistematic.core.getFeaturesIntersecting(genome, chrom, start - radius, 2 * radius, ftype="%UTR")
            if genome == "human":
                featureList += cistematic.core.getFeaturesIntersecting(genome, chrom, start - radius, 2 * radius, ftype="WGRNA")

            if len(featureList) > 0:
                feature = featureList[0][0] # always pick first feature 
                try:
                    geneID = (genome, feature)
                    (gchrom, gstart, gstop, glength, gsense) = cistematic.core.geneEntry(geneID)
                    if gstart <= start and start <= gstop:
                        relativeLoc = "GENE"
                    elif start < gstart:
                        if gsense == "F":
                            relativeLoc = "UP"
                        else:
                            relativeLoc = "DOWN"
                    else:
                        if gsense == "F":
                            relativeLoc = "DOWN"
                        else:
                            relativeLoc = "UP"

                    desc = cistematic.core.getAnnotInfo(geneID)[0]
                    symbol = idb.getGeneInfo(geneID)[0]
                except:
                    pass

        line = '<tr><td>%s</td><td>%s</td><td>%s</td><td>%s</td><td>%s</td><td><a href="http://genome.ucsc.edu/cgi-bin/hgTracks?org=%s&db=%s&position=chr%s:%d-%d" target="new">chr%s:%d-%d</a></td><td>%s</td><td>%s</td><td>%d</td><td>%s</td></tr>\n' % (relativeLoc, mid1, mid2, sim, length, ucscOrg, ucscDB, refchrom, start - radius, stop + radius, refchrom, matchstart , matchstop, feature, symbol, radius, desc)
        outlines.append((float(sim), int(length), mid1, line))
        if index % 100 == 0:
            print "."

    cistematic.core.uncacheGeneDB()
    annotfile.close()
    sql.close()
    db.close()
    os.remove(tempdb)
    outlines.sort()
    outlines.reverse()
    alreadySeen = []
    print "writing outfile"
    for line in outlines:
        if line[2] in alreadySeen:
            continue

        alreadySeen.append(line[2])
        outfile.write(line[3])

    outfile.close()
    if refineMotifs:
        print "Refining motifs"
        goodIDs = []
        for entry in res:
            (ID, gen1, mid1, gen2, mid2, sim, length, mis, indel, start1, stop1, start2, stop2, evalue, score) = entry
            if mid1 not in goodIDs:
                goodIDs.append(mid1)
                goodseqs.append(motseqs[mid1].upper())

            if mid2 not in goodIDs:
                goodIDs.append(mid2)
                goodseqs.append(motseqs[mid2].upper())

        mot = Motif("%s+R" % motifID, "", "", goodseqs)
        motfilename = "%s/%s+R.mot" % (directory, motifID)
        mot.saveMotif(motfilename)