##################################
#                                #
# Last modified 08/29/2010       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import math
from sets import Set

def run():

    if len(sys.argv) < 5:
        print 'usage: python %s  <data table> <labels file> threshold <dataset | cellTypeAll | cellType2> outfilename' % sys.argv[0]
        print '	data table: spreadsheet with abundance values for each elements in each dataset'
        print '	labels files format:'
        print '		<labels fields, comma separated> <tab> <dataset name (should be the same as in the data table) <tab> <cell type belonging to>>'
        print '	number of datasets and values should be comma-separated'
        sys.exit(1)

    datafilename = sys.argv[1]

    labelsFile=sys.argv[2]
    threshold = float(sys.argv[3])
    criteria = sys.argv[4]
    outfilename = sys.argv[5]

    outfile = open(outfilename, 'w')

    LabelDict={}
    lineslist  = open(labelsFile)
    for line in lineslist:
        fields=line.strip().split('\t')
        labels=[]
        labefields=fields[0].split(',')
        for l in labefields:
            labels.append(int(l))
        CellType=fields[2]
        dataset=fields[1]
        if LabelDict.has_key(CellType):
            pass
        else:
            LabelDict[CellType]={}
        LabelDict[CellType][dataset]=''

    print LabelDict

    FinalsStatsDict={}
    for CellType in LabelDict:
        FinalsStatsDict[CellType]=0

    lineslist  = open(datafilename)
    i=0
    IDtoLabelDict={}
    for line in lineslist:
        i+=1
        if i % 10000 == 0:
            print i
        fields=line.strip().split('\t')
        if line.startswith('#'):
            outfile.write(line)
            for dataset in fields:
                ID=fields.index(dataset)
                for CellType in LabelDict.keys():
                    if dataset in LabelDict[CellType]:
                        IDtoLabelDict[ID]=(CellType,dataset)
            PosList=IDtoLabelDict.keys()
            continue
        label=[]
        for ID in labels:
            label.append(fields[ID])
        label=tuple(label)
        DataDict={}
        for pos in PosList:
            (CellType,name)=IDtoLabelDict[pos]
            DataDict[(CellType,name)]=float(fields[pos])
        if criteria=='dataset':
            found=[]
            for (CellType,name) in DataDict.keys():
                if DataDict[(CellType,name)]>=threshold:
                     found.append(CellType)
            found=list(Set(found))
            if len(found)==1:
                outfile.write(line)
                FinalsStatsDict[found[0]]+=1
        if criteria=='cellTypeAll':
            foundDict={}
            found=[]
            for (CellType,name) in DataDict.keys():
                if DataDict[(CellType,name)]>=threshold:
                     if foundDict.has_key(CellType):
                         foundDict[CellType].append(name)
                     else:
                         foundDict[CellType]=[]
                         foundDict[CellType].append(name)
            for CellType in foundDict.keys():
                if len(foundDict[CellType])==len(LabelDict[CellType]):
                    found.append(CellType)
            if len(found)==1:
                outfile.write(line)
                FinalsStatsDict[found[0]]+=1
        if criteria=='cellType2':
            foundDict={}
            found=[]
            for (CellType,name) in DataDict.keys():
                if DataDict[(CellType,name)]>=threshold:
                     if foundDict.has_key(CellType):
                         foundDict[CellType].append(name)
                     else:
                         foundDict[CellType]=[]
                         foundDict[CellType].append(name)
            for CellType in foundDict.keys():
                if len(foundDict[CellType])>=min(2,len(LabelDict[CellType])):
                    found.append(CellType)
            if len(found)==1:
                outfile.write(line)
                FinalsStatsDict[found[0]]+=1

    outfile.write('############################################################\n')
    outfile.write('############################################################\n')
    outfile.write('############################################################\n')
    outfile.write('############################################################\n')
    outline='#CellType\tNumber_uniquely_expressed_genes'
    outfile.write(outline+'\n')
    keys=FinalsStatsDict.keys()
    keys.sort()
    for CellType in keys:
        outline=CellType+'\t'+str(FinalsStatsDict[CellType])
        outfile.write(outline+'\n')

    print FinalsStatsDict

    outfile.close()
        
run()

