import os
import json
from collections import defaultdict
from snakemake.utils import min_version
from snakemake.remote.HTTP import RemoteProvider as HTTPRemoteProvider

min_version("6.6.1")

configfile: 
    "config/config.yaml"


workdir: 
    config['workdir']

max_threads = config["max_threads_per_rule"]

def other_asm2(assembly):
    if assembly == "mm10":
        return "cavpor_dnazoo"
    elif assembly == "cavpor_dnazoo":
        return "mm10"
    else:
        raise ValueError

assembly_2_to_alt = {
    "cavpor_dnazoo": "cavpor3",
    "mm10": "mm10"
}

assembly_alt_to_2 = {
    "cavpor3": "cavpor_dnazoo",
    "mm10": "mm10"
}

with open("config/samples.tsv") as sample_file:
    h = sample_file.readline().rstrip('\n').split('\t')
    exp_ind = h.index("Experiment")
    asm_ind = h.index("Assembly")
    asm2_ind = h.index("Assembly2")
    sample_assembly = {}
    sample_assembly_2 = {}
    samples = []
    samples_in_assembly = {}
    samples_in_assembly_2 = {}
    for line in sample_file:
        if line.startswith("#"):
            continue
        entries = line.rstrip('\n').split('\t')
        exp = entries[exp_ind]
        asm = entries[asm_ind]
        asm2 = entries[asm2_ind]
        sample_assembly[exp] = asm
        sample_assembly_2[exp] = asm2
        samples.append(exp)
        samples_in_assembly.setdefault(asm, []).append(exp)
        samples_in_assembly_2.setdefault(asm2, []).append(exp)
    assemblies = list(samples_in_assembly.keys())
    assemblies_2 = list(samples_in_assembly_2.keys())

def script_path(script_name):
    return str(workflow.source_path(script_name))

def import_path(suffix):
    return os.path.join(config["import_dir"], suffix)

# def export_path(suffix):
#     return os.path.join(config["export_dir"], suffix)

include:
    "rules/atac.smk"
include:
    "rules/rna.smk"
include:
    "rules/rna_group.smk"
include:
    "rules/reference.smk"
include:
    "rules/rna_name_l1.smk"
include:
    "rules/chromap.smk"
include:
    "rules/chromap_alt.smk"
include:
    "rules/peaks_l1.smk"
include:
    "rules/peaks_l1_alt.smk"
include:
    "rules/chrombpnet_bias.smk"
include:
    "rules/chrombpnet_l1.smk"
include:
    "rules/genome_transfer_l1.smk"
include:
    "rules/genome_transfer_l1_alt.smk"
include:
    "rules/chrombpnet_l1_downsampled.smk"
include:
    "rules/genome_transfer_l1_downsampled.smk"
include:
    "rules/annotate_peaks.smk"
include:
    "rules/annotate_peaks_alt.smk"
include:
    "rules/abc.smk"
#     "rules/scarlink.smk"
include:
    "rules/finemo_l1.smk"
include:
    "rules/great.smk"


