import os
import json
from collections import defaultdict
from snakemake.utils import min_version
from snakemake.remote.HTTP import RemoteProvider as HTTPRemoteProvider

min_version("6.6.1")

configfile: 
    "config/config.yaml"

with open("config/samples.tsv") as sample_file:
    h = sample_file.readline().rstrip('\n').split('\t')
    exp_ind = h.index("Experiment")
    asm_ind = h.index("Assembly")
    sample_config = {}
    samples = []
    samples_in_assembly = {}
    for line in sample_file:
        if line.startswith("#"):
            continue
        entries = line.rstrip('\n').split('\t')
        exp = entries[exp_ind]
        asm = entries[asm_ind]
        sample_config[exp] = {
            "experiment": exp,
            "assembly": asm
        }
        samples.append(exp)
        samples_in_assembly.setdefault(asm, []).append(exp)
    assemblies = list(samples_in_assembly.keys())

with open("config/cluster_labels.tsv") as sample_file:
    h = sample_file.readline().rstrip('\n').split('\t')
    cluster_ind = h.index("cluster")
    label_ind = h.index("label")
    organism_ind = h.index("organism")
    cluster_names = {}
    clusters = set()
    for line in sample_file:
        if line.startswith("#"):
            continue
        entries = line.rstrip('\n').split('\t')
        cluster = entries[cluster_ind]
        label = entries[label_ind]
        organism = entries[organism_ind]
        cluster_names.setdefault(organism, {"orig": [], "names": []})
        cluster_names[organism]["orig"].append(cluster)
        cluster_names[organism]["names"].append(label)    
        clusters.add(label)

with open("config/cluster_labels_l1.tsv") as sample_file:
    h = sample_file.readline().rstrip('\n').split('\t')
    cluster_ind = h.index("cluster")
    label_ind = h.index("label")
    organism_ind = h.index("organism")
    cluster_names_l1 = {}
    clusters_l1 = set()
    for line in sample_file:
        if line.startswith("#"):
            continue
        entries = line.rstrip('\n').split('\t')
        cluster = entries[cluster_ind]
        label = entries[label_ind]
        organism = entries[organism_ind]
        cluster_names_l1.setdefault(organism, {"orig": [], "names": []})
        cluster_names_l1[organism]["orig"].append(cluster)
        cluster_names_l1[organism]["names"].append(label)    
        clusters_l1.add(label)

with open("config/cluster_labels_l2.tsv") as sample_file:
    h = sample_file.readline().rstrip('\n').split('\t')
    cluster_ind = h.index("cluster")
    label_l1_ind = h.index("label_l1")
    label_l2_ind = h.index("label_l2")
    organism_ind = h.index("organism")
    cluster_names_l2 = {}
    l2_l1_map = {}
    clusters_l2 = set()
    for line in sample_file:
        if line.startswith("#"):
            continue
        entries = line.rstrip('\n').split('\t')
        cluster = entries[cluster_ind]
        label_l1 = entries[label_l1_ind]
        label_l2 = entries[label_l2_ind]
        organism = entries[organism_ind]
        cluster_names_l2.setdefault((organism, label_l1), {"orig": [], "names": []})
        cluster_names_l2[(organism, label_l1)]["orig"].append(cluster)
        cluster_names_l2[(organism, label_l1)]["names"].append(label_l2)    
        l2_l1_map[label_l2] = label_l1
        clusters_l2.add(label_l2)

with open("config/cluster_labels_l3.tsv") as sample_file:
    h = sample_file.readline().rstrip('\n').split('\t')
    cluster_ind = h.index("cluster")
    label_l2_ind = h.index("label_l2")
    label_l3_ind = h.index("label_l3")
    organism_ind = h.index("organism")
    cluster_names_l3 = {}
    l3_l2_map = {}
    clusters_l3 = set()
    clusters_l2_rem = set()
    for line in sample_file:
        if line.startswith("#"):
            continue
        entries = line.rstrip('\n').split('\t')
        cluster = entries[cluster_ind]
        label_l2 = entries[label_l2_ind]
        label_l3 = entries[label_l3_ind]
        organism = entries[organism_ind]
        cluster_names_l3.setdefault((organism, label_l2), {"orig": [], "names": []})
        cluster_names_l3[(organism, label_l2)]["orig"].append(cluster)
        cluster_names_l3[(organism, label_l2)]["names"].append(label_l3)    
        l3_l2_map[label_l3] = label_l2
        clusters_l3.add(label_l3)
        clusters_l2_rem.add(label_l2)
    clusters_l3 |= clusters_l2 - clusters_l2_rem

