##################################
#                                #
# Last modified 2017/08/21       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s config outfile' % sys.argv[0]
        print '\tconfig format:'
        print '\t\tlabel <tab> genes.tsv <tab> matrix.mtx'
        sys.exit(1)

    config = sys.argv[1]
    outfilename = sys.argv[2]

    DataFiles = []

    linelist = open(config)
    for line in linelist:
        fields = line.strip().split('\t')
        label = fields[0]
        genes = fields[1]
        matrix = fields[2]
        DataFiles.append((label,genes,matrix))

    DataMatrix = {}
    CellIDs = {}

    for (label,genes,matrix) in DataFiles:
        print label
        GenePosDict = {}
        i=0
        linelist = open(genes)
        for line in linelist:
            i+=1
            fields = line.strip().split('\t')
            geneID = fields[0]
            geneName = fields[1]
            GenePosDict[i] = (geneID,geneName)
        linelist = open(matrix)
        k=0
        for line in linelist:
            k+=1
            if line.startswith('%'):
                continue
            if k == 3:
                continue
            if k % 1000000 == 0:
                print label, str(k/1000000.) + '+e06 lines processed'
            fields = line.strip().split(' ')
            genePos = int(fields[0])
            cellNumber = int(fields[1])
            UMIcounts = int(fields[2])
            (geneID,geneName) = GenePosDict[genePos]
            if DataMatrix.has_key(((geneID,geneName))):
                pass
            else:
                DataMatrix[(geneID,geneName)] = {}
            CellID = (label,cellNumber)
            CellIDs[CellID] = 1
            DataMatrix[(geneID,geneName)][CellID] = UMIcounts

    outfile = open(outfilename, 'w')

    Cells = CellIDs.keys()
    Cells.sort()

    outline = '#geneID\tgeneName'
    for (label,cellNumber) in Cells:
        outline = outline + '\t' + label + '_' + str(cellNumber)
    outfile.write(outline + '\n')

    geneIDs = DataMatrix.keys()
    geneIDs.sort()

    for (geneID,geneName) in geneIDs:
        outline = geneID + '\t' + geneName
        for (label,cellNumber) in Cells:
            if DataMatrix[(geneID,geneName)].has_key((label,cellNumber)):
                outline = outline + '\t' + str(DataMatrix[(geneID,geneName)][(label,cellNumber)])
            else:
                outline = outline + '\t' + str(0)
        outfile.write(outline + '\n')

    outfile.close()

run()