#! /usr/bin/env python
# Convert a .model file into a .inl file that can be directly compiled into nanopolish
from __future__ import print_function

import argparse

def assign(name, value):
    assert(len(name) > 0)
    assert(value is not None)
    print("\ttmp.%s = %s;" % (name, value))

def quote(value):
    assert(len(value) > 0)
    return "\"" + value + "\""

parser = argparse.ArgumentParser( description='Rewrite a .model file into a file that can be directly compiled into nanopolish')
parser.add_argument('-i', '--input', type=str, required=True)
parser.add_argument('-f', '--function-name', type=str, required=True)
args = parser.parse_args()

f = open(args.input)
K = 0
model = list()
header_kv = dict()
bases = dict()

for line in f:
    line = line.rstrip()
    fields = line.split()

    # check if this is a header line
    if line[0] == '#' or line.find("kmer") == 0:
        if fields[0] != "kmer":
            key = fields[0][1:]
            value = fields[1]
            header_kv[key] = value
    else:
        # check if this is the first non-header line seen
        if K == 0:
            K = len(fields[0])
        else:
            assert len(fields[0]) == K
        model.append(tuple(fields))
        for b in fields[0]:
            bases[b] = 1

# Preamble
print("// Autogenerated by convert_model_to_header.py")
print("#ifndef NANOPOLISH_%s_INL" % args.function_name.upper())
print("#define NANOPOLISH_%s_INL" % args.function_name.upper())

data_name = "%s_data" % args.function_name
print("static std::vector<double> %s = {" % data_name)
for ki, t in enumerate(model):

    is_last = ki == len(model) - 1
    sep = ',' if not is_last else ''

    print("\t\t%.5f, %.5f, %.5f, %.5f%s // %s" % (float(t[1]), float(t[2]), float(t[3]), float(t[4]), sep, t[0]))
print("};")

print("PoreModel %s()\n{" % args.function_name)
print("\tPoreModel tmp;")
# Output metadata
assign("model_filename", quote(args.input))
assign("k", K)

num_states = len(model)

print("\ttmp.states.resize(%d);" % num_states)
print("\tfor(size_t i = 0; i < %d; ++i) {" % num_states)
print("\t\ttmp.states[i].level_mean = %s[4*i + 0];" % data_name)
print("\t\ttmp.states[i].level_stdv = %s[4*i + 1];" % data_name)
print("\t\ttmp.states[i].sd_mean = %s[4*i + 2];" % data_name)
print("\t\ttmp.states[i].sd_stdv = %s[4*i + 3];" % data_name)
print("\t\ttmp.states[i].update_sd_lambda();")
print("\t\ttmp.states[i].update_logs();")
print("\t}")

if "alphabet" in header_kv:
    print("\ttmp.pmalphabet = get_alphabet_by_name(%s);" % (quote(header_kv["alphabet"])))
else:
    print("\ttmp.pmalphabet = best_alphabet(%s);" % (quote("".join(bases))))

print("\ttmp.set_metadata(%s, %s);" % (quote(header_kv["kit"]), quote(header_kv["strand"])))
print("\treturn tmp;\n}")
print("#endif")
