""" FileIO retuines for loading data from biological sources.  Such as
scanner outputs and the likes.
"""

from __future__ import nested_scopes

import string
import os
import sys
import operator

import Numeric
import MLab
import MA

import compClust.mlx.datasets
from compClust.mlx import labelings
from compClust.util.TimeStampedPrintStream import TimeStampedPrintStream
#from bioinformatics.schema import datasets
from compClust.util import listOps


def loadGPRFile(GPRFile, columns = ['Median of Ratios',
                                    'F635 Median',
                                    'F532 Median',
                                    '% > B532+2SD',
                                    '% > B635+2SD',
                                    'Flags'],
                ):
  
  """
  loadGPRFile(GPRFile, columns )

  if columns is set to None all columns are loaded, otherwise only the specified columns are loaded.
  """
  
  IF = open(GPRFile,'r')
  lines = IF.readlines()

  version = lines[0].split()[1].strip().replace("'", '')
  numHeaderLines = int(lines[1].split()[0])
  
  if version != '1.0' :
    print "Version Mismatch ! "
    return

  # read the header
  header = {}
  for i in range(2,numHeaderLines+2):
    tokens = map(lambda x: x.strip().replace('"','') , lines[i].split('='))
    header[tokens[0]]= tokens[1]


  columnHeaders = map(lambda x: x.strip().replace('"',''), lines[numHeaderLines+2].split('\t'))
  if columns is None:
    columns = columnHeaders[5:]

  columnDict = {}
  map(lambda x: columnDict.setdefault(x, columnHeaders.index(x)), columns)
  
  # set up the row Labels
  BCRs  = []
  names = []
  IDs   = []
  rows  = []
  data  = []
  
  # read the data, keeping track of the row labels
  for line in lines[numHeaderLines+3:]:
    tokens = map(lambda x: x.strip(), line.split('\t'))
    BCRs.append( "%s,%s,%s"%(tokens[0],tokens[1], tokens[2]))
    rows.append(tokens[2].replace('"',''))
    names.append(tokens[3].replace('"',''))
    IDs.append(tokens[4].replace('"',''))
    rowData = []
    for col in columns:
      try:
        rowData.append(float(tokens[columnDict[col]]))
      except:
        rowData.append(float('Nan'))
        #rowData.append(-99999999)
    data.append(rowData)

  a = Numeric.array(data)
  m = Numeric.not_equal(a, a)
  data = MA.MaskedArray(a, copy=0, mask=m)
  ds = compClust.mlx.datasets.Dataset(data)
  ds.setName(GPRFile)

  # now to add some labelings.
  l = labelings.GlobalWrapper(ds,'columns')
  l.labelCols(columns)

  l = labelings.GlobalWrapper(ds,'BCR')
  l.labelRows(BCRs)

  l = labelings.GlobalWrapper(ds,'names')
  l.labelRows(names)

  l = labelings.GlobalWrapper(ds,'IDs')
  l.labelRows(IDs)

  l = labelings.GlobalWrapper(ds,'rows')
  l.labelRows(rows)

  return(ds, header)

