##################################
#                                #
# Last modified 2017/08/30       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
from sets import Set

def run():

    if len(sys.argv) < 4:
        print 'usage: python %s genes.tsv barcodes.tsv matrix.mtx outfile [-fullBarCode] [-geneIDOnly]' % sys.argv[0]
        sys.exit(1)

    genes = sys.argv[1]
    barcodes = sys.argv[2]
    matrix = sys.argv[3]
    outfilename = sys.argv[4]

    doFBC = False
    if '-fullBarCode' in sys.argv:
        doFBC = True

    doGID = False
    if '-geneIDOnly' in sys.argv:
        doGID = True

    i = 0
    BarcodePosDict = {}
    linelist = open(barcodes)
    for line in linelist:
        i+=1
        fields = line.strip().split('\t')
        barcode = fields[0]
        BarcodePosDict[i] = barcode

    print 'finished parsing barcodes'

    i = 0
    GenePosDict = {}
    linelist = open(genes)
    for line in linelist:
        i+=1
        fields = line.strip().split('\t')
        geneID = fields[0]
        geneName = fields[1]
        GenePosDict[i] = (geneID,geneName)

    print 'finished parsing genes'

    DataMatrix = {}
    CellIDs = {}

    linelist = open(matrix)
    k=0
    for line in linelist:
        k+=1
        if line.startswith('%'):
            continue
        if k == 3:
            continue
        if k % 1000000 == 0:
            print str(k/1000000.) + '+e06 lines processed'
        fields = line.strip().split(' ')
        genePos = int(fields[0])
        cellNumber = int(fields[1])
        UMIcounts = int(fields[2])
        (geneID,geneName) = GenePosDict[genePos]
        if DataMatrix.has_key(((geneID,geneName))):
            pass
        else:
            DataMatrix[(geneID,geneName)] = {}
        if doFBC:
            CellID = BarcodePosDict[cellNumber]
        else:
            CellID = BarcodePosDict[cellNumber].split('-')[1] + '-' + str(cellNumber)
        CellIDs[CellID] = 1
        DataMatrix[(geneID,geneName)][CellID] = UMIcounts

    outfile = open(outfilename, 'w')

    Cells = CellIDs.keys()
    Cells.sort()

    if doGID:
        outline = '#geneID'
    else:
        outline = '#geneID\tgeneName'
    for C in Cells:
        outline = outline + '\t' + C
    outfile.write(outline + '\n')

    geneIDs = DataMatrix.keys()
    geneIDs.sort()

    for (geneID,geneName) in geneIDs:
        if doGID:
            outline = geneID
        else:
            outline = geneID + '\t' + geneName
        for C in Cells:
            if DataMatrix[(geneID,geneName)].has_key(C):
                outline = outline + '\t' + str(DataMatrix[(geneID,geneName)][C])
            else:
                outline = outline + '\t' + str(0)
        outfile.write(outline + '\n')

    outfile.close()

run()