 ##################################
#                                #
# Last modified 08/28/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 6:
        print 'usage: python %s  <data table> <labels file> <number datasets separators> <value separators> <dataset | cellTypeAll | cellType2> outfilename' % sys.argv[0]
        print '	data table: spreadsheet with abundance values for each elements in each dataset'
        print '	labels files format:'
        print '		<labels fields, comma separated> <tab> <dataset name (should be the same as in the data table) <tab> <cell type belonging to>>'
        print '	number of datasets and values should be comma-separated'
        sys.exit(1)

    datafilename = sys.argv[1]

    NumDatasets=[]
    labelsFile=sys.argv[2]
    fields = sys.argv[3].split(',')
    for ID in fields:
        NumDatasets.append(int(ID))
    NumDatasets.append(0)
    NumDatasets=list(Set(NumDatasets))

    CutoffValues=[]
    fields = sys.argv[4].split(',')
    for ID in fields:
        CutoffValues.append(float(ID))
    CutoffValues.append(0)
    CutoffValues=list(Set(CutoffValues))

    criteria = sys.argv[5]

    outfilename = sys.argv[6]

    outfile = open(outfilename, 'w')

    LabelDict={}
    lineslist  = open(labelsFile)
    for line in lineslist:
        fields=line.strip().split('\t')
        labels=[]
        labefields=fields[0].split(',')
        for l in labefields:
            labels.append(int(l))
        CellType=fields[2]
        dataset=fields[1]
        if LabelDict.has_key(CellType):
            pass
        else:
            LabelDict[CellType]={}
        LabelDict[CellType][dataset]=''

#    print labels
#    print LabelDict

    lineslist  = open(datafilename)
    i=0
    IDtoLabelDict={}
    DataDict={}
    for line in lineslist:
        i+=1
        if i % 10000 == 0:
            print i
        fields=line.strip().split('\t')
        if line.startswith('#'):
            for dataset in fields:
                ID=fields.index(dataset)
                for CellType in LabelDict.keys():
                    if dataset in LabelDict[CellType]:
                        IDtoLabelDict[ID]=(CellType,dataset)
            PosList=IDtoLabelDict.keys()
#            print IDtoLabelDict
            continue
        label=[]
        for ID in labels:
            label.append(fields[ID])
        label=tuple(label)
        DataDict[label]={}
        for pos in PosList:
            (CellType,name)=IDtoLabelDict[pos]
            if fields[pos] == 'FAIL':
                DataDict[label][(CellType,name)]=0
            else:
                DataDict[label][(CellType,name)]=float(fields[pos])

    print i, 'elements found'

    CutoffValues.sort()

    FinalsStatsDict={}
    for cutoff in CutoffValues:
        FinalsStatsDict[cutoff]={}
        for num in NumDatasets:
            FinalsStatsDict[cutoff][num]=0
    numDatasetFound=FinalsStatsDict[cutoff].keys()
    numDatasetFound.sort()

    print numDatasetFound
 
    if criteria=='dataset':
         print 'dataset'
         for label in DataDict.keys():
             for cutoff in CutoffValues:
                 found=0
                 for (CellType,name) in DataDict[label].keys():
                     if DataDict[label][(CellType,name)]>=cutoff:
                         found+=1
                 if found >= numDatasetFound[-1]:
                     FinalsStatsDict[cutoff][numDatasetFound[-1]]+=1
                 else:
                     for num in numDatasetFound[0:-1]:
                         if found >= num and found < numDatasetFound[numDatasetFound.index(num)+1]:
                             FinalsStatsDict[cutoff][num]+=1
    if criteria=='cellTypeAll':
         print 'cellTypeAll'
         for label in DataDict.keys():
             ValueDict={}
             for (CellType,name) in DataDict[label].keys():
                 ValueDict[CellType]=[]
             for (CellType,name) in DataDict[label].keys():
                 ValueDict[CellType].append(DataDict[label][(CellType,name)])
             for cutoff in CutoffValues:
                 found=0
                 for CellType in ValueDict.keys():
                     foundInCellType=0
                     for value in ValueDict[CellType]:
                         if value >= cutoff:
                             foundInCellType+=1
                     if foundInCellType == len(ValueDict[CellType]):
                         found+=1
                 if found >= numDatasetFound[-1]:
                     FinalsStatsDict[cutoff][numDatasetFound[-1]]+=1
                 else:
                     for num in numDatasetFound[0:-1]:
                         if found >= num and found < numDatasetFound[numDatasetFound.index(num)+1]:
                             FinalsStatsDict[cutoff][num]+=1
    if criteria=='cellType2':
         print 'cellType2'
         for label in DataDict.keys():
             ValueDict={}
             for (CellType,name) in DataDict[label].keys():
                 ValueDict[CellType]=[]
             for (CellType,name) in DataDict[label].keys():
                 ValueDict[CellType].append(DataDict[label][(CellType,name)])
             for cutoff in CutoffValues:
                 found=0
                 for CellType in ValueDict.keys():
                     foundInCellType=0
                     for value in ValueDict[CellType]:
                         if value >= cutoff:
                             foundInCellType+=1
                     if foundInCellType >= min(2,len(ValueDict[CellType])):
                         found+=1
                 assigned=False
                 if found >= numDatasetFound[-1]:
                     FinalsStatsDict[cutoff][numDatasetFound[-1]]+=1
                 else:
                     for num in numDatasetFound[0:-1]:
                         if found >= num and found < numDatasetFound[numDatasetFound.index(num)+1]:
                             FinalsStatsDict[cutoff][num]+=1

    outline='#datasets\cutoff'
    for cutoff in CutoffValues:
        outline=outline+'\t'+str(cutoff)
    outfile.write(outline+'\n')

#    print FinalsStatsDict

    NumDatasets.sort()

    for num in NumDatasets:
        outline=str(num)
        for cutoff in CutoffValues:
            outline=outline+'\t'+str(FinalsStatsDict[cutoff][num])
        outfile.write(outline+'\n')

    outfile.close()
        
run()