rule all:
    """
    Generate all outputs (default)
    """
    input: 
        expand("results/sample/{sample}/rna/seurat_name_l1/proj.rds", sample=samples),
        expand("results/sample/{sample}/rna/seurat_write_metadata_l1/metadata.tsv", sample=samples),
        expand("results/sample/{sample}/chromap/fragments.tsv.gz", sample=samples),
        expand("results/assembly_2/{assembly_2}/peaks_l1/{cluster}/fragments_sorted.tsv", assembly_2=assemblies_2, cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/peaks_l1/{cluster}/peaks_overlap_filtered.narrowPeak", assembly_2=assemblies_2, cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/peaks_l1/{cluster}/peaks_overlap_clipped.narrowPeak", assembly_2=assemblies_2, cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/peaks_l1/{cluster}/nonpeaks/fold_{fold}", assembly_2=assemblies_2, cluster=config["l1_labels"], fold=config["folds_used"]),
        expand("results/assembly_2/{assembly_2}/chrombpnet_train_bias", assembly_2=assemblies_2),
        expand("results/assembly_2/{assembly_2}/chrombpnet_train_l1/{cluster}/fold_{fold}", assembly_2=assemblies_2, cluster=config["l1_labels"], fold=config["folds_used"]),
        expand("results/assembly_2/{assembly_2}/chrombpnet_predict_l1/{cluster}/fold_{fold}/ss", assembly_2=assemblies_2, cluster=config["l1_labels"], fold=config["folds_used"]),
        expand("results/assembly_2/{assembly_2}/chrombpnet_predict_l1/{cluster}/fold_{fold}/xs", assembly_2=assemblies_2, cluster=config["l1_labels"], fold=config["folds_used"]),
        expand("results/assembly_2/{assembly_2}/chrombpnet_contributions_l1/{cluster}/fold_{fold}/ss", assembly_2=assemblies_2, cluster=config["l1_labels"], fold=config["folds_used"]),
        expand("results/assembly_2/{assembly_2}/chrombpnet_contributions_l1/{cluster}/fold_{fold}/xs", assembly_2=assemblies_2, cluster=config["l1_labels"], fold=config["folds_used"]),
        expand("results/assembly_2/{assembly_2}/score_peaks_predictions_l1/{cluster}/scores.tsv", assembly_2=assemblies_2, cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/score_peaks_contributions_l1/{cluster}/scores.tsv", assembly_2=assemblies_2, cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/collate_peak_outputs/{cluster}/peak_outputs.h5", assembly_2=assemblies_2, cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/merge_peak_scores/scores.tsv", assembly_2=assemblies_2),
        expand("results/assembly_2/{assembly_2}/modisco_l1_counts_ss/{cluster}/modisco_results.h5", assembly_2=assemblies_2, cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/modisco_l1_counts_xs/{cluster}/modisco_results.h5", assembly_2=assemblies_2, cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/modisco_l1_counts_ss/{cluster}/modisco_report", assembly_2=assemblies_2, cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/modisco_l1_counts_xs/{cluster}/modisco_report", assembly_2=assemblies_2, cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/modisco_seqlet_occurences/{cluster}/occurences_annotated.tsv", assembly_2=assemblies_2, cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/merge_peak_scores_downsampled/scores.tsv", assembly_2=assemblies_2),
        expand("results/assembly_2/{assembly_2}/modisco_l1_downsampled_counts_ss/{cluster}/modisco_results.h5", assembly_2=assemblies_2, cluster=config["chrombpnet_downsample_l1"].keys()),
        expand("results/assembly_2/{assembly_2}/modisco_l1_downsampled_counts_xs/{cluster}/modisco_results.h5", assembly_2=assemblies_2, cluster=config["chrombpnet_downsample_l1"].keys()),
        expand("results/assembly_2/{assembly_2}/modisco_l1_downsampled_counts_ss/{cluster}/modisco_report", assembly_2=assemblies_2, cluster=config["chrombpnet_downsample_l1"].keys()),
        expand("results/assembly_2/{assembly_2}/modisco_l1_downsampled_counts_xs/{cluster}/modisco_report", assembly_2=assemblies_2, cluster=config["chrombpnet_downsample_l1"].keys()),
        expand("results/assembly_2/{assembly_2}/modisco_seqlet_occurences_downsampled/{cluster}/occurences_annotated.tsv", assembly_2=assemblies_2, cluster=config["chrombpnet_downsample_l1"].keys()),
        expand("results/assembly_2/{assembly_2}/peak_tss_distances/{cluster}/peak_data.tsv", assembly_2=["mm10"], cluster=config["l1_labels"]),

rule peak_annotations:
    input:
        expand("results/assembly_2/{assembly_2}/peaks_l1/{cluster}/read_depth.txt", assembly_2=assemblies_2, cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/peak_tss_distances/{cluster}/{comparison}/peak_data_hit_coefficients.tsv", assembly_2=["mm10"], comparison=["ss", "xs"], cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/peak_tss_distances/{cluster}/peak_tss_dists_wide.bed", assembly_2=["mm10"], cluster=config["l1_labels"]),


rule great:
    """
    Run great analyses
    """
    input:
        # expand("results/assembly_2/{assembly_2}/great/{cluster}/{comparison}/touch_{quantile}.txt", assembly_2=["mm10"], cluster=config["l1_labels"], comparison=["ss", "xs"], quantile=config["great_quantiles"]),
        expand("results/assembly_2/{assembly_2}/great/{cluster}/{comparison}/touch_{quantile}.txt", assembly_2=["mm10"], cluster=["endothelial_cell"], comparison=["ss"], quantile=config["great_quantiles"]),
        expand("results/assembly_2/{assembly_2}/great/{cluster}/{comparison}/bg_touch_{quantile}.txt", assembly_2=["mm10"], cluster=["endothelial_cell"], comparison=["ss"], quantile=config["great_quantiles"]),
        expand("results/assembly_2/{assembly_2}/hit_footprints/{cluster}/{comparison}/touch_{quantile}.txt", assembly_2=assemblies_2, cluster=["endothelial_cell"], comparison=["ss", "xs"], quantile=config["great_quantiles"]),
        expand("results/assembly_2/{assembly_2}/hit_footprints/{cluster}/{comparison}/counts_touch_{quantile}.txt", assembly_2=assemblies_2, cluster=["endothelial_cell"], comparison=["ss", "xs"], quantile=config["great_quantiles"]),


rule peak_conservation:
    input:
        # expand("results/assembly_2/{assembly_2}/peak_tss_distances/{cluster}/peak_data.tsv", assembly_2=["mm10"], cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/peak_tss_distances/tss_dists.tsv", assembly_2=["mm10"]),
        expand("results/assembly_2/{assembly_2}/peak_conservation/{cluster}/phylop.tsv", assembly_2=["mm10"], cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/exon_conservation/phylop.tsv", assembly_2=["mm10"]),
        expand("results/assembly_2/{assembly_2}/peak_conservation/phylop_all.tsv", assembly_2=["mm10"]),
        expand("results/assembly_2/{assembly_2}/peak_conservation/{cluster}/phastcons.tsv", assembly_2=["mm10"], cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/exon_conservation/phastcons.tsv", assembly_2=["mm10"]),
        expand("results/assembly_2/{assembly_2}/peak_conservation/phastcons_all.tsv", assembly_2=["mm10"]),



rule embeddings:
    """
    Generate chrombpnet embeddings
    """
    input:
        # expand("results/assembly_2/{assembly_2}/chrombpnet_embeddings_l1/{cluster}/embeddings_concat.npy", assembly_2=assemblies_2, cluster=config["l1_labels"]),
        # expand("results/assembly_alt/{assembly_alt}/chrombpnet_embeddings_l1/{cluster}/embeddings_concat.npy", assembly_alt=["cavpor3"], cluster=config["l1_labels"]),
        # expand("results/assembly_alt/{assembly_alt}/collate_contribs_l1_counts_xs/{cluster}/seqs.npz", assembly_alt=["cavpor3"], cluster=config["l1_labels"]),
        # expand("results/assembly_2/{assembly_2}/gkm_embeddings_l1/{cluster}/{comparison}/embeddings.npz", assembly_2=assemblies_2, cluster=config["l1_labels"], comparison=["ss", "xs"]),
        expand("results/assembly_alt/{assembly_alt}/collate_contribs_l1_counts_xs_promoters/{cluster}/seqs.npz", assembly_alt=["cavpor3"], cluster=["endothelial_cell"]),
        expand("results/assembly_alt/{assembly_alt}/peaks_l1_alt/{cluster}/promoter_peaks.bed", assembly_alt=["cavpor3"], cluster=["endothelial_cell"]),
        expand("results/assembly_alt/{assembly_alt}/peaks_l1_alt/{cluster}/promoter_peaks.bed", assembly_alt=["cavpor3"], cluster=["endothelial_cell"]),
        expand("results/assembly_alt/{assembly_alt}/collate_peak_outputs_alt_promoters/{cluster}/peak_outputs.h5", assembly_alt=["cavpor3"], cluster=["endothelial_cell"]),
        expand( "results/assembly_alt/{assembly_alt}/annotate_hit_occurences_l1_alt_promoters/{cluster}/{comparison}/hits_annotated_coefficients.tsv", assembly_alt=["cavpor3"], cluster=["endothelial_cell"], comparison=["ss", "xs"]),



rule finemo:
    """
    Run hit calling
    """
    input: 
        expand("results/assembly_2/{assembly_2}/finemo_call_hits/{cluster}/{comparison}", assembly_2=assemblies_2, cluster=config["l1_labels"], comparison=["ss", "xs"]),
        expand("results/assembly_2/{assembly_2}/finemo_report/{cluster}/{comparison}", assembly_2=assemblies_2, cluster=config["l1_labels"], comparison=["ss", "xs"]),
        expand("results/assembly_2/{assembly_2}/finemo_hit_occurences/{cluster}/{comparison}/hits_annotated_counts.tsv", assembly_2=assemblies_2, cluster=config["l1_labels"], comparison=["ss", "xs"]),
        # expand("results/assembly_2/{assembly_2}/motif_gene_linking_l1/{cluster}/{comparison}/counts.tsv", assembly_2=["mm10"], cluster=config["l1_labels"], comparison=["ss", "xs"]),
        expand("results/assembly_2/{assembly_2}/finemo_hits_bed/{cluster}/browser_data_full.json", assembly_2=["mm10"], cluster=config["l1_labels"]),

rule abc:
    """
    Run ABC enhancer-gene linking
    """
    input:
        expand("results/assembly_2/{assembly_2}/abc/{cluster}/predictions/EnhancerPredictionsAllPutative.tsv.gz", assembly_2=["mm10"], cluster=config["l1_labels"]),
        expand("results/assembly_2/{assembly_2}/abc/{cluster}/motif_gene_linking/{comparison}/counts.tsv", assembly_2=["mm10"], cluster=config["l1_labels"], comparison=["ss", "xs"]),

# rule scarlink:
#     """
#     Run scarlink analyses
#     """
#     input: 
#         expand("results/assembly_2/{assembly_2}/scarlink/analysis/coassay_matrix.h5", assembly_2=["mm10"]),
#         expand(f"results/assembly_2/{{assembly_2}}/scarlink/analysis/gene_linked_tiles_{config['scarlink_celltype_col']}.csv.gz", assembly_2=["mm10"])


rule pre_name_l1:
    """
    Generate all outputs (default)
    """
    input: 
        expand("results/sample/{sample}/rna/seurat_map_orthologs/proj.rds", sample=samples),
        expand("results/sample/{sample}/rna/seurat_plot_qc/qc_plots_merged.pdf", sample=samples),
        "references/seurat_build_reference_human/proj.rds",
        expand("results/sample/{sample}/rna/seurat_transfer_human/proj.rds", sample=samples),
        expand("results/sample/{sample}/rna/seurat_cc_score/proj.rds", sample=samples),
        expand("results/sample/{sample}/rna/seurat_plot_genes/umaps", sample=samples),
        expand("results/sample/{sample}/rna/seurat_cluster/proj.rds", sample=samples),
        # "results_merged/rna/seurat_merge/proj.rds",


rule download_barcode_wl:
    """
    Download barcode whitelist
    """
    output:
        rna = "resources/whitelist_rna.txt",
        atac = "resources/whitelist_atac.txt"
    params:
        url_rna = config["bc_wl_rna"],
        url_atac = config["bc_wl_atac"],
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L -f '{params.url_rna}' | zcat -f > {output.rna}; "
        "curl --no-progress-meter -L -f '{params.url_atac}' | zcat -f > {output.atac}"

rule import_fasta:
    """
    Load genome fastas
    """
    input:
        import_path("assembly/{assembly}/genome.fa"),
    output:
        "inputs/assembly/{assembly}/genome.fa",
    conda:
        "envs/fetch.yaml"
    shell:
        "cp {input} {output}"

rule import_gtf:
    """
    Load genome annotation data
    """
    input:
        import_path("assembly/{assembly}/annotations.gtf.gz"),
    output:
        "inputs/assembly/{assembly}/annotations.gtf.gz",
    conda:
        "envs/fetch.yaml"
    shell:
        "cp {input} {output}"

rule import_blacklists:
    """
    Load genome blacklist
    """
    input:
        import_path("assembly/{assembly}/genome_blacklist.bed"),
    output:
        "inputs/assembly/{assembly}/genome_blacklist.bed",
    conda:
        "envs/fetch.yaml"
    shell:
        "cp {input} {output}"

rule import_chromsizes:
    """
    Load chromosome sizes
    """
    input:
        import_path("assembly/{assembly}/chrom_sizes.txt"),
    output:
        sizes = "inputs/assembly/{assembly}/chrom_sizes.txt",
        chroms = "inputs/assembly/{assembly}/chroms.txt"
    conda:
        "envs/fetch.yaml"
    shell:
        "cp {input} {output.sizes}; "
        "awk '{{ print $1 }}' {output.sizes} > {output.chroms}"

# rule import_folds:
#     """
#     Load chrombpnet training folds
#     """
#     input:
#         import_path("assembly/{assembly}/folds/fold_{fold}.json")
#     output:
#         "inputs/assembly/{assembly}/folds/fold_{fold}.json"
#     conda:
#         "envs/fetch.yaml"
#     shell:
#         "cp -R {input} {output}"

rule import_archr_project:
    """
    Load ATAC ArchR project
    """
    input:
        project = import_path("sample/{sample}/archr_project"),
        qc = import_path("sample/{sample}/archr_qc"),
    output:
        project = directory("inputs/sample/{sample}/archr_project"),
        qc = directory("inputs/sample/{sample}/archr_qc"),
    conda:
        "envs/fetch.yaml"
    shell:
        "cp -R {input.project} {output.project}; "
        "cp -R {input.qc} {output.qc}; "

rule import_counts:
    """
    Load RNA gene counts
    """
    input:
        mat = import_path("sample/{sample}/rna_counts.mtx"),
        features = import_path("sample/{sample}/rna_features.txt"),
        cells = import_path("sample/{sample}/rna_cells.txt"),
    output:
        mat = "inputs/sample/{sample}/rna_counts.mtx",
        features = "inputs/sample/{sample}/rna_features.txt",
        cells = "inputs/sample/{sample}/rna_cells.txt",
    conda:
        "envs/fetch.yaml"
    shell:
        "cp {input.mat} {output.mat}; "
        "cp {input.features} {output.features}; "
        "cp {input.cells} {output.cells}"

rule import_fragments:
    """
    Load RNA gene counts
    """
    input:
        frags = import_path("sample/{sample}/atac_fragments.tsv.gz"),
        index = import_path("sample/{sample}/atac_fragments.tsv.gz.tbi"),
    output:
        frags = "inputs/sample/{sample}/atac_fragments.tsv.gz",
        index = "inputs/sample/{sample}/atac_fragments.tsv.gz.tbi"
    conda:
        "envs/fetch.yaml"
    shell:
        "cp {input.frags} {output.frags}; "
        "cp {input.index} {output.index}"

rule import_fastqs:
    """
    Load ATAC fastq's
    """
    input:
        fastq_1 = import_path("atac_fastq/{sample}_S1_L001_R1_001.fastq.gz"),
        fastq_2 = import_path("atac_fastq/{sample}_S1_L001_R3_001.fastq.gz"),
        fastq_bc = import_path("atac_fastq/{sample}_S1_L001_R2_001.fastq.gz"),
    output:
        fastq_1 = "inputs/sample/{sample}/fastq/R1.fastq.gz",
        fastq_2 = "inputs/sample/{sample}/fastq/R3.fastq.gz",
        fastq_bc = "inputs/sample/{sample}/fastq/R2.fastq.gz",
    conda:
        "envs/fetch.yaml"
    resources:
        mem_mb = 40000,
    shell:
        "cp {input.fastq_1} {output.fastq_1}; "
        "cp {input.fastq_2} {output.fastq_2}; "
        "cp {input.fastq_bc} {output.fastq_bc}; "

rule build_atac_observed_wl:
    """
    Filter for human/mouse orthologs
    """
    input:
        "resources/whitelist_atac.txt"
    output:
        "resources/whitelist_atac_observed.txt"
    params:
        bc_prefix = config["bc_prefix"]
    conda:
        "envs/fetch.yaml"
    script:
        "scripts/build_atac_observed_wl.py"

rule filter_orthologs:
    """
    Filter for one-to-one mouse/GP orthologs
    """
    input:
        "resources/orthologs.tsv"
    output:
        "resources/orthologs_filtered.tsv"
    conda:
        "envs/fetch.yaml"
    script:
        "scripts/filter_orthologs.py"

rule filter_orthologs_human:
    """
    Filter for human/mouse orthologs
    """
    input:
        "resources/orthologs_human.tsv"
    output:
        "resources/orthologs_human_filtered.tsv"
    conda:
        "envs/fetch.yaml"
    script:
        "scripts/filter_orthologs_human.py"

rule calc_effective_sizes:
    """
    Calculate effective chromosome sizes
    """
    input:
        "inputs/assembly/{assembly}/genome.fa"
    output:
        eff_sizes = "inputs/assembly/{assembly}/effective_sizes.txt",
        true_sizes = "inputs/assembly/{assembly}/true_sizes.txt",
        total_size = "inputs/assembly/{assembly}/total_size.txt"
    conda:
        "envs/fetch.yaml"
    script:
        "scripts/calc_effective_sizes.py"

rule index_genome:
    """
    Build genome faidx 
    """
    input:
        "inputs/assembly/{assembly}/genome.fa"
    output:
        "inputs/assembly/{assembly}/genome.fa.fai"
    conda:
        "envs/chrombpnet.yaml"
    shell:
        "samtools faidx -o {output} {input}"

rule build_folds:
    """
    Build chrombpnet training folds
    """
    input:
        "inputs/assembly/{assembly}/true_sizes.txt"
    output:
        [f"inputs/assembly/{{assembly}}/folds/fold_{fold}.json" for fold in range(config["num_folds"])]
    params:
        train_val_ratio = config["train_val_ratio"],
        mito_chr = lambda w: config["mito_chr"][w.assembly],
        chr_prefix = lambda w: config["chr_prefix"][w.assembly]
    conda:
        "envs/fetch.yaml"
    script:
        "scripts/build_folds.py"

# rule filter_sizes:
#     """
#     Filter chromsizes
#     """
#     input:
#         in_sizes = "inputs/assembly/{assembly}/true_sizes.txt"
#     output:
#         out_sizes = "inputs/assembly/{assembly}/true_sizes_filtered.txt",
#     conda:
#         "envs/fetch.yaml"
#     script:
#         "scripts/filter_chromsizes.py"

rule sort_genome_sizes:
    """
    Calculate effective chromosome sizes
    """
    input:
        "inputs/assembly/{assembly}/true_sizes.txt"
    output:
        "inputs/assembly/{assembly}/true_sizes_sorted.txt"
    conda:
        "envs/fetch.yaml"
    shell:
         "sort -k 1,1 {input} > {output}"


rule download_amulet_lowmem:
    """
    Download AMULET utility
    """
    output:
        directory("resources/amulet")
    params:
        repo = config["amulet_repo"]
    conda:
        "envs/fetch.yaml"
    shell:
        "git clone {params.repo} {output}"

rule download_amulet:
    """
    Download AMULET utility (low memory patch)
    """
    output:
        directory("resources/amulet_lowmem")
    params:
        repo = config["amulet_repo_lowmem"]
    conda:
        "envs/fetch.yaml"
    shell:
        "git clone {params.repo} {output}"
