import os
import re
import ast


def make_c_files():
    modules = ['reduce', 'move', 'nonreduce', 'nonreduce_axis']
    dirpath = os.path.dirname(__file__)
    for module in modules:
        filepath = os.path.join(dirpath, module + '_template.c')
        with open(filepath, 'r') as f:
            src_str = f.read()
        src_str = template(src_str)
        filepath = os.path.join(dirpath, module + '.c')
        with open(filepath, 'w') as f:
            f.write(src_str)


def template(src_str):
    src_list = src_str.splitlines()
    src_list = repeat_templating(src_list)
    src_list = dtype_templating(src_list)
    src_list = string_templating(src_list)
    src_str = '\n'.join(src_list)
    src_str = re.sub(r'\n\s*\n\s*\n', r'\n\n', src_str)
    return src_str


# repeat --------------------------------------------------------------------

REPEAT_BEGIN = r'^/\*\s*repeat\s*=\s*'
REPEAT_END = r'^/\*\s*repeat end'
COMMENT_END = r'.*\*\/.*'


def repeat_templating(lines):
    index = 0
    while True:
        idx0, idx1 = next_block(lines, index, REPEAT_BEGIN, REPEAT_END)
        if idx0 is None:
            break
        func_list = lines[idx0:idx1]
        func_list = expand_functions_repeat(func_list)
        # the +1 below is to skip the /* repeat end */ line
        lines = lines[:idx0] + func_list + lines[idx1+1:]
        index = idx0
    return lines


def expand_functions_repeat(lines):
    idx = first_occurence(COMMENT_END, lines)
    repeat_dict = repeat_info(lines[:idx + 1])
    lines = lines[idx + 1:]
    func_str = '\n'.join(lines)
    func_list = expand_repeat(func_str, repeat_dict)
    return func_list


def repeat_info(lines):
    line = ''.join(lines)
    repeat = re.findall(r'\{.*\}', line)
    repeat_dict = ast.literal_eval(repeat[0])
    return repeat_dict


def expand_repeat(func_str, repeat_dict):
    nrepeats = [len(repeat_dict[key]) for key in repeat_dict]
    if len(set(nrepeats)) != 1:
        raise ValueError("All repeat lists must be the same length")
    nrepeat = nrepeats[0]
    func_list = []
    for i in range(nrepeat):
        f = func_str[:]
        for key in repeat_dict:
            f = f.replace(key, repeat_dict[key][i])
        func_list.append('\n' + f)
    func_list = (''.join(func_list)).splitlines()
    return func_list


# dtype ---------------------------------------------------------------------

DTYPE_BEGIN = r'^/\*\s*dtype\s*=\s*'
DTYPE_END = r'^/\*\s*dtype end'


def dtype_templating(lines):
    index = 0
    while True:
        idx0, idx1 = next_block(lines, index, DTYPE_BEGIN, DTYPE_END)
        if idx0 is None:
            break
        func_list = lines[idx0:idx1]
        func_list = expand_functions_dtype(func_list)
        # the +1 below is to skip the /* dtype end */ line
        lines = lines[:idx0] + func_list + lines[idx1+1:]
        index = idx0
    return lines


def expand_functions_dtype(lines):
    idx = first_occurence(COMMENT_END, lines)
    dtypes = dtype_info(lines[:idx + 1])
    lines = lines[idx + 1:]
    func_str = '\n'.join(lines)
    func_list = expand_dtypes(func_str, dtypes)
    return func_list


def dtype_info(lines):
    line = ''.join(lines)
    dtypes = re.findall(r'\[.*\]', line)
    if len(dtypes) != 1:
        raise ValueError("expecting exactly one dtype specification")
    dtypes = ast.literal_eval(dtypes[0])
    return dtypes


def expand_dtypes(func_str, dtypes):
    if 'DTYPE' not in func_str:
        raise ValueError("cannot find dtype marker")
    func_list = []
    for dtype in dtypes:
        f = func_str[:]
        for i, dt in enumerate(dtype):
            f = f.replace('DTYPE%d' % i, dt)
            if i > 0:
                f = f + '\n'
        func_list.append('\n\n' + f)
    return func_list


# multiline strings ---------------------------------------------------------

STRING_BEGIN = r'.*MULTILINE STRING BEGIN.*'
STRING_END = r'.*MULTILINE STRING END.*'


def string_templating(lines):
    index = 0
    while True:
        idx0, idx1 = next_block(lines, index, STRING_BEGIN, STRING_END)
        if idx0 is None:
            break
        str_list = lines[idx0+1:idx1]
        str_list = quote_string(str_list)
        lines = lines[:idx0] + str_list + lines[idx1+1:]
        index = idx0
    return lines


def quote_string(lines):
    for i in range(len(lines)):
        lines[i] = "\"" + lines[i] + r"\n" + "\""
    lines[-1] = lines[-1] + ";"
    return lines


# utility -------------------------------------------------------------------

def first_occurence(pattern, lines):
    for i in range(len(lines)):
        if re.match(pattern, lines[i]):
            return i
    raise ValueError("`pattern` not found")


def next_block(lines, index, begine_pattern, end_pattern):
    idx = None
    for i in range(index, len(lines)):
        line = lines[i]
        if re.match(begine_pattern, line):
            idx = i
        elif re.match(end_pattern, line):
            if idx is None:
                raise ValueError("found end of function before beginning")
            return idx, i
    return None, None