def loadGPRFiles(path, index,
                 columns = ['Median of Ratios',
                            'F635 Median',
                            'F532 Median',
                            'F635 Mean',
                            'F532 Mean',
                            '% > B532+2SD',
                            '% > B635+2SD'],
                 headers = ['Normalization',
                            'RatioFormulation',
                            'PMTVolts'],
                 ):
  
  """
  Load all files specified in the index.txt file into a
  MicroarrayDataset

  Each column specified will become a layer in the MicroarrayDataset
  (order in the list, dictates layer order).

  A column labeling will be created for each header specified and the
  contents of the header will be the label for each dimension in the
  dataset.

  The index file specifies which GPR files to load, and what
  annotations should be marked on each of the columns. 
  
  FileName (Relitive to path) \t ID \t <any amount of other information>
  """

  outstream = TimeStampedPrintStream(stream=sys.stdout)
  
  # process the index file
  lines = open(os.path.join(path, index), 'r').readlines()
  indexCols = map(lambda x: x.strip(), lines[0].split('\t'))


  # all the column labels for the growing dataset
  colLabels = {}
  colLabels['GPR Column'] =[]
  for col in indexCols:
    colLabels[col] = []
  for header in headers:
    colLabels[header] = []


  # the beginings of a dataset, each column is a layer in the dataset
  data = []
  #map(data.setdefault, columns)

  # now to create an index file mapping of column number to condition label
  tokens = map(string.strip, lines[0].split('\t'))
  col2head = {}
  map(col2head.setdefault, range(len(tokens)), tokens)

  # now to process the GPR files listed in the index file...
  for line in lines[1:]:
    tokens = map(lambda x: x.strip(), line.split('\t'))
    if not line.strip():
      break
    gprFile = os.path.join(path, tokens[0])
    outstream.write("loading  %s..\n"%(gprFile))
    sys.stdout.flush()
    gpr, gprHeader  = loadGPRFile(os.path.join(path, gprFile), columns= columns)
    gprColLabs = gpr.getLabeling('columns')
    
    # now we get the data from the gpr dataset
    for column in columns:
      data.append(gpr.getColData(gprColLabs.getColsByLabel(column)[0]))
      # keep track of the column labelings
      for header in headers:
        colLabels[header].append(gprHeader[header])
      for token, col in zip(tokens, range(len(tokens))):
        colLabels[col2head[col]].append(token)
      colLabels['GPR Column'].append(column)
      
      
  # LABEL the real dataset
  ds = compClust.mlx.datasets.Dataset(Numeric.transpose(Numeric.array(data)))
  ds.setName('GPR Data loaded from %s'%(os.path.join(path, index)))

  for key in colLabels.keys():
    l = labelings.GlobalWrapper(ds,key)
    l.labelCols(colLabels[key])


  # add on some row labelings that are common between the GPRs (this
  # isn't checked, but assumed).

  l = labelings.GlobalWrapper(ds,'BCR')
  l.labelRows(gpr.getLabeling('BCR').getLabelByRows(range(gpr.getNumRows())))
  #ds.setPrimaryLabeling(l)

  l = labelings.GlobalWrapper(ds,'names')
  l.labelRows(gpr.getLabeling('names').getLabelByRows(range(gpr.getNumRows())))
  
  l = labelings.GlobalWrapper(ds,'IDs')
  l.labelRows(gpr.getLabeling('IDs').getLabelByRows(range(gpr.getNumRows())))
  #l = labelings.GlobalWrapper(ds,'rows')
  #l.labelRows(gpr.getLabeling('rows').getLabelByRows(range(gpr.getNumRows())))
  #ds.setFeatureDescriptions(l)

  #ds.setDisplayedLabelings([ds.getLabeling('BCR'),
  #ds.getLabeling('names'),
  #                          ds.getLabeling('rows')])
  
  return(ds)

def loadAgilentFile(file, numericColumns, rowLabelings):

  """
  loadAgilentFile(file, numericColumns, rowLabelings)

  loads the tab-deliminated text output of the agilent microarray
  """

  
  lines = open(file, 'r').readlines()

  # figure out where the data actually starts.
  n = 0
  while lines[n].split()[0] != 'FEATURES':
    n+=1
  startline= n

  # create mappings of col -> header and header -> col
  col2head = {} 
  head2col = {}
  tokens = string.strip(lines[startline]).split('\t')
  map(col2head.setdefault , range(len(tokens)), tokens)
  map(head2col.setdefault , tokens, range(len(tokens)))
  
  # work out which columns we want for data and which columns will be come
  # row labelings.
  dataCols  = map(operator.getitem, [head2col]*len(numericColumns), numericColumns)
  labelCols = map(operator.getitem, [head2col]*len(rowLabelings), rowLabelings)

  data = []
  rowLabels = []
  for line in lines[startline+1:]:
    tokens = string.strip(line).split('\t')
    data.append(map(float, map(operator.getitem , [tokens]*len(dataCols), dataCols)))
    rowLabels.append(map(operator.getitem , [tokens]*len(labelCols), labelCols))

  # now create the dataset
  ds = compClust.mlx.datasets.Dataset(data)

  # now add row labelings.
  labs = map (lambda x: labelings.Labeling(ds, x), rowLabelings)
  for row in range(ds.getNumRows()):
    for label, labeling in zip(rowLabels[row], labs):
      labeling.addLabelToRow(label, row)

  # now add the col labels
  colNames = labelings.Labeling(ds, 'Column Names')
  for i in range(len(dataCols)):
    colNames.addLabelToCol(col2head[dataCols[i]] ,i)
  
  return(ds)

