import math, random, time, copy 
from operator import add

try:
    #import _motif
    from _csom import cPropagate, cEuclidean, cUpdateUnit
    print "loading SOM C-extension"
    hasExtension = True
except:
    print "can not load SOM C-extension"
    hasExtension = False


class SOM:
    def __init__(self, inputDim=2, numRows=6, numCols=5, grid='hex', topology='toroid', initialFile=''):
        self.gridType = grid
        self.topologyType = topology
        if initialFile != '':
            self.readSOM(initialFile)
        else:
            self.resize(inputDim, numRows, numCols)
    
    def newMap(self, mapType='Array', numDim=-1):
        newMapResult = []
        if numDim < 1:
            numDim = self.inputDim
        if mapType == 'Dict':
            newMapResult = {}
            for row in range(self.outRows):
                for col in range(self.outCols):
                    newMapResult[row, col] = []
            
            return newMapResult
        
        for row in range(self.outRows):
            res = []
            for col in range(self.outCols):
                unit = [0.] * numDim
                res.append(unit)
            newMapResult.append(res)
        
        return newMapResult
    
    def getUnits(self):
        res = []
        for row in range(self.outRows):
            for col in range(self.outCols):
                res.append((row, col))
        return res
    
    def resize(self, inputDim=2, numRows=6, numCols=5):
        self.outRows = int(numRows)
        self.outCols = int(numCols)
        self.inputDim = int(inputDim)
        self.weightsArray = self.newMap()
        self.units = self.getUnits()
        self.tempUnits = self.getUnits()
    
    def readSOM(self, somFile):
        infile = open(somFile)
        firstline = infile.readline().strip()
        if 'grid' in firstline:
            fields = firstline.split('\t')
            self.gridType = fields[3][:-5]
            if 'topology' in firstline:
                self.topologyType = fields[4][:-9]
        else:
            self.gridType = 'rect'
            fields = firstline.split(' * ')
            #print fields
        self.outRows = int(fields[0][2:-4])
        self.outCols = int(fields[1][:-4])
        self.inputDim = int(fields[2][:-11])
        self.weightsArray = self.newMap()
        self.units = self.getUnits()
        self.tempUnits = self.getUnits()
        for line in infile:
            fields = line.strip().split('\t')
            row = int(fields[0])
            col = int(fields[1])
            if (row, col) in self.units:
                self.weightsArray[row][col] = [float(x) for x in fields[2:]]
            else:
                print "could not initialize some with: %s" % line.strip()
        
        infile.close()
    
    def saveSOM(self, somFile):
        outfile = open(somFile,'w')
        outfile.write('# %d rows\t%d cols\t%d dimensions\t%s grid\t%s topology\n' % (self.outRows, self.outCols, self.inputDim, self.gridType, self.topologyType))
        for (row, col) in self.units:
            outvector = '%d\t%d' % (row, col)
            for val in self.weightsArray[row][col]:
                outvector += '\t%f' % val
            outfile.write(outvector + '\n')
        outfile.close()
    
    def saveMap(self, regArray, mapFile, comment=''):
        outfile = open(mapFile,'w')
        if comment != '':
            outfile.write('#%s\n' % comment)
        for row in regArray:
            outfile.write('%.4f' % float(row[0]))
            for col in row[1:]:
                outfile.write('\t%.4f' % float(col))
            outfile.write('\n')
        outfile.close()
    
    def readMap(self, mapFile):
        resultArray = self.newMap()
        infile = open(mapFile)
        rowIndex = 0
        colIndex = 0
        for line in infile:
            if line[0] == '#':
                continue
            colIndex = 0
            fields = line.strip().split()
            for val in fields:
                resultArray[rowIndex][colIndex] = float(val)
                colIndex += 1
            rowIndex += 1
        infile.close()
        
        return resultArray
                
    def describe(self,dimension=-1):
        print "Rows = %d, Cols = %d" % (self.outRows, self.outCols)
        print "grid = %s inputDim = %d" % (self.gridType, self.inputDim)
        if dimension < 0:
            minDim = 0
            maxDim = self.outRows + 1
        else:
            minDim = dimension
            maxDim = dimension + 1
        if self.gridType == 'hex':
            spacer = " " * (len(self.prettyFloat(self.weightsArray[0][0][minDim:maxDim]))/2)
            for row in range(self.outRows):
                # if grid type is vertical columns
                #for col in range(0, self.outCols, 2):
                #    print "[" + self.prettyFloat(self.weightsArray[row][col][minDim:maxDim]) + "]\t\t",
                #print
                #for col in range(1, self.outCols, 2):
                #    print "\t[" + self.prettyFloat(self.weightsArray[row][col][minDim:maxDim]) + "]\t",
                # if grid type is horizontal columns:
                if row % 2 == 1:
                    print spacer,
                for col in range(self.outCols):
                    print " [" + self.prettyFloat(self.weightsArray[row][col][minDim:maxDim]) + "] ",
                print
                print
        else:
            for (row, col) in self.units:
                print row, col, "\t", self.prettyFloat(self.weightsArray[row][col][minDim:maxDim])
	
    def prop(self, inputVector, outRow, outCol):
        distSquared = 0.
        currentUnit = self.weightsArray[outRow][outCol]
        for input,current in zip(inputVector, currentUnit):
            distSquared += (input - current) ** 2
        return distSquared
    
    def euclidean(self, vecA, vecB):
        distance = 0.
        for a,b in zip(vecA, vecB):
            distance += (a - b) ** 2
        return math.sqrt(distance)
        
    def propagate(self, inputVect):
        smallest = 16777216
        locUnits = self.units
        winUnit = (0, 0)
        
        for (row, col) in locUnits:
            magSquared = 0
            currentUnit = self.weightsArray[row][col]
            for input, current in zip(inputVect, currentUnit):
                magSquared += (input - current) ** 2
            if magSquared < smallest:
                winUnit = (row, col)
                smallest = magSquared
        return winUnit
    
    def recDistance(self, aRow, aCol, bRow, bCol):
        distRow = abs(aRow - bRow)
        distCol = abs(aCol - bCol)
        
        #if self.grid in ['cylinder','toroid']:
        diffRow = self.outRows - distRow
        if diffRow < distRow:
            distRow = diffRow
        
        #if self.grid == 'toroid':
        diffCol = self.outCols - distCol
        if diffCol < distCol:
            distCol = diffCol
        
        return max(distRow, distCol)
    
    def immediateHexNeighbors(self, arow, acol):
        arowm = arow - 1
        arowp = arow + 1
        acolm = acol - 1
        acolp = acol + 1
        results = []
        
        arowmOK = False
        if 0 <= arowm:
            arowmOK = True
            results.append((arowm, acol))

        arowpOK = False
        if arowp < self.outRows:
            arowpOK = True
            results.append((arowp, acol))

        acolmOK = False
        if 0 <= acolm:
            acolmOK = True
            results.append((arow, acolm))

        acolpOK = False
        if acolp < self.outCols:
            acolpOK = True
            results.append((arow, acolp))
        
        if arow % 2 == 0:
            if acolmOK:
                if arowmOK:
                    results.append((arowm, acolm))
                if arowpOK:
                    results.append((arowp, acolm))
        else:
            if acolpOK:
                if arowmOK:
                    results.append((arowm, acolp))
                if arowpOK:
                    results.append((arowp, acolp))
        
        return results
    
    def immediateHexNeighborsToroid(self, arow, acol):
        arowm = (arow - 1) % self.outRows
        arowp = (arow + 1) % self.outRows
        acolm = (acol - 1) % self.outCols
        acolp = (acol + 1) % self.outCols
        results = [(arowm, acol), (arowp, acol), (arow, acolm), (arow, acolp)]
        # for vertical columns grid
        #if acol % 2 == 0:
        #    results.append((arowm, acolp))
        #    results.append((arowm, acolm))
        #else:
        #    results.append((arowp, acolp))
        #    results.append((arowp, acolm))
        # for horizontal columns grid
        
        if arow % 2 == 0:
            results.append((arowm, acolm))
            results.append((arowp, acolm))
        else:
            results.append((arowm, acolp))
            results.append((arowp, acolp))
        
        return results
    
    def getHexNeighbors(self, theRow, theCol, startDict, maxRadius):
        resultList = []
        previousUnits = startDict[theRow, theCol][1]
        seenUnits = [(theRow, theCol)] + startDict[theRow, theCol][1]
        radiusList = range(2, maxRadius+1)
        for aRadius in  radiusList:
            resultList.append([])
            currentUnits = []
            nextUnits = []
            for (row, col) in previousUnits:
                for (row, col) in startDict[row, col][1]:
                    if (row, col) not in seenUnits:
                        resultList[-1].append((row, col))
                        seenUnits.insert(0,(row, col))
                        nextUnits.append((row, col))
            previousUnits = nextUnits
        
        return resultList
    
    def buildNeighborDict(self, radius):
        resultDict = self.newMap(mapType='Dict')
        radiusList = range(radius+1)
        if self.gridType == 'rect':
            for (row, col) in self.units:
                resultDict[row, col] = [[] for x in radiusList]
                for (arow, acol) in self.units:
                    dist = self.rectDistance(row, col, arow, acol)
                    if dist <= radius:
                        resultDict[row,col][dist].append((arow, acol))
        else:
            resultDict1 = self.newMap(mapType='Dict')
            for (row, col) in self.units:
                if self.topologyType == 'sheet':
                    resultDict[row,col] = [[(row, col)], self.immediateHexNeighbors(row, col)]
                    resultDict1[row,col] = [[(row, col)], self.immediateHexNeighbors(row, col)]
                else:
                    # default is toroid
                    resultDict[row,col] = [[(row, col)], self.immediateHexNeighborsToroid(row, col)]
                    resultDict1[row,col] = [[(row, col)], self.immediateHexNeighborsToroid(row, col)]
            print "%s\tbuilt neighborhood Dict of radius 1" % (time.asctime())
            index = 0
            for (row, col) in self.units:
                resultDict[row, col] += self.getHexNeighbors(row, col, resultDict1, radius)
                index += 1
                if index % 500 == 0:
                    print "%s\tfinished %d units" % (time.asctime(), index)
        
        return resultDict
    
    def updateUnit(self, oldVec, inVec, rate):
        correctionVector = [rate * (in_i - old_i) for in_i,old_i in zip(inVec, oldVec)]
        return map(add, correctionVector, oldVec)
    
    def update(self, inputVector, learningRate = 0.1, upRadius=1, neighbors={}):
        updateFunction = self.updateUnit        
        if hasExtension:
            winner = cPropagate(self.weightsArray, inputVector, self.outRows, self.outCols, self.inputDim)
            updateFunction = cUpdateUnit
        else:
            winner = self.propagate(inputVector)
        (winnerRow, winnerCol) = winner
        sradius = upRadius + 1
        # update winner
        winnerDistList = neighbors[winner]
        self.weightsArray[winnerRow][winnerCol] = updateFunction(self.weightsArray[winnerRow][winnerCol], inputVector, learningRate)
        
        for dist in range(1, sradius):
            theWeight = learningRate * math.exp(-0.5 * (dist ** 2) / (sradius ** 2))
            for (row, col) in winnerDistList[dist]:
                self.weightsArray[row][col] = updateFunction(self.weightsArray[row][col], inputVector, theWeight)
        
        return winner
	
    def train(self, trainingSet, learningRate=0.2, timeSteps=1000000, printScores=False, radius=-1, neighborDict = {}):
        if radius < 1:
            radius = max(self.outRows, self.outCols) / 2
        # multiplier = -1/timeConstant
        multiplier = -1 * math.log(radius) / float(timeSteps)
        vecIDList = trainingSet.keys()
        numTraining = len(vecIDList)
        random.shuffle(vecIDList)
        if len(neighborDict) == 0:
            print "%s\tbuilding neighborhood Dict of radius %d" % (time.asctime(), radius)
            neighborDict = self.buildNeighborDict(radius)
            print "%s\tbuilt neighborhood Dict" % (time.asctime())
        
        for epoch in xrange(timeSteps):
            vectorID = vecIDList[epoch % numTraining]
            decay = math.exp(epoch * multiplier)
            epochLearningRate = learningRate * decay
            epochRadius = int(radius * decay)
            self.update(trainingSet[vectorID], epochLearningRate, epochRadius, neighbors=neighborDict)
            if epoch % 100000 == 0:
                if printScores:
                    print "calling scoreData()"
                    print "%s\ttime %d\tscore %.3f\tradius %d\tlearn %.4f" % (time.asctime(), epoch, self.scoreData(trainingSet), epochRadius, epochLearningRate)
                else:
                    print "%s\ttime %d\tradius %d\tlearn %.4f" % (time.asctime(), epoch, epochRadius, epochLearningRate)
        
    def initializeWeights(self, trainingSet):
        sampleIDs = trainingSet.keys()
        randomSample = random.sample(sampleIDs, self.outRows * self.outCols)
        index = 0
        for (row, col) in self.units:
            self.weightsArray[row][col] = trainingSet[randomSample[index]]
            index += 1
    
    def datasetSize(self, infilename, startField=0):
        infile = open(infilename)
        index = 0
        dimension = -1
        for line in infile:
            line = line.strip()
            if len(line) < 2:
                continue
            if line[0] == '#':
                continue
            fields = line.strip().split()
            if dimension < 1:
                dimension = len(fields[startField:])
            if len(fields[startField:]) == dimension:
                index +=1
        
        infile.close()
        
        return (dimension, index)
    
    def getDataset(self, infilename, numRecords=3000000000, startField=0, IDField=-1, dimension=0):
        infile = open(infilename)
        index = 0
        if dimension < 1:
            dimension = self.inputDim
        dataDict = {}
        trackID = False
        if IDField > -1:
            trackID = True
        for line in infile:
            line = line.strip()
            if len(line) < 2:
                continue
            if line[0] == '#':
                continue
            fields = line.strip().split()
            currentRec = fields[startField:startField+dimension]
            if len(currentRec) == dimension:
                record = []
                if trackID:
                    recordID = fields[IDField]
                else:
                    recordID = 'record%s' % str(index)                
                try:
                    dataDict[recordID] = [float(x) for x in currentRec]
                    index += 1
                except:
                    print "could not convert %s" % str(fields)
            if index > numRecords:
                break
        
        infile.close()
        
        return dataDict
    
    def getWinners(self, dataDict):
        resultDict = self.newMap(mapType='Dict')
        if hasExtension:
            for recID in dataDict:
                resultDict[cPropagate(self.weightsArray, dataDict[recID], self.outRows, self.outCols, self.inputDim)].append(recID)
        else:
            for recID in dataDict:
                resultDict[self.propagate(dataDict[recID])].append(recID)
        
        return resultDict
    
    def scoreData(self, dataset, outfilename='', printScores=False):
        index = 0
        dimension = self.inputDim
        idList = []
        winnerDict = {}
        totalScore = 0.
        doOutfile = False
        scorefunc = self.euclidean
        if hasExtension:
            scorefunc = cEuclidean
        if outfilename != '':
            doOutfile = True
            outfile = open(outfilename,'w')
        
        winnerDict = self.getWinners(dataset)
        
        for unit in self.units:
            (row, col) = unit
            unitWeights = self.weightsArray[row][col]
            if printScores:
                print '%s: [%s]' % (str(unit), self.prettyFloat(unitWeights))
            if doOutfile:
                outfile.write('unit\t%d,%d\t%s\n' % (row, col, self.prettyFloat(unitWeights, delim='\t')))
            
            
            for recID in winnerDict[unit]:
                winnerScore = scorefunc(dataset[recID], unitWeights)
                totalScore += winnerScore
                if printScores:
                    print '\t%s [%s] %.2f' % (recID, self.prettyFloat(dataset[recID]), winnerScore)
                if doOutfile:
                    outfile.write('%.2f\t%s\t%s\n' % (winnerScore, recID, self.prettyFloat(dataset[recID], delim='\t')))
        
        if doOutfile:
            outfile.close()

        try:
            toreturn=totalScore/len(dataset)
        except:
            toreturn=0

        return toreturn
    
    def readScoreFile(self, scorefilename, full=False):
        wDict = self.newMap(mapType='Dict')
        scorefile = open(scorefilename)
        
        for line in scorefile:
            fields = line.split()
            if 'unit' in line:
                (wrow, wcol) = fields[1].split(',')
                wrow = int(wrow)
                wcol = int(wcol)
                wDict[wrow, wcol] = []
            else:
                if full:
                    wDict[wrow, wcol].append(line)
                else:
                    wDict[wrow, wcol].append(fields[1])
        scorefile.close()
        
        return wDict
        
    def prettyFloat(self,floatList, delim=' '):
        output = ''
        for afloat in floatList:
            output += '%.2f%s' % (afloat, delim)
            
        return output[:-1]