with open("config/clusters_exclude.tsv") as sample_file:
    h = sample_file.readline().rstrip('\n').split('\t')
    cluster_ind = h.index("cluster")
    organism_ind = h.index("organism")
    clusters_exclude = []
    clusters_exclude_clust_only = set()
    for line in sample_file:
        if line.startswith("#"):
            continue
        entries = line.rstrip('\n').split('\t')
        cluster = entries[cluster_ind]
        organism = entries[organism_ind]
        clusters_exclude.append((organism, cluster),)
        clusters_exclude_clust_only.add(cluster)

workdir: 
    config['workdir']

HTTP = HTTPRemoteProvider()

max_threads = config["max_threads_per_rule"]

def script_path(script_name):
    return str(workflow.source_path(script_name))

include:
    "rules/chromap.smk"
include:
    "rules/rna.smk"
include:
    "rules/peaks.smk"
include:
    "rules/chrombpnet.smk"
include:
    "rules/cluster.smk"
include:
    "rules/integration.smk"
include:
    "rules/cross_species_cells.smk"
include:
    "rules/cluster_l2.smk"
include:
    "rules/cluster_l3.smk"
include:
    "rules/cellspace.smk"
# include:
#     "rules/xs_projection.smk"
include:
    "rules/xs_projection_selected.smk"
include:
    "rules/xs_chrombpnet_global.smk"
include:
    "rules/peaks_2.smk"
include:
    "rules/chrombpnet_2.smk"

def get_chrombpnet_outputs(template, assemblies, clusters, folds, exclude):
    out = []
    exclude_set = set(exclude)
    for a in assemblies:
        for c in clusters:
            if (a, c) not in exclude_set:
                for f in folds:
                    out.append(template.format(assembly=a, cluster=c, fold=f))

    return out

def get_track_outputs(template, assemblies, clusters, folds, exclude):
    out = []
    for a in assemblies:
        for c in clusters:
            if c not in exclude:
                for f in folds:
                    out.append(template.format(assembly=a, cluster=c, fold=f))

    return out


rule all:
    """
    Generate all outputs (default)
    """
    input: 
        expand("results/{sample}/cluster/archr_clustered", sample=samples),
        # expand("results/{sample}/cluster/archr_markers", sample=samples),
        expand("results/{sample}/cluster/archr_clustered_extern", sample=samples),
        expand("results_merged/{assembly}/cluster/archr_harmony", assembly=assemblies),
        expand("results_merged/{assembly}/cluster/archr_clustered_extern", assembly=assemblies),
        expand("results_merged/{assembly}/cluster/archr_markers", assembly=assemblies),
        expand("results_merged/{assembly}/cluster/archr_cluster_int", assembly=assemblies),
        expand("results_merged/{assembly}/cluster/archr_label_export", assembly=assemblies),
        expand("results/{sample}/alignments_rna_star", sample=samples),
        expand("results/{sample}/cluster/exp_matrix_frac_match.txt", sample=samples),
        # # expand("results/{sample}/kmer_scores/mat.mtx", sample=samples),
        # expand("results/{sample}/kmer_scores/score/kmer_embed.h5ad", sample=samples),
        # expand("results/{sample}/kmer_scores/flank/kmer_embed.h5ad", sample=samples),
        # expand("results/{sample}/kmer_scores/footprint/kmer_embed.h5ad", sample=samples),
        # # expand("results/{sample}/kmer_scores/kmer_cluster_compare", sample=samples),
        # # expand("results/{sample}/gene_counts/ranges.tsv", sample=samples) ####
        expand("results_merged/{assembly}/peaks/{cluster}/peaks_overlap_filtered.narrowPeak", assembly=assemblies, cluster=clusters),
        # expand("results/{sample}/peaks/{cluster}/coverage_total.bw", sample=samples, cluster=clusters),
        # expand("results/{sample}/peak_plots/log_coverage.html", sample=samples),
        # expand("results_merged/{assembly}/chrombpnet/{cluster}/negatives_fold_{fold}/negatives_with_summit.bed", assembly=assemblies, cluster=clusters, fold=config["folds_used"]),
        # expand("results_merged/{assembly}/chrombpnet/{cluster}/train_fold_{fold}", assembly=assemblies, cluster=clusters, fold=config["folds_used"]),
        # get_chrombpnet_outputs("results_merged/{assembly}/chrombpnet/{cluster}/negatives_fold_{fold}/negatives_with_summit.bed", assemblies, clusters, config["folds_used"], clusters_exclude),
        # get_chrombpnet_outputs("results_merged/{assembly}/chrombpnet/{cluster}/train_fold_{fold}", assemblies, clusters, config["folds_used"], clusters_exclude),
        # get_chrombpnet_outputs("results_merged/{assembly}/chrombpnet/{cluster}/bias_pwm.png", assemblies, clusters, config["folds_used"], clusters_exclude),
        get_chrombpnet_outputs("results_merged/{assembly}/chrombpnet/{cluster}/interpret_fold_{fold}", assemblies, clusters, config["folds_used"], clusters_exclude),
        get_chrombpnet_outputs("results_merged/{assembly}/chrombpnet/{cluster}/motif_counts_fold_{fold}", assemblies, clusters, config["folds_used"], clusters_exclude),
        get_chrombpnet_outputs("results_merged/{assembly}/chrombpnet/{cluster}/motif_profile_fold_{fold}", assemblies, clusters, config["folds_used"], clusters_exclude),
        "results_unified/cellspace/cellspace_embedding.tsv",
        "results_unified/seurat_cells_unified/proj.rds",
        "results_unified/seurat_peaks_unified/proj.rds",
        "results_unified/cellspace/cellspace_umap_data.csv",
        "results_unified/cellspace/cellspace_umap_plot_cluster.pdf",