def processedAgilentLogRatios(ds):

  """
  processedAgilentLogRatios(ds, row):

  This is a FunctionView 

  returns a list of log10 ratios for the agilent dataset.  No data
  is truncated.  If both the red and green channel data are below
  background, a log ratio of 0.0 is imputed.  If only one of the
  channels are below background then the channel whose measurement is
  below background is replaced with that channels background
  measuement + 1 std. dev. (our detection threshold) and a log ratio
  is calculated.  A labeling named 'floored rows' is attached to ds
  which marks all rows which were either floored in the 'green',
  'red', or 'both' channels.

  required Labelings:
    'Column Names'
  requred data columns (indexed in Column Names):
    'gMedianSignal'
    'gBGMeanSignal'
    'gPixSDev'
    'rMedianSignal'
    'rBGMeanSignal'
    'rPixSDev'
    'gIsPosAndSignif'
    'rIsPosAndSignif'
    'gIsWellAboveBG'
    'rIsWellAboveBG'
    'LogRatio'
  """
  
  flooredRows =  labelings.Labeling(ds, 'floored rows')

  data = []
  cLabs = ds.getLabeling('Column Names')
  
  for row in range(ds.getNumRows()):
    rowData = ds.getRowData(row)

    redBad = 0
    greenBad = 0
    
    # check for control spots
    if rowData[cLabs.getColsByLabel('ControlType')[0]] != 0:
      redBad = greenBad = 0

    # check for high enough signal intensity in one of the two channels
    if (rowData[cLabs.getColsByLabel('gMedianSignal')[0]] <
        (rowData[cLabs.getColsByLabel('gBGMeanSignal')[0]] +
         rowData[cLabs.getColsByLabel('gPixSDev')[0]] )):
      greenBad = 1

    if (rowData[cLabs.getColsByLabel('rMedianSignal')[0]] <
        (rowData[cLabs.getColsByLabel('rBGMeanSignal')[0]] +
         rowData[cLabs.getColsByLabel('rPixSDev')[0]] )):
      redBad = 1

    # check to make sure one of the two channels is makred as pos and sig
    if (rowData[cLabs.getColsByLabel('gIsPosAndSignif')[0]] == 0):
      greenBad = 1
    if (rowData[cLabs.getColsByLabel('rIsPosAndSignif')[0]] == 0):
      redBad = 1

    # check to make sure one of the two channes is marked as well above BG
    if (rowData[cLabs.getColsByLabel('gIsWellAboveBG') [0]] == 0):
      greenBad = 1

    if(rowData[cLabs.getColsByLabel('rIsWellAboveBG')[0]]  == 0):
      redBad = 1


    if not(redBad and greenBad):
      data.append(rowData[cLabs.getColsByLabel('LogRatio')[0]])
    elif redBad and greenBad:
      data.append(0.0)
      flooredRows.addLabelToRow('both', row)
    elif redBad:
      # replace red measurement with red's BG + 1xSTD
      data.append(
        MLab.log( (rowData[cLabs.getColsByLabel('rBGMeanSignal')[0]] +
                   rowData[cLabs.getColsByLabel('rPixSDev')[0]] ) /
                  rowData[cLabs.getColsByLabel('gMedianSignal')[0]]) /
        MLab.log(10)
        )
      flooredRows.addLabelToRow('red', row)
    elif  greenBad:
      # replace green measurement with green's BG + 1xSTD 
      data.append(
        MLab.log(rowData[cLabs.getColsByLabel('rMedianSignal')[0]] /
                 (rowData[cLabs.getColsByLabel('gBGMeanSignal')[0]] +
                  rowData[cLabs.getColsByLabel('gPixSDev')[0]] )) /
        MLab.log(10)
        )
      flooredRows.addLabelToRow('green', row)

  return(data)

