import sys, copy, math
try:
    import psyco
    psyco.full()
except:
    print "psyco not running"

from som import SOM
from mapGraphics import saveMapPNG, coordTranslate

versionString = '%s: version 1.3' % sys.argv[0]
print versionString

steps = 100000
rounds = 100
firstField = 1
dim = -1
if len(sys.argv) < 4:
    print "usage: python %s numRows numCols somfile outprefix [-maps map1,map2,map3,...] [-translate dx,dy] [-withTopology]" % sys.argv[0]
    sys.exit(0)

clusterRows = int(sys.argv[1])
clusterCols = int(sys.argv[2])

somfile = sys.argv[3]
clusterResultsPrefix = sys.argv[4]

mapList = []
mapfileList = []
if '-maps' in sys.argv:
    mapnameList = sys.argv[sys.argv.index('-testset') + 1].split(',')
    
if '-translate' in sys.argv:
    coords = sys.argv[sys.argv.index('-translate') + 1]
    (dx, dy) = coords.split(',')
    dx = int(dx)
    dy = int(dy)
    doCoords = True

useTopology = False
if '-withTopology' in sys.argv:
    useTopology = True

if '-trials' in sys.argv:
    rounds = int(sys.argv[sys.argv.index('-trials') + 1])

if '-timesteps' in sys.argv:
    steps = int(sys.argv[sys.argv.index('-timesteps') + 1])

startRadius = -1
origSOM = SOM(initialFile=somfile)
for mapfile in mapfileList:
    mapList.append(origSOM.readMap(mapfile))
    
origUnits = origSOM.getUnits()
origDict = {}
for (arow, acol) in origUnits:
    unitLabel = 'unit_%d_%d' % (arow, acol)
    origVector = copy.copy(origSOM.weightsArray[arow][acol])
    origDim = len(origVector)
    # add map list values here
    # add x,y unit coordinate
    if useTopology:
        if origSOM.topologyType == 'toroid':
            origVector.append(0.4 + 0.1 * math.sin(math.pi * float(arow)/origSOM.outRows)) 
            origVector.append(0.6 + 0.1 * math.sin(math.pi * float((acol + origSOM.outCols/2) % origSOM.outCols)/origSOM.outCols))
        else:
            origVector.append(0.25 + 0.5 * float(arow)/origSOM.outRows) 
            origVector.append(0.25 + 0.5 * float(acol)/origSOM.outCols)
    origDict[unitLabel] = origVector
dim = len(origVector)
count = origSOM.outRows * origSOM.outCols

if count < clusterRows * clusterCols:
    print "number of clusters %d * %d > %d - exiting" % (clusterRows, clusterCols, count)
    sys.exit(1)

bestScore = 1000000

print "original SOM: %s (%d units)" % (somfile, count)
print "\twill use %d dimensions" % (dim)
print "SOM: %d rows * %d columns for %d time steps" % (clusterRows, clusterCols, steps)
print "\tbest of %d trials will be saved in %s" % (rounds, clusterResults)
print

tempResults = clusterResultsPrefix + '.somtemp'
for trial in range(rounds):
    mysom = SOM(topology='sheet')
    mysom.resize(inputDim=dim, numRows=clusterRows, numCols=clusterCols)
    print "trial %d" % trial
    trainingDict = copy.deepcopy(origDict)
    mysom.initializeWeights(trainingDict)
    mysom.train(trainingDict, timeSteps=steps, radius=startRadius)
    print "scoring data"
    currentScore = mysom.scoreData(origDict)
    print "trial score = %.3f" % currentScore
    if currentScore < bestScore:
        mysom.saveSOM(tempResults)
        bestScore = currentScore

finalSOM = SOM(initialFile=tempResults)
finalClusters = finalSOM.getUnits()
winnerDict = finalSOM.getWinners(origDict)

index = 0
clusterMap = origSOM.newMap()
for (arow, acol) in origUnits:
    clusterMap[arow][acol] = 0.
for (clustx, clusty) in finalClusters:
    outfile = open(clusterResultsPrefix + '.' + str(index) + '.cluster','w')
    print 'cluster (%d, %d): %d units' % (clustx, clusty, len(winnerDict[clustx, clusty]))
    for regionID in winnerDict[clustx, clusty]:
        regionCoords = regionID.split('_')
        regionRow = int(regionCoords[1])
        regionCol = int(regionCoords[2])
        origVector = origSOM.weightsArray[regionRow][regionCol]
        clusterMap[regionRow][regionCol] = index
        outline = regionID 
        for weight in origVector:
            outline += '\t%.3f' % weight
        outfile.write(outline + '\n')
    index += 1
outfile.close()
saveMapPNG(clusterMap, clusterResultsPrefix + '.clusters', grid=origSOM.gridType)

print "bestScore = %.3f" % bestScore