def chrombpnet_export_path(suffix):
    return os.path.join(config["chrombpnet_export_prefix"], suffix)

rule chrombpnet_export:
    """
    Data for round 2 chrombpnet training
    """
    input:
        chrombpnet_export_path("bias_model.h5"),
        expand(chrombpnet_export_path("assembly/{assembly}/genome.fa"), assembly=assemblies),
        expand(chrombpnet_export_path("assembly/{assembly}/folds"), assembly=assemblies),
        expand(chrombpnet_export_path("assembly/{assembly}/clusters/{cluster}/coverage.bw"), assembly=assemblies, cluster=clusters_l3),
        expand(chrombpnet_export_path("assembly/{assembly}/clusters/{cluster}/peaks.narrowPeak"), assembly=assemblies, cluster=clusters_l3),
        expand(chrombpnet_export_path("assembly/{assembly}/clusters/{cluster}/folds/{fold}/negatives.bed"), assembly=assemblies, cluster=clusters_l3, fold=config["folds_used_2"]),

rule xs_global:
    """
    Cross-species global visualizations
    """
    input:
        get_track_outputs("results_merged/{assembly}/xs_chrombpnet_global/markers_{cluster}/fold_{fold}_peaks_chrombpnet_compare", assemblies, clusters, config["folds_used"], clusters_exclude_clust_only),

rule xs_projection_tracks:
    """
    Cross-species chrombpnet visualizations
    """
    input:
        get_track_outputs("results_merged/{assembly}/xs_projection/markers_{cluster}/build_track_data_fold_{fold}.touch", assemblies, clusters, config["folds_used"], clusters_exclude_clust_only),

rule xs_projection_tracks_selected:
    """
    Cross-species chrombpnet visualizations
    """
    input:
        get_track_outputs("results_merged/{assembly}/xs_projection_selected/markers_{cluster}/build_track_data_fold_{fold}.touch", assemblies, clusters, config["folds_used"], clusters_exclude_clust_only),



