try:
    import psyco
    psyco.full()
except:
    print 'psyco not running'

import sys
import optparse
from cistematic.core import genesIntersecting, featuresIntersecting
from commoncode import getConfigParser, getConfigIntOption, getConfigOption, getConfigBoolOption, getGeneInfoDict, getGeneAnnotDict, getExtendedGeneAnnotDict

print "getallgenes: version 5.6"


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog genome regionfile outfile [--radius bp] [--nomatch nomatchfile] --trackfar --stranded --cache --compact [--step dist] [--startField colID]"

    parser = getParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(2)

    genome = args[0]
    infilename = args[1]
    outfilename = args[2]

    getallgenes(genome, infilename, outfilename, options.maxRadius,
                options.nomatchfilename, options.step, options.trackFar,
                options.trackStrand, options.compact, options.colID,
                options.doCache, options.extendGenome, options.replaceModels)


def getParser(usage):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--radius", type="int", dest="maxRadius")
    parser.add_option("--nomatch", dest="nomatchfilename")
    parser.add_option("--trackfar", action="store_true", dest="trackFar")
    parser.add_option("--stranded", action="store_true", dest="trackStrand")
    parser.add_option("--cache", action="store_true", dest="cachePages")
    parser.add_option("--compact", action="store_true", dest="compact")
    parser.add_option("--step", type="int", dest="step")
    parser.add_option("--startField", type="int", dest="colID")
    parser.add_option("--models", dest="extendGenome")
    parser.add_option("--replacemodels", action="store_true", dest="replaceModels")

    configParser = getConfigParser()
    section = "getallgenes"
    maxRadius = getConfigIntOption(configParser, section, "maxRadius", 20002)
    nomatchfilename = getConfigOption(configParser, section, "nomatchfilename", "")
    step = getConfigOption(configParser, section, "step", None)
    trackFar = getConfigBoolOption(configParser, section, "trackFar", False)
    trackStrand = getConfigBoolOption(configParser, section, "trackStrand", False)
    compact = getConfigBoolOption(configParser, section, "compact", False)
    colID = getConfigIntOption(configParser, section, "colID", 1)
    doCache = getConfigBoolOption(configParser, section, "doCache", False)
    extendGenome = getConfigOption(configParser, section, "extendGenome", "")
    replaceModels = getConfigBoolOption(configParser, section, "replaceModels", False)

    parser.set_defaults(maxRadius=maxRadius, nomatchfilename=nomatchfilename, step=step, trackFar=trackFar,
                        trackStrand=trackStrand, compact=compact, colID=colID, doCache=doCache,
                        extendGenome=extendGenome, replaceModels=replaceModels)

    return parser


