#
#  buildMatrix.py
#  ENRAGE
#
#  Created by Ali Mortazavi on 3/6/09.
#
import sys
import string
import optparse
from commoncode import writeLog, getConfigParser, getConfigOption, getConfigBoolOption

versionString = "buildMatrix: version 1.5"
print versionString


def main(argv=None):
    if not argv:
        argv = sys.argv

    usage = "usage: python %prog matrix.step.N-1 data.part matrix.step.N [--rescale] [--truncate maxRPKM] [--log altlogfile]"

    parser = makeParser(usage)
    (options, args) = parser.parse_args(argv[1:])

    if len(args) < 3:
        print usage
        sys.exit(0)

    infile = args[0]
    colfilename = args[1]
    outfilename = args[2]

    if options.maxRPKM is not None:
        truncateRPKM = True
        maxRPKM = options.maxRPKM
    else:
        truncateRPKM = False
        maxRPKM = 100000000

    buildMatrix(infile, colfilename, outfilename, truncateRPKM, maxRPKM,
                options.rescale, options.logfilename)


def makeParser(usage=""):
    parser = optparse.OptionParser(usage=usage)
    parser.add_option("--rescale", action="store_true", dest="rescale")
    parser.add_option("--truncate", type="int", dest="maxRPKM")
    parser.add_option("--log", dest="logfilename")

    configParser = getConfigParser()
    section = "buildMatrix"
    rescale = getConfigBoolOption(configParser, section, "rescale", False)
    maxRPKM = getConfigOption(configParser, section, "maxRPKM", None)
    logfilename = getConfigOption(configParser, section, "logfilename", "buildMatrix.log")

    parser.set_defaults(rescale=rescale, maxRPKM=maxRPKM, logfilename=logfilename)

    return parser


def buildMatrix(inFileName, colfilename, outfilename, truncateRPKM,
                maxRPKM=100000000, rescale=False, logfilename="buildMatrix.log"):

    writeLog(logfilename, versionString, string.join(sys.argv[1:]))

    if "/" in colfilename:
        colname = colfilename.split("/")[-1]
    else:
        colname = colfilename

    fileParts = colname.split(".")
    colID =  fileParts[0]

    infile = open(inFileName)
    colfile = open(colfilename)
    outfile = open(outfilename, "w")
    header = infile.readline()[:-1]
    if header.strip() == "":
        header = "#\t"

    outfile.write("%s\t%s\n" % (header, colID))

    values = []
    min = 20000000000.
    max = -1.
    untruncatedMax = -1.
    for line in colfile:
        if doNotProcessLine(line):
            continue

        fields = line.strip().split()
        val = float(fields[-1])
        if truncateRPKM and val > maxRPKM:
            if val > untruncatedMax:
                untruncatedMax = val

            val = maxRPKM

        values.append(val)
        if val < min:
            min = val

        if val > max:
            max = val

    range = max - min
    if rescale:
        finalValues = [(val - min)/range for val in values]
    else:
        finalValues = values

    for val in finalValues:
        line = infile.readline().strip()
        line += "\t%1.3f\n" % val
        outfile.write(line)

    outfile.close()

    if untruncatedMax > 0:
        max = untruncatedMax

    message = "max value in %s was %.2f" % (colname, max)
    if untruncatedMax > 0:
        message += " but was truncated to %d" % maxRPKM

    print message
    writeLog(logfilename, versionString, message)


def doNotProcessLine(line):
    return line[0] == "#"


if __name__ == "__main__":
    main(sys.argv)