import os
import gzip
import numpy as np
import pandas as pd


def get_links(hits_df, eg_df):
    hits_df = hits_df[["motif_name", "peak_id", "hit_importance", "hit_coefficient"]].add_prefix("finemo_")
    eg_df = eg_df[["enh_idx", "TargetGene", "powerlaw.Score"]].add_prefix("abc_")
    data_merged = pd.merge(hits_df, eg_df, how='inner', left_on="finemo_peak_id", right_on="abc_enh_idx")
    return data_merged


def get_total_loadings(links_df):
    links_df["importance_loading"] = links_df["abc_powerlaw.Score"] * links_df["finemo_hit_importance"]
    links_df["coefficient_loading"] = links_df["abc_powerlaw.Score"] * links_df["finemo_hit_coefficient"]
    hit_counts = pd.pivot_table(links_df, values="abc_powerlaw.Score", index="abc_TargetGene", columns="finemo_motif_name", aggfunc=np.sum, fill_value=0)
    hit_importances = pd.pivot_table(links_df, values="importance_loading", index="abc_TargetGene", columns="finemo_motif_name", aggfunc=np.sum, fill_value=0)
    hit_coefficients = pd.pivot_table(links_df, values="coefficient_loading", index="abc_TargetGene", columns="finemo_motif_name", aggfunc=np.sum, fill_value=0)

    return hit_counts, hit_importances, hit_coefficients


def main(hits_path, eg_path, out_path_counts, out_path_importances, out_path_coefficients):
    hits_df = pd.read_csv(os.path.join(hits_path, "hits_unique.tsv"), sep='\t', header=0)
    eg_df = pd.read_csv(eg_path, sep='\t', header=0, compression='gzip')

    links_df = get_links(hits_df, eg_df)

    hit_counts, hit_importances, hit_coefficients = get_total_loadings(links_df)

    hit_counts.to_csv(out_path_counts, sep="\t", index=True)
    hit_importances.to_csv(out_path_importances, sep="\t", index=True)
    hit_coefficients.to_csv(out_path_coefficients, sep="\t", index=True)


out_path_counts = snakemake.output["counts"]
out_path_importances = snakemake.output["importances"]
out_path_coefficients = snakemake.output["coefficients"]

hits_path = snakemake.input["hits"]
eg_path = snakemake.input["eg"]

main(hits_path, eg_path, out_path_counts, out_path_importances, out_path_coefficients)