def fastLoadAgilentFiles(index, numericColumns= ['gMedianSignal',
                                                 'rMedianSignal',
                                                 'LogRatio',
                                                 'PValueLogRatio'],
                         rowLabelings=['FeatureNum'],
                         process=1, outfile=None):


  outstream = TimeStampedPrintStream("%Y-%b-%d %H:%M:%S > ", stream=sys.stdout)
  # read the index file
  lines = open(index).readlines()
  indexCols = map(string.strip, lines[0].split('\t'))

  colLabels = {}
  colLabels['data column'] =[]

  for col in indexCols:
    colLabels[col] = []

  data = []
  for indexLine in lines[1:]:
    indexTokens = indexLine.split('\t')
    agilentFile = indexTokens[0].strip()
    outstream.write('loading(%i/%i): %s\n'%(lines[1:].index(indexLine), len(lines[1:]), agilentFile))
    sys.stdout.flush()
    dataLines = open(agilentFile, 'r').readlines()
    
    # figure out where the data actually starts.
    n = 0
    while dataLines[n].split()[0] != 'FEATURES':
      n+=1
    startline= n

    # create mappings of col -> header and header -> col
    col2head = {} 
    head2col = {}
    tokens = string.strip(dataLines[startline]).split('\t')
    map(col2head.setdefault , range(len(tokens)), tokens)
    map(head2col.setdefault , tokens, range(len(tokens)))
  
    # work out which columns we want for data 
    dataCols  = map(operator.getitem, [head2col]*len(numericColumns), numericColumns)
    colLabels['data column'].extend(numericColumns)


    for name,index in zip(indexCols, 
                          range(len(indexCols))):
      colLabels[name].extend([indexTokens[index].strip()]*len(numericColumns))

    # work out which columns we want for the row labelings.
    if indexLine == lines[-1]:
      rowLabels = []
      labelCols = map(operator.getitem, [head2col]*len(rowLabelings), rowLabelings)

    fileData = []
    for line in dataLines[startline+1:]:
      tokens = string.strip(line).split('\t')
      fileData.append(map(float, map(operator.getitem , [tokens]*len(dataCols), dataCols)))

      if indexLine == lines[-1]:
        rowLabels.append(map(operator.getitem , [tokens]*len(labelCols), labelCols)) 
    data.append(fileData)

  
  # remarkably this works.  The zip merges each line of each file into a single list.
  data = Numeric.array (apply(zip, data))
  # this merges the last 2 dimensions 
  data = Numeric.reshape(data, (data.shape[0], data.shape[1]*data.shape[2],))
         
  #data = Numeric.transpose(Numeric.reshape(data,
  #                                         (len(lines[1:])*len(numericColumns),
  #                                          len(data)/len(lines[1:]),)))
  ds = compClust.mlx.datasets.Dataset(data)

  # now add row labelings.
  labs = map (lambda x: labelings.GlobalWrapper(ds, x), rowLabelings)
  for row in range(ds.getNumRows()):
    for label, labeling in zip(rowLabels[row], labs):
      labeling.addLabelToRow(label, row)

  # now add the col labels
  for key in colLabels.keys():
    l = labelings.GlobalWrapper(ds,key)
    l.labelCols(colLabels[key])

  return (ds)

