import sys

versionString = '%s: version 1.6' % sys.argv[0]
print versionString

try:
    import psyco
    psyco.full()
except:
    print "psyco not running"

from som import SOM
from mapGraphics import saveMapPNG, coordTranslate

if len(sys.argv) < 4:
    print "usage: python %s somfile scorefile outprefix [-data datafile] [-diffs file] [-all] [-count] [-datacol dimID] [-verbose] [-dataheader] [-savemap] [-startField num] [-vmax value] [-translate dx,dy]" % sys.argv[0]
    sys.exit(0)

doCount = False
doComponents = True
doData = False
doAllData = False
doDiffs = False
doCoords = False
saveMap = False
dataColumns = []
maxValue = 0

dx = 0
dy = 0
xmax = 0
ymax = 0

recField = 0
firstField = 1
doVerbose = False
hasHeader = False
dataHeaders = []

somfile = sys.argv[1]
scorefile = sys.argv[2]
outprefix = sys.argv[3]

datafile = ''
if '-data' in sys.argv:
    datafile = sys.argv[sys.argv.index('-data') + 1]
    doData = True

if '-diffs' in sys.argv:
    difflistfile = open(sys.argv[sys.argv.index('-diffs') + 1])
    doDiffs = True
    
if '-count' in sys.argv:
    doCount = True

if '-all' in sys.argv:
    doComponents = True
    if doData:
        doAllData = True

if '-verbose' in sys.argv:
    doVerbose = True

if '-startField' in sys.argv:
    firstField = int(sys.argv[sys.argv.index('-startField') + 1])

if '-dataheader' in sys.argv:
    hasHeader = True

if '-translate' in sys.argv:
    coords = sys.argv[sys.argv.index('-translate') + 1]
    (dx, dy) = coords.split(',')
    dx = int(dx)
    dy = int(dy)
    doCoords = True

if '-savemap' in sys.argv:
    saveMap = True

if '-vmax' in sys.argv:
    maxValue = float(sys.argv[sys.argv.index('-vmax') + 1])

diffDict = {}
diffMapDict = {}
diffList = []
diffMapOrder = []
if doDiffs:
    for line in difflistfile:
        if line[0] == '#':
            continue
        try:
            fields = line.strip().split('\t')
            diffname = fields[0]
            startmap = fields[1]
            stopmap = fields[2]
            if len(fields) > 3:
                operation = fields[4].upper()
            else:
                operation = 'SUB'
            diffList.append(startmap)
            diffList.append(stopmap)
            diffDict[diffname] = (startmap, stopmap, operation)
            diffMapOrder.append(diffname)
        except:
            print "could not process fields: %s" % line.strip()
    difflistfile.close()

mysom = SOM(initialFile=somfile)
if doVerbose:
    mysom.describe()

xmax = mysom.outCols
ymax = mysom.outRows
onePercent = int(xmax * ymax / 100.)
if onePercent < 1:
    onePercent = 1

winnerDict = mysom.readScoreFile(scorefile)

if doData:
    (dataDim, dataCount) = mysom.datasetSize(datafile, startField=firstField)
    dataDict = mysom.getDataset(datafile, startField=firstField, IDField=recField, dimension=dataDim)
    if hasHeader:
        infile = open(datafile)
        line = infile.readline()
        infile.close()
        fields = line.strip().split('\t')
        dataHeaders = fields[2:]
if doAllData:
    dataColumns = range(dataDim)

countArray = mysom.newMap()
for (row, col) in winnerDict:
    if doCoords:
        (tcol, trow) = coordTranslate(col, row, dx, dy, xmax, ymax)
    else:
        tcol = col
        trow = row
    countArray[trow][tcol] = len(winnerDict[row, col])
#countArray.reverse()
saveMapPNG(countArray, outprefix + '.datacount', grid=mysom.gridType)

if doComponents:
    for dimension in range(mysom.inputDim):
        unitDimArray = mysom.newMap()
        for (row, col) in mysom.units:
            if doCoords:
                (tcol, trow) = coordTranslate(col, row, dx, dy, xmax, ymax)
            else:
                tcol = col
                trow = row
            unitDimArray[trow][tcol] = mysom.weightsArray[row][col][dimension]
            #print unitDimArray[row][col]
        #unitDimArray.reverse()
        saveMapPNG(unitDimArray, outprefix + '.component_' + str(dimension), grid=mysom.gridType)

if doData:
    for dimension in dataColumns:
        dataDimArray = mysom.newMap()
        unitValueList = []
        for (row, col) in winnerDict:
            datasum = 0.
            for recID in winnerDict[row, col]:
                datasum += dataDict[recID][dimension]
            if doCoords:
                (tcol, trow) = coordTranslate(col, row, dx, dy, xmax, ymax)
            else:
                tcol = col
                trow = row
            if doCount:
                dataDimArray[trow][tcol] = datasum            
                unitValueList.append(datasum)
            else:
                try:
                    dataDimArray[trow][tcol] = datasum / float(len(winnerDict[row, col]))
                    unitValueList.append(datasum / float(len(winnerDict[row, col])))
                except:
                    dataDimArray[trow][tcol] = 0
                    unitValueList.append(0)
        outfileprefix = outprefix + '.data_' + str(dimension)
        outfiletitle = ''
        if hasHeader:
            outfileprefix += '_' + dataHeaders[dimension]
            outfiletitle = dataHeaders[dimension]
            print outfiletitle
            if doDiffs:
                if dataHeaders[dimension] in diffList:
                    diffMapDict[dataHeaders[dimension]] = dataDimArray
        if saveMap:
            mysom.saveMap(dataDimArray, outfileprefix + '.map')
        unitValueList.sort()
        
        #dataDimArray.reverse()
        saveMapPNG(dataDimArray, outfileprefix, outfiletitle, minVal=unitValueList[onePercent], maxVal=unitValueList[-1 * onePercent], grid=mysom.gridType)

if doDiffs:
    for diffname in diffMapOrder:
        outfileprefix = outprefix + '.diffed_' + diffname
        (startMap, stopMap, operation) = diffDict[diffname]
        if startMap not in diffMapDict:
            print "Could not find %s for diffmap %s - skipping" % (startMap, diffname)
            continue
        if stopMap not in diffMapDict:
            print "Could not find %s for diffmap %s - skipping" % (stopMap, diffname)
            continue
        outfiletitle = '%s %s %s' % (startMap, operation, stopMap)
        print "doing diff for %s = %s %s %s" % (diffname, startMap, operation, stopMap)
        inmap1 = diffMapDict[startMap]
        inmap2 = diffMapDict[stopMap]
        outmap = mysom.newMap()
        unitValueList = []
        for (row, col) in mysom.units:
            if operation == 'SUB':
                outmap[row][col] = inmap1[row][col] - inmap2[row][col]
            if outmap[row][col] > 0:
                unitValueList.append(outmap[row][col])
            else:
                unitValueList.append(0)
        diffMapDict[diffname] = outmap
        if saveMap:
            mysom.saveMap(outmap, outfileprefix + '.map')
        #outmap.reverse()
        unitValueList.sort()
        saveMapPNG(outmap, outfileprefix, outfiletitle, minVal=unitValueList[onePercent], maxVal=unitValueList[-1 * onePercent], grid=mysom.gridType)