rule plot_markers:
    """
    Plot marker genes
    """
    input:
        expand("results_merged/{assembly}/cluster/archr_markers", assembly=assemblies),
        expand("results_merged/{assembly}/cluster/archr_markers_ameen", assembly=assemblies),
        expand("results_merged/{assembly}/cluster/archr_markers_merged_hu_all", assembly=assemblies),
        expand("results_merged/{assembly}/cluster/archr_markers_merged_ameen_all", assembly=assemblies),
        expand("results_merged/{assembly}/cluster/archr_markerlist_merged", assembly=assemblies),
        expand("results_merged/{assembly}/cluster/archr_markerlist_merged_data", assembly=assemblies),
        expand("results_merged/{assembly}/cluster/archr_markers_cross_species", assembly=assemblies),
        expand("results_merged/{assembly}/cluster_l2/{cluster_l1}/archr_cluster_l2", assembly=assemblies, cluster_l1=clusters_l1),
        expand("results_merged/{assembly}/cluster_l2/{cluster_l1}/archr_markers_l2_hu_all", assembly=assemblies, cluster_l1=clusters_l1),
        expand("results_merged/{assembly}/cluster_l2/{cluster_l1}/archr_markers_l2_ameen_all", assembly=assemblies, cluster_l1=clusters_l1),
        expand("results_merged/{assembly}/cluster_l2/{cluster_l1}/archr_markerlist_l2", assembly=assemblies, cluster_l1=clusters_l1),
        expand("results_merged/{assembly}/cluster_l2/{cluster_l1}/archr_markerlist_l2_data", assembly=assemblies, cluster_l1=clusters_l1),
        expand("results_merged/{assembly}/cluster_l2/{cluster_l1}/archr_markers_cross_species_l2", assembly=assemblies, cluster_l1=clusters_l1),
        expand("results_merged/{assembly}/cluster_l3/{cluster_l2}/archr_cluster_l3", assembly=assemblies, cluster_l2=clusters_l2),
        expand("results_merged/{assembly}/cluster_l3/{cluster_l2}/archr_markers_l3_hu_all", assembly=assemblies, cluster_l2=clusters_l2),
        expand("results_merged/{assembly}/cluster_l3/{cluster_l2}/archr_markers_l3_ameen_all", assembly=assemblies, cluster_l2=clusters_l2),
        expand("results_merged/{assembly}/cluster_l3/{cluster_l2}/archr_markerlist_l3", assembly=assemblies, cluster_l2=clusters_l2),
        expand("results_merged/{assembly}/cluster_l3/{cluster_l2}/archr_markerlist_l3_data", assembly=assemblies, cluster_l2=clusters_l2),
        expand("results_merged/{assembly}/cluster_l3/{cluster_l2}/archr_markers_cross_species_l3", assembly=assemblies, cluster_l2=clusters_l2),
        expand("results_merged/{assembly}/cluster_l3/{cluster_l2}/clusters_integrated_l3.txt", assembly=assemblies, cluster_l2=clusters_l2),
        expand("results_merged/{assembly}/cluster/archr_label_import_l3", assembly=assemblies),
        expand("results_merged/{assembly}/cluster/archr_markerlist_merged_l3", assembly=assemblies)

rule plot_markers_debug:
    input:
        expand("results_merged/{assembly}/cluster_l3/{cluster_l2}/clusters_integrated_l3.txt", assembly=assemblies, cluster_l2=["dCM"]),


rule fetch_whitelist:
    """
    Fetch barcode whitelist
    """
    output:
        "bc_whitelist.txt"
    params:
        url = config["bc_whitelist"],
        prefix = config["bc_prefix"]
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L {params.url} | "
        "zcat | tr ACGTacgt TGCAtgca | rev | sed 's/^/{params.prefix}/' > {output}"
        # "zcat | sed 's/^/{params.prefix}/' > {output}"

rule fetch_genome:
    """
    Fetch FASTA
    """
    input:
        lambda w: config["genome"][w.assembly]
    output:
        "genomes/{assembly}.fa"
    conda:
        "envs/fetch.yaml"
    shell:
        "cp {input} {output}"

rule build_bc_map:
    """
    Build ATAC to RNA barcode mapping
    """
    input:
        atac = config["wl_atac_orig"],
        rna = config["wl_rna_orig"]
    output:
        "bc_map.txt"
    params:
        prefix = config["bc_prefix"]
    conda:
        "envs/fetch.yaml"
    shell:
        "zcat {input.atac} | tr ACGTacgt TGCAtgca | rev | sed 's/^/{params.prefix}/' |"
        "paste - <(zcat {input.rna}) > {output}"

rule setup_r_kernel:
    """
    Setup Jupyter IRkernel
    """
    output:
        touch("results/irkernel_touch.txt")
    conda:
        "envs/cluster.yaml"
    script:
        "scripts/setup_kernel.R"


