import sys, string

versionString = '%s: version 1.1' % sys.argv[0]
print versionString

try:
    import psyco
    psyco.full()
except:
    print "psyco not running"

from som import SOM
from mapGraphics import saveMapPNG, coordTranslate

if len(sys.argv) < 4:
    print "usage: python %s somfile scorefile outprefix [-cluster clusterfile] [-translate dx,dy] [-full] [-regions]" % sys.argv[0]
    sys.exit(0)

somfile = sys.argv[1]
scorefile = sys.argv[2]
outprefix = sys.argv[3]

dx = 0
dy = 0
doFull = False
if '-full' in sys.argv:
    doFull = True

doCoords = False
if '-translate' in sys.argv:
    coords = sys.argv[sys.argv.index('-translate') + 1]
    (dx, dy) = coords.split(',')
    dx = int(dx)
    dy = int(dy)
    doCoords = True

selectUnits = False
selectList = []
if '-cluster' in sys.argv:
    selectUnits = True
    clusterfile = open(sys.argv[sys.argv.index('-cluster') + 1])
    for line in clusterfile:
        fields = line.strip().split('\t')
        unitID = fields[0]
        uCoords = unitID.split('_')
        uRow = int(uCoords[1])
        uCol = int(uCoords[2])
        selectList.append((uRow, uCol))

doRegions = False
if '-regions' in sys.argv:
    doRegions = True
    
mysom = SOM(initialFile=somfile)

xmax = mysom.outCols
ymax = mysom.outRows
onePercent = int(xmax * ymax / 100.)
if onePercent < 1:
    onePercent = 1

winnerDict = mysom.readScoreFile(scorefile, full=True)

units = mysom.getUnits()

for (row, col) in units:
    if selectUnits and (row, col) not in selectList:
        continue
    outfilename = '%s_%d_%d.unit' % (outprefix, row, col)
    outfile = open(outfilename,'w')
    if doRegions:
        outfile.write('#regionID\tchrom\tstart\tstop\tscore')
    else:
        outfile.write('#score\tregionID')
    if doFull:
        outfile.write('\tother_columns\n')
    else:
        outfile.write('\n')
    
    for line in winnerDict[row, col]:
        fields = line.split()
        if doRegions:
            (chrom,coord) = fields[1].split(':')
            (start,stop) = coords.split('-')
            outline = '%s\t%s\t%s\t%s\t%s' % (fields[1], chrom, start, stop, fields[0])
            if doFull:
                outline += string.join(fields[2:],'\t')
            outline += '\n'
        else:
            if doFull:
                outline = line
            else:
                outline = '%s\t%s\n' % (fields[0], fields[1])
        outfile.write(outline)
    outfile.close()

