#!/usr/bin/env python3

import yaml
import subprocess, tempfile
import logging, argparse, os, sys
from collections import defaultdict

import branchSpecificMask

def getArgs():
    parser = argparse.ArgumentParser(description="""
Given a full sample-paths file and a maskfile from the same protobuf, print
(to stdout) tab-separated output with three columns:
* sample name
* most specific Pango lineage (e.g. if BA.2, BA.5 and BA.5.3 all apply, BA.5.3)
* total number of sites (and/or reversions) masked across all lineages
"""
)
    parser.add_argument('samplePaths', metavar='merged.pb.sample-paths',
                        help='sample-paths from same protobuf file as was branch-specific masked')
    parser.add_argument('maskFile', metavar='merged.pb.branchSpecificMask.tsv',
                        help='branch-specific masking file generated by branchSpecificMask.py')
    parser.add_argument('yamlIn', metavar='branchSpecificMask.yml',
                        help='YAML spec that identifies representative node and sites to mask')
    args = parser.parse_args()
    return args

def die(message):
    """Log an error message and exit with nonzero status"""
    logging.error(message)
    exit(1)

def getRepresentativeNodeIds(maskFile):
    """Read representative node IDs from maskFile, return dict of node IDs -> #occurrences"""
    repNodes = defaultdict(int)
    with open(maskFile, 'r') as m:
        for line in m:
            try:
                [mut, nodeId] = line.rstrip().split('\t')
            except ValueError as e:
                die(f"maskFile {maskFile} has unexpected format (expect two tab-sep columns:\n" + line)
            repNodes[nodeId] += 1
    return repNodes

def getLineageNodeId(name, nodeMuts, branchSpec):
    """Given a name that is a representative for branchSpec, and its node list, find the node that
    starts the branch"""
    # In most cases we want the last node in the path [-1]
    nodeIx = -1
    # ... but if the last word starts with the sample name (with private mutations)
    # then we do not want to mask just that sample, so backtrack to [-2]
    if nodeMuts[-1].startswith(name):
        nodeIx = nodeIx - 1
    # ... and the spec might say to backtrack even more (e.g. parent or grandparent):
    backtrack = branchSpec.get('representativeBacktrack')
    if backtrack is not None:
        nodeIx = nodeIx - backtrack
    # Strip to just the node ID, discard mutations
    return nodeMuts[nodeIx].split(':')[0]

def printSampleStats(spec, repNodes, samplePaths):
    """For each sample in samplePaths, print out sample name, most specific pango lineage that is
    used for branch-specific masking, and total number of sites/reversions masked in sample"""
    repLineages = {}
    for branch in spec:
        rep = spec[branch]['representative']
        repLineages[rep] = branch
    nodeLineages = dict()
    sampleStats = list()
    with open(samplePaths, 'r') as s:
        for line in s:
            try:
                [fullName, path] = line.rstrip().split('\t')
            except ValueError as e:
                continue
            branchNodes = list()
            nodeMuts = path.split(' ')
            # Collect all branch-starting nodes found in nodeMuts
            for nodeMut in nodeMuts:
                nodeId = nodeMut.split(':')[0]
                if nodeId in repNodes:
                    branchNodes.append(nodeId)
            # If this sample is a representative, find which branch-starting node it has
            # and map that node to the lineage that the sample represents.
            name = fullName.split('|')[0]
            if repLineages.get(name) is not None:
                lineage = repLineages[name]
                nodeId = getLineageNodeId(name, nodeMuts, spec[lineage])
                nodeLineages[nodeId] = lineage
            mostSpecificBranch = ''
            siteCount = 0
            if len(branchNodes) > 0:
                # When the path contains multiple branch-starting nodes, the last one is the most
                # specific.
                mostSpecificBranch = branchNodes[-1]
                for node in branchNodes:
                    siteCount += repNodes[node]
            sampleStats.append((fullName, mostSpecificBranch, siteCount))


    for fullName, mostSpecificBranch, siteCount in sampleStats:
        lineage = ''
        if mostSpecificBranch != '':
            lineage = nodeLineages.get(mostSpecificBranch)
            if lineage is None:
                die(f"No lineage for node {mostSpecificBranch} (sample {fullName})")
        print('\t'.join([fullName, lineage, str(siteCount)]))


def main():
    args = getArgs()
    spec = branchSpecificMask.getSpec(args.yamlIn)
    repNodes = getRepresentativeNodeIds(args.maskFile)
    printSampleStats(spec, repNodes, args.samplePaths)


if __name__ == '__main__':
    main()
