##################################
#                                #
# Last modified 2023/07/25       # 
#                                #
# Georgi Marinov                 #
#                                # 
##################################

import sys
import string
import numpy
from sets import Set

def run():

    if len(sys.argv) < 2:
        print 'usage: python %s config outprefix' % sys.argv[0]
        print '\tconfig format:'
        print '\t\tlabel\tlist1\tbarcode_fieldID\tbarcodes.tsv\tgenes.tsv\tmatrix.mtx'
        sys.exit(1)

    LabelDict = {}
    GeneDict = {}

    lineslist = open(sys.argv[1])
    for line in lineslist:
        fields = line.strip().split('\t')
        label = fields[0]
        if LabelDict.has_key(label):
            pass
        else:
            LabelDict[label] = {}
        barcodesFile = fields[1]
        print label, barcodesFile
        fieldID = int(fields[2])
        barcodesTSV = fields[3]
        genesTSV = fields[4]
        matrix = fields[5]

        WantedBCs = {}
        lines = open(barcodesFile)
        for L in lines:
            FFIELDS = L.strip().split('\t')
            WantedBCs[FFIELDS[fieldID]] = 1

        BCDict = {}
        BCindex = 1
        lines = open(barcodesTSV)
        for L in lines:
            BC = L.strip()
            BCDict[BCindex] = BC
            BCindex += 1
        
        geneDict = {}
        geneIndex = 1
        lines = open(genesTSV)
        for L in lines:
            gene = L.strip()
            geneDict[geneIndex] = gene
            geneIndex += 1

        lines = open(matrix)
        LL = 0
        for L in lines:
            if L.startswith('%'):
                continue
            LL+=1
            if LL == 1:
                continue
            FFIELDS = L.strip().split(' ')
            BCindex = int(FFIELDS[1])
            BC = BCDict[BCindex]
            if WantedBCs.has_key(BC):
                pass
            else:
                continue
            geneIndex = int(FFIELDS[0])
            counts = float(FFIELDS[2])
            gene = geneDict[geneIndex]
            if LabelDict[label].has_key(gene):
                pass
            else:
                LabelDict[label][gene] = 0
            LabelDict[label][gene] += 1
            GeneDict[gene] = 1


    outfile = open(sys.argv[2] + '.counts', 'w')

    TotalLabelCountsDict = {}

    outline = '#'
    labels = LabelDict.keys()
    labels.sort()
    genes = GeneDict.keys()
    genes.sort()

    for label in labels:
        outline = outline + '\t' + label
        TotalLabelCountsDict[label] = 0.0
    outfile.write(outline + '\n')

    for gene in genes:
        outline = gene.replace('\t',' ')
        for label in labels:
            if LabelDict[label].has_key(gene):
                outline = outline + '\t' + str(LabelDict[label][gene])
                TotalLabelCountsDict[label] += LabelDict[label][gene]
            else:
                outline = outline + '\t0'
        outfile.write(outline + '\n')

    outfile.close()

    outfile = open(sys.argv[2] + '.TPM', 'w')

    outline = '#'
    for label in labels:
        outline = outline + '\t' + label
    outfile.write(outline + '\n')

    for gene in genes:
        outline = gene.replace('\t',' ')
        for label in labels:
            if LabelDict[label].has_key(gene):
                counts = LabelDict[label][gene]
                TPM = counts/(TotalLabelCountsDict[label]/1000000)
                outline = outline + '\t' + str(TPM)
            else:
                outline = outline + '\t0'
        outfile.write(outline + '\n')

    outfile.close()

run()