def loadAgilentFiles(index, numericColumns= ['gMedianSignal',
                                             'rMedianSignal',
                                             'LogRatio',
                                             'processedRatio',
                                             'PValueLogRatio'],
                     process=1, outfile=None):

  """
  Reads all files specified in the index file.  The first line descibes
  the columns of the file.  The first column is always the filename of the
  file to be loaded, then for each additional column specified column
  covariates can be added.

  """

  lines = open(index).readlines()
  indexCols = map(lambda x: x.strip(), lines[0].split('\t'))

  colLabels = {}
  colLabels['data column'] =[]

  for col in indexCols:
    colLabels[col] = []

  data = []
  for line in lines[1:]:

    tokens = line.split('\t')
    agilentFile = tokens[0].strip()
    print 'reading %s...\n'%(agilentFile)
    sys.stdout.flush()
    tmpNumericColumns=['ControlType', 'gMedianSignal',
                       'rMedianSignal', 'gBGMeanSignal',
                       'rMeanSignal', 'gMeanSignal',
                       'rBGMeanSignal', 'gPixSDev',
                       'rPixSDev', 'PValueLogRatio',
                       'gIsPosAndSignif', 'rIsPosAndSignif',
                       'gIsWellAboveBG', 'rIsWellAboveBG',
                       'LogRatio' ]
    if line == lines[-1]:
      rowLabelings=['FeatureNum',]
                    #'ProbeUID',
                    #'ProbeName',
                    #'GeneName',
                    #'SystematicName',
                    #'Description']
    else:
      rowLabelings = []

    tmpDS = loadAgilentFile(agilentFile, tmpNumericColumns, rowLabelings)
    for c in numericColumns:
      for colName,colIndex in zip(indexCols, 
                                  range(len(indexCols))):
        colLabels[colName].append(tokens[colIndex].strip())


    colNamesLab = tmpDS.getLabeling('Column Names')
    for col in listOps.difference(numericColumns, ['processedRatio']):
      data.append(tmpDS.getColData(colNamesLab.getColsByLabel(col)[0]))
      colLabels['data column'].append(col)

 
    if 'processedRatio' in numericColumns:
      colLabels['data column'].append('processedRatio')
      colData  = processedAgilentLogRatios(tmpDS)
      data.append(colData)
    #if outfile:
    #  OF = open(outfile, 'a')
    #  for v in colData:
    #    OF.write('%3.4f\t'%(v))
    #  OF.write('\n')
    #  OF.close()
                 

  #ostream.write('Constructing the Dataset...')
  # build the dataset
  ds = compClust.mlx.datasets.Dataset(Numeric.transpose(data))

  # label the columns
  for key in colLabels.keys():
    l = labelings.GlobalWrapper(ds,key)
    l.labelCols(colLabels[key])
   
  # label the rows.
  l = labelings.GlobalWrapper(ds, 'FeatureNum')
  l.labelRows(tmpDS.getLabeling('FeatureNum').getLabelByRows())
  #l = labelings.GlobalWrapper(ds, 'ProbeUID')
  #l.labelRows(tmpDS.getLabeling('ProbeUID').getLabelByRows())
  #l = labelings.GlobalWrapper(ds, 'ProbeName')
  #l.labelRows(tmpDS.getLabeling('ProbeName').getLabelByRows())
  #l = labelings.GlobalWrapper(ds, 'GeneName')
  #l.labelRows(tmpDS.getLabeling('GeneName').getLabelByRows())
  #l = labelings.GlobalWrapper(ds, 'SystematicName')
  #l.labelRows(tmpDS.getLabeling('SystematicName').getLabelByRows())
  #l = labelings.GlobalWrapper(ds, 'Description')
  #l.labelRows(tmpDS.getLabeling('Description').getLabelByRows())

  return (ds)