def getallgenes(genome, infilename, outfilename, maxRadius=20002, nomatchfilename="",
                step=None, trackFar=False, trackStrand=False, compact=False, colID=1,
                doCache=False, extendGenome="", replaceModels=False):

    if not step:
        step = maxRadius - 2

    if extendGenome and replaceModels:
        replaceModels = True
    else:
        replaceModels = False

    infile = open(infilename)
    outfile = open(outfilename,"w")

    geneinfoDict = getGeneInfoDict(genome, cache=doCache)

    posList = []
    altPosDict = {}
    altPosRevDict = {}
    posLine = {}
    posStrand = {}
    altPosList = []

    for line in infile:
        if line[0] == "#":
            continue

        fields = line.split("\t")
        if compact:
            (chrom, pos) = fields[colID].split(":")
            chrom = chrom[3:]
            (start, stop) = pos.split("-")
            pos = (chrom, int(start))
            altPos = (chrom, int(stop))
        else:
            try:
                chrom = fields[colID][3:]
            except:
                print line
                continue

            pos = (chrom, int(fields[colID + 1]))
            altPos = (chrom, int(fields[colID + 2]))

        altPosDict[pos] = altPos
        altPosRevDict[altPos] = pos
        posList.append(pos)
        posList.append(altPos)
        altPosList.append(altPos)
        posLine[pos] = line
        if trackStrand:
            if "RNAFARP" in line:
                posStrand[pos] = "+"
                posStrand[altPos] = "+"
            else:
                posStrand[pos] = "-"
                posStrand[altPos] = "-"

    geneList = []
    geneDict = {}
    if maxRadius < step:
        step = maxRadius - 2

    if extendGenome != "":
        geneannotDict = getExtendedGeneAnnotDict(genome, extendGenome, replace=replaceModels, inRAM=True)
    else:
        geneannotDict = getGeneAnnotDict(genome, inRAM=True)

    for radius in range(1, maxRadius, step):
        print "radius %d" % radius
        print len(posList)
        if radius == 1:
            posDict = genesIntersecting(genome, posList, extendGen=extendGenome, replaceMod=replaceModels)
        else:
            posDict = featuresIntersecting(genome, posList, radius, "CDS", extendGen=extendGenome, replaceMod=replaceModels) 
            posDict2 = featuresIntersecting(genome, posList, radius, "UTR", extendGen=extendGenome, replaceMod=replaceModels)
            for apos in posDict2:
                try: 
                    posDict[apos] += posDict2[apos]
                    posDict[apos].sort()
                except:
                    posDict[apos] = posDict2[apos]

        for pos in posDict:
            geneID  = ""
            if len(posDict[pos]) == 1:
                if trackStrand:
                    if posStrand[pos] == posDict[pos][0][-1]:
                        geneID = posDict[pos][0][0]
                else:
                    geneID = posDict[pos][0][0]
            elif len(posDict[pos]) > 1 and not trackStrand:
                (chrom, loc) = pos
                bestres = posDict[pos][0]
                dist1 = abs(bestres[3] - loc)
                dist2 = abs(bestres[4] - loc)
                if dist1 < dist2:
                    bestdist = dist1
                else:
                    bestdist = dist2

                for testres in posDict[pos]:
                    testdist1 = abs(testres[3] - loc)
                    testdist2 = abs(testres[4] - loc)
                    if testdist1 < testdist2:
                        testdist = testdist1
                    else:
                        testdist = testdist2

                    if testdist < bestdist:
                        bestdist = testdist
                        bestres = testres

                geneID = bestres[0]
            elif len(posDict[pos]) > 1:
                (chrom, loc) = pos
                bestres = posDict[pos][0]
                dist1 = abs(bestres[3] - loc)
                dist2 = abs(bestres[4] - loc)
                bestStrand = posDict[pos][-1]
                if dist1 < dist2:
                    bestdist = dist1
                else:
                    bestdist = dist2

                for testres in posDict[pos]:
                    testdist1 = abs(testres[3] - loc)
                    testdist2 = abs(testres[4] - loc)
                    testStrand = testres[-1]
                    if testdist1 < testdist2:
                        testdist = testdist1
                    else:
                        testdist = testdist2

                    if bestStrand != posStrand[pos] and testStrand == posStrand[pos]:
                        bestdist = testdist
                        bestres = testres
                        bestStrand = testStrand
                    elif testdist < bestdist:
                        bestdist = testdist
                        bestres = testres

                if bestStrand == posStrand[pos]:
                    geneID = bestres[0]

            if geneID != "":
                try:
                    if genome == "dmelanogaster":
                        symbol = geneinfoDict["Dmel_" + geneID][0][0]
                    else:
                        symbol = geneinfoDict[geneID][0][0]
                except:
                    try:
                        symbol = geneannotDict[(genome, geneID)][0]
                    except:
                        symbol = "LOC" + geneID
            else:
                continue

            if pos in altPosList and pos in posList:
                posList.remove(pos)
                if pos not in altPosRevDict:
                    continue

                if altPosRevDict[pos] in posList:
                    posList.remove(altPosRevDict[pos])

                pos = altPosRevDict[pos]
            elif pos in posList:
                posList.remove(pos)
                if pos not in altPosDict:
                    print pos
                    continue

                if altPosDict[pos] in posList:
                    posList.remove(altPosDict[pos])
            else:
                continue

            if (symbol, geneID) not in geneList:
                geneList.append((symbol, geneID))
                geneDict[(symbol, geneID)] = []

            if pos not in geneDict[(symbol, geneID)]:
                geneDict[(symbol, geneID)].append(pos)

    for (symbol, geneID) in geneList:
        geneDict[(symbol, geneID)].sort()
        seenLine = []
        for pos in geneDict[(symbol, geneID)]:
            if pos in altPosRevDict:
                pos = altPosRevDict[pos]

            if posLine[pos] in seenLine:
                continue

            if "\t" in symbol:
                symbol = symbol.replace("\t","|")

            if " " in symbol:
                symbol = symbol.replace(" ","_")

            line = "%s %s %s" % (symbol, geneID, posLine[pos])
            seenLine.append(posLine[pos])
            outfile.write(line)

    matchIndex = 0
    if nomatchfilename != "":
        nomatchfile = open(nomatchfilename, "w")

    prevStart = 0
    prevChrom = ""
    farIndex = 0
    start = 0
    for pos in posList:
        if pos not in altPosList:
            if nomatchfilename != "":
                nomatchfile.write(posLine[pos])

            matchIndex += 1
            # need to add strand tracking here.....
            if trackFar:
                (chrom, start) = pos
                if chrom != prevChrom:
                    farIndex += 1
                    prevChrom = chrom
                elif abs(int(start) - prevStart) > maxRadius:
                    farIndex += 1

                line = "FAR%d %d %s" % (farIndex, -1 * farIndex, posLine[pos])
                outfile.write(line)
            prevStart = int(start)

    if nomatchfilename != "":
        nomatchfile.close()

    print "%d sites without a gene within radius of %d" % (matchIndex, radius)


if __name__ == "__main__":
    main(sys.argv)