import os
import json
from snakemake.utils import min_version

min_version("6.6.1")

configfile: 
    "config/config.yaml"

superregion_map = {}
for s, v in config["superregions"].items():
    for r in v:
        superregion_map[r] = s

samples = []
groups = {}
groups_unisex = {}
donor_data = {}
supergroups = {}
sample_to_group = {}
with open("config/samples.tsv") as sample_file:
    h = sample_file.readline().rstrip('\n').split('\t')
    series_ind = h.index("Series")
    region_ind = h.index("Region")
    status_ind = h.index("Status")
    sex_ind = h.index("Sex")
    donor_ind = h.index("Donor")

    samples = []
    for line in sample_file:
        if line.startswith("#"):
            continue
        entries = line.rstrip('\n').split('\t')
        series = entries[series_ind]
        region = entries[region_ind]
        status = entries[status_ind]
        sex = entries[sex_ind]
        donor = entries[donor_ind]

        group_name = f"{region}-{status}-{sex}"
        samples.append(series)
        groups.setdefault(group_name, []).append(series)
        groups_unisex.setdefault(f"{region}-{status}", []).append(series)
        donor_data[series] = donor
        sample_to_group[series] = group_name

        supergroup_name = superregion_map[region]
        supergroups.setdefault(supergroup_name, set()).add(group_name)

supergroups = {k: list(v) for k, v in supergroups.items()}
group_to_supergroup = {}
for k, v in supergroups.items():
    for g in v:
        group_to_supergroup[g] = k

group_names = list(groups.keys())
group_unisex_names = list(groups_unisex.keys())
supergroup_names = list(supergroups.keys())

hic_regions = set(config["hic_regions"])
hic_labels = config["hic_labels"]

groups_hic = {}
groups_unisex_hic = {}
hic_files = {}
with open("config/hic_samples.tsv") as sample_file:
    h = sample_file.readline().rstrip('\n').split('\t')
    exp_ind = h.index("Experiment")
    region_ind = h.index("Region")
    status_ind = h.index("Status")
    sex_ind = h.index("Sex")
    file_ind = h.index("File")

    samples = []
    for line in sample_file:
        if line.startswith("#"):
            continue
        entries = line.rstrip('\n').split('\t')
        experiment = entries[exp_ind]
        region = entries[region_ind]
        # print(region) ####
        if region not in hic_regions:
            # print(region) ####
            continue
        status = entries[status_ind]
        sex = entries[sex_ind]
        file = entries[file_ind]

        group_name = f"{region}-{status}-{sex}"
        groups_hic.setdefault(group_name, []).append(experiment)
        groups_unisex_hic.setdefault(f"{region}-{status}", []).append(experiment)
        hic_files[experiment] = file

group_unisex_names_hic = list(groups_unisex_hic.keys())

as_regions = set(config["as_regions"])
as_labels = config["as_labels"]

donor_to_wgs_dir = {}
with open("config/wgs_samples.tsv") as sample_file:
    h = sample_file.readline().rstrip('\n').split('\t')
    donor_ind = h.index("Donor")
    dir_ind = h.index("Directory")

    for line in sample_file:
        if line.startswith("#"):
            continue
        entries = line.rstrip('\n').split('\t')
        donor = entries[donor_ind]
        wdir = entries[dir_ind]
        donor_to_wgs_dir[donor] = wdir

groups_wgs = {}
for k, v in groups.items():
    region = k.split("-")[0]
    if region not in as_regions:
        continue
    for s in v:
        donor = donor_data[s]
        if donor in donor_to_wgs_dir:
            wgs_dir = donor_to_wgs_dir[donor]
            groups_wgs.setdefault(k, []).append((wgs_dir, s),)

groups_unisex_wgs = {}
for k, v in groups_unisex.items():
    region = k.split("-")[0]
    if region not in as_regions:
        continue
    for s in v:
        donor = donor_data[s]
        if donor in donor_to_wgs_dir:
            wgs_dir = donor_to_wgs_dir[donor]
            groups_unisex_wgs.setdefault(k, []).append((wgs_dir, s),)

group_unisex_names_wgs = list(groups_unisex_wgs.keys())

# subcluster_jobs = []
# with open(workflow.source_path("files/subcluster_jobs.txt")) as f:
#     for line in f:
#         subcluster_jobs.append(line.rstrip("\n"))

# with open("config/rna_cluster_names.txt") as f:
#     rna_cluster_names = []
#     for line in f:
#         rna_cluster_names.append(line.rstrip("\n"))
        
workdir: 
    config['workdir']

max_threads = config["max_threads_per_rule"]

# def script_path(script_name):
#     return str(workflow.source_path(script_name))

include:
    "rules/atac.smk"
include:
    "rules/rna.smk"
include:
    "rules/reference.smk"
include:
    "rules/rna_group.smk"
include:
    "rules/rna_import_emb.smk"
include:
    "rules/atac_group.smk"
include:
    "rules/rna_name_group.smk"
include:
    "rules/rna_subcluster_cleanup.smk"
include:
    "rules/rna_name_subclusters.smk"
include:
    "rules/rna_name_subclusters_refined.smk"
include:
    "rules/rna_l1_labeled.smk"
include:
    "rules/export.smk"
include:
    "rules/diff_exp.smk"
include:
    "rules/soupx.smk"
include:
    "rules/abc.smk"
include:
    "rules/allele_specific.smk"

rule all:
    """
    Generate all outputs (default)
    """
    input: 
        # expand("results_supergroups/{supergroup}/{label}/rna/seurat_name_subclusters_supergroup/proj.rds", supergroup=supergroup_names, label=config["supergroup_labels"]),
        # expand("results_supergroups/{supergroup}/{label}/rna/seurat_named_subcluster_markers_supergroup/markers.tsv", supergroup=supergroup_names, label=config["supergroup_labels"]),
        # expand("results_groups/{group}/rna/seurat_subclusters_supergroups_to_groups/proj.rds", group=group_names),
        # expand("results_groups/{group}/rna/seurat_named_cluster_markers_refined/markers.tsv", group=group_names),
        # expand("results_groups/{group}/rna/seurat_add_seq_emb/proj.rds", group=group_names),
        # expand("results_groups/{group}/rna/seurat_embed_all/proj.rds", group=group_names),
        # expand("results/{sample}/rna/seurat_soupx_filtered/proj.rds", group=group_names),


rule allele_specific:
    """
    Run allele-specific analysis
    """
    input:
        expand("results_unisex/{group}/allele_specific/label/{label}/atac/rasqual_output.txt", group=group_unisex_names_wgs, label=as_labels)


rule abc:
    """
    Run ABC enhancer-gene linking
    """
    input:
        expand("results_unisex/{group}/abc/{label}/neighborhoods/GeneList.txt", group=group_unisex_names_hic, label=hic_labels),
        [
            f"results_unisex/{group}/abc/{label}/predictions/{experiment}/EnhancerPredictionsAllPutative.tsv.gz"
            for group in group_unisex_names if group in groups_unisex_hic
            for label in config["hic_labels"]
            for experiment in groups_unisex_hic[group]
        ]


rule soupx:
    """
    Run ambient RNA decontamination on filtered data
    """
    input: 
        expand("results/{sample}/rna/seurat_soupx_filtered/proj.rds", sample=samples),
        "qc_all/seurat_soupx_filtered.pdf",
        expand("results/{sample}/rna/seurat_soupx_filtered/adjusted_counts/counts_mat.mtx", sample=samples),

rule export_l1:
    """
    Export level-1 cell type data
    """
    input: 
        expand("export_l1/{group}/metadata.tsv.gz", group=group_names),
        expand("export_l1/{group}/datasets.txt", group=group_names),
        expand("export_l1/{group}/labels/cell_types_l1.tsv.gz", group=group_names),
        expand("export_l1/{group}/markers", group=group_names),
        expand("export_l1/{group}/embeddings/rna_pca.tsv.gz", group=group_names),
        expand("export_l1/{group}/embeddings/rna_harmony.tsv.gz", group=group_names),
        expand("export_l1/{group}/embeddings/rna_umap.tsv.gz", group=group_names),
        expand("export_l1/{group}/embeddings/atac_lsi.tsv.gz", group=group_names),
        expand("export_l1/{group}/embeddings/atac_harmony.tsv.gz", group=group_names),
        expand("export_l1/{group}/embeddings/atac_umap.tsv.gz", group=group_names),
        expand("export_l1/{group}/figures.tar.gz", group=group_names),

# rule subcluster_naming:
#     """
#     Subcluster naming (pre-supergroup-refinement)
#     """
#     input: 
#         expand("results_subcluster_unified/{label}/rna/seurat_name_subclusters/proj.rds", label=config["l1_labels"]),
#         expand("results_subcluster_unified/{label}/rna/seurat_named_subcluster_markers/markers.tsv", label=config["l1_labels"]),
#         expand("results_groups/{group}/rna/seurat_subclusters_to_groups/proj.rds", group=group_names),
#         expand("results_groups/{group}/rna/seurat_named_cluster_markers/markers.tsv", group=group_names),
#         expand("results_supergroups/{supergroup}/{label}/rna/seurat_subcluster_supergroup/proj.rds", supergroup=supergroup_names, label=config["supergroup_labels"]),
#         expand("results_supergroups/{supergroup}/{label}/rna/seurat_subcluster_supergroup/label_plots.pdf", supergroup=supergroup_names, label=config["supergroup_labels"]),
#         expand("results_supergroups/{supergroup}/{label}/rna/seurat_subcluster_supergroup/label_plots.pdf", supergroup=supergroup_names, label=config["supergroup_labels"]),
#         expand("results_supergroups/{supergroup}/{label}/rna/seurat_plot_subcluster_markers_supergroup/markers.tsv", supergroup=supergroup_names, label=config["supergroup_labels"]),
#         expand("results_supergroups/{supergroup}/{label}/rna/seurat_plot_subcluster_genes_supergroup/heatmap.pdf", supergroup=supergroup_names, label=config["supergroup_labels"]),

rule diff_exp:
    """
    Finalize level-1 labels
    """
    input: 
        expand("results_diff_exp/label/{label}/collate_pseudobulks/mat.tsv", label=config["l1_labels"]),
        expand("results_diff_exp/label/{label}/run_deseq/coefficients", label=config["l1_labels"]),
        expand("results_diff_exp/label/{label}/run_deseq_sva/coefficients", label=config["l1_labels"]),
        expand("results_diff_exp/label/{label}/run_deseq_lv/coefficients", label=config["l1_labels"]),
        expand("results_diff_exp/label/{label}/run_deseq_sva_lv/coefficients", label=config["l1_labels"]),

rule l1_finalize:
    """
    Finalize level-1 labels
    """
    input: 
        # expand("results_subcluster/{supergroup}/{label}/rna/seurat_subcluster_2/proj.rds", supergroup=supergroup_names, label=config["l1_labels"]),
        # expand("results_subcluster/{supergroup}/{label}/rna/seurat_subcluster_2/label_plots.pdf", supergroup=supergroup_names, label=config["l1_labels"]),
        # expand("results_subcluster/{supergroup}/{label}/rna/seurat_plot_subcluster_markers/dotplot.pdf", supergroup=supergroup_names, label=config["l1_labels"]),
        expand("results_subcluster_unified/{label}/rna/seurat_subcluster_cleanup/proj.rds", label=config["l1_labels"]),
        expand("results_subcluster_unified/{label}/rna/seurat_subcluster_2/proj.rds", label=config["l1_labels"]),
        expand("results_subcluster_unified/{label}/rna/seurat_subcluster_2/label_plots.pdf", label=config["l1_labels"]),
        expand("results_subcluster_unified/{label}/rna/seurat_plot_subcluster_markers/markers.tsv", label=config["l1_labels"]),
        expand("results_subcluster_unified/{label}/rna/seurat_plot_subcluster_genes/dotplot.pdf", label=config["l1_labels"]),
        expand("results_groups/{group}/rna/seurat_clusters_to_groups/proj.rds", group=group_names),
        expand("results_groups/{group}/rna/seurat_write_rna_all/metadata.tsv", group=group_names),
        expand("results_groups/{group}/rna/seurat_write_l1_markers/markers", group=group_names),

rule name_1:
    """
    Generate outputs dependent on first-level cluster naming
    """
    input: 
        expand("results_groups/{group}/rna/seurat_name_group_1/proj.rds", group=groups.keys()),
        expand("results_groups/{group}/rna/seurat_plot_genes/dotplot.pdf", group=groups.keys()),
        # expand("results_subcluster/{supergroup}/{label}/rna/seurat_subcluster_supergroup/proj.rds", supergroup=supergroup_names, label=config["l1_labels"]),
        # expand("results_subcluster/{supergroup}/{label}/rna/seurat_subcluster_supergroup/label_plots.pdf", supergroup=supergroup_names, label=config["l1_labels"]),
        expand("results_subcluster_unified/{label}/rna/seurat_subcluster_unified/proj.rds", label=config["l1_labels"]),
        expand("results_subcluster_unified/{label}/rna/seurat_subcluster_unified/label_plots.pdf", label=config["l1_labels"]),
        expand("results_subcluster_unified/{label}/rna/seurat_integrate_subgroups_unified/umap_group_harmony.pdf", label=config["l1_labels"]),

rule pre_name:
# rule all:
    """
    Generate all outputs prior to cluster naming
    """
    input: 
        # expand("results/{sample}/rna/seurat_transfer_rna/labels_kramann.tsv", sample=samples),
        # expand("results/{sample}/rna/seurat_transfer_rna/labels_teichmann.tsv", sample=samples),
        # expand("results/{sample}/rna/seurat_transfer_rna/labels_ellinor.tsv", sample=samples),
        # expand("results/{sample}/rna/seurat_transfer_rna/labels_azimuth.tsv", sample=samples),
        expand("results_groups/{group}/rna/seurat_integrate_rna/proj.rds", group=groups.keys()),
        expand("results_groups/{group}/rna/seurat_integrate_rna/integration_plots.pdf", group=groups.keys()),
        expand("results_groups/{group}/rna/seurat_cluster_rna/proj.rds", group=groups.keys()),
        expand("results_groups/{group}/rna/seurat_load_transfer_labels/reference_plots.pdf", group=groups.keys()),
        expand("results_groups/{group}/rna/seurat_embed_test/proj.rds", group=groups.keys()),
        expand("results_groups/{group}/rna/seurat_plot_cm/mats.pdf", group=groups.keys()),
        # expand("results_groups/{group}/atac/archr_embed", group=groups.keys()),
        # expand("results_groups/{group}/atac/archr_rna_clustering", group=groups.keys()),
        # expand("results_groups/{group}/rna/seurat_add_batch_data/proj.rds", group=groups.keys()),
        # expand("results_groups/{group}/rna/seurat_add_multiplexing_data/proj.rds", group=groups.keys()),
        # expand("results_groups/{group}/rna/seurat_plot_cluster_markers/dotplot.pdf", group=groups.keys()), 
        # expand("results_groups/{group}/atac/archr_plot_genes/umaps", group=groups.keys()),
                           
        # "results_merged/atac/archr_init",
        # "results_merged/atac/archr_label",
        # "export/atac/embeddings/harmony.tsv.gz",
        # "export/atac/embeddings/umap.tsv.gz",
        # "export/atac/labels/cell_types.tsv.gz",
        # "export/atac/markers",
        # "export/atac/metadata.tsv.gz",
        # "export/atac/figures.tar.gz",
        # "export/atac/datasets.txt",
        # "results_merged/rna/seurat_name_rna/proj.rds",
        # "export/rna/embeddings/harmony.tsv.gz",
        # "export/rna/embeddings/umap.tsv.gz",
        # "export/rna/labels/cell_types.tsv.gz",
        # "export/rna/markers",
        # "export/rna/metadata.tsv.gz",
        # "export/rna/figures.tar.gz",
        # "export/rna/datasets.txt"

rule harmonize_seq_emb:
    """
    Harmonize sequence embeddingså
    """
    input:
        expand("results_groups/{group}/rna/seurat_add_seq_emb/proj.rds", group=groups.keys()),

rule download_gtf:
    """
    Download GTF data
    """
    output:
        "results_merged/fetch/GRCh38.gtf.gz"
    params:
        url = config["gtf"]
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L {params.url} > {output}"

def get_series(w):
    x = w.sample
    return x.split("-")[0]

def get_replicate(w):
    x = w.sample
    return x.split("-")[1]

rule get_dataset_info:
    """
    Query ENCODE portal for dataset information
    """
    output:
        "results/{sample}/fetch/dataset_info.json"
    params:
        series = get_series,
        replicate = 1,
        dcc_mode = config["dcc_mode"],
        dcc_api_key = os.environ.get("DCC_API_KEY"), 
        dcc_secret_key = os.environ.get("DCC_SECRET_KEY")
    log:
        directory("logs/{sample}/get_dataset_info")
    conda:
        "envs/fetch.yaml"
    script:
        "scripts/get_dataset_info.py"


rule query_fragments:
    """
    Query ENCODE portal for fragments URL
    """
    output:
        "results/{sample}/fetch/fragments_url.txt"
    params:
        series = get_series,
        replicate = 1,
        dcc_mode = config["dcc_mode"],
        dcc_api_key = os.environ.get("DCC_API_KEY"), 
        dcc_secret_key = os.environ.get("DCC_SECRET_KEY")
    log:
        directory("logs/{sample}/query_fragments")
    conda:
        "envs/fetch.yaml"
    script:
        "scripts/get_fragments_url.py"

rule download_fragments:
    """
    Download fragments tarball
    """
    input:
        "results/{sample}/fetch/fragments_url.txt"
    output:
        "results/{sample}/fetch/fragments.tar.gz"
    params:
        usr = os.environ.get("DCC_API_KEY"),
        pwd = os.environ.get("DCC_SECRET_KEY")
    conda:
        "envs/fetch.yaml"
    shell:
        "filename=$(basename -- \"$(< {input})\"); "
        "extension=\"${{filename##*.}}\"; "
        "fid=\"${{filename%%.*}}\"; "
        "alturl=\"https://www.encodeproject.org/files/${{fid}}/@@download/${{filename}}\"; "
        "echo $alturl; " ####
        "curl --no-progress-meter -L -f $(< {input}) > {output} || "
        "curl --no-progress-meter -L -f -u {params.usr}:{params.pwd} $(< {input}) > {output} || "
        "curl --no-progress-meter -L -f -u {params.usr}:{params.pwd} \"$alturl\" > {output} "

rule extract_fragments:
    """
    Extract fragments file from tarball
    """
    input:
        "results/{sample}/fetch/fragments.tar.gz"
    output:
        directory("results/{sample}/fetch/fragments_extracted")
    conda:
        "envs/fetch.yaml"
    shell:
        "mkdir -p {output}; "
        "tar -xzf {input} --transform='s/.*\///' -C {output}"

rule move_fragments:
    """
    Move fragments data to final location
    """
    input:
        "results/{sample}/fetch/fragments_extracted"
    output:
        frag = "results/{sample}/fetch/fragments.tsv.gz",
        frag_ind = "results/{sample}/fetch/fragments.tsv.gz.tbi"
    conda:
        "envs/fetch.yaml"
    shell:
        "cp {input}/fragments.tsv.gz {output.frag}; "
        "cp {input}/fragments.tsv.gz.tbi {output.frag_ind};"


rule query_bam_atac:
    """
    Query ENCODE portal for BAM
    """
    output:
        "results/{sample}/fetch/bam_atac_url.txt"
    params:
        series = get_series,
        replicate = 1,
        dcc_mode = config["dcc_mode"],
        dcc_api_key = os.environ.get("DCC_API_KEY"), 
        dcc_secret_key = os.environ.get("DCC_SECRET_KEY")
    log:
        directory("logs/{sample}/query_bam_atac")
    conda:
        "envs/fetch.yaml"
    script:
        "scripts/get_bam_atac_url.py"


rule download_bam_atac:
    """
    Download BAM
    """
    input:
        "results/{sample}/fetch/bam_atac_url.txt"
    output:
        "results/{sample}/fetch/atac.bam"
    params:
        usr = os.environ.get("DCC_API_KEY"),
        pwd = os.environ.get("DCC_SECRET_KEY")
    conda:
        "envs/fetch.yaml"
    shell:
        "filename=$(basename -- \"$(< {input})\"); "
        "extension=\"${{filename##*.}}\"; "
        "fid=\"${{filename%%.*}}\"; "
        "alturl=\"https://www.encodeproject.org/files/${{fid}}/@@download/${{filename}}\"; "
        # "echo $alturl; " ####
        "curl --no-progress-meter -L -f $(< {input}) > {output} || "
        "curl --no-progress-meter -L -f -u {params.usr}:{params.pwd} $(< {input}) > {output} || "
        "curl --no-progress-meter -L -f -u {params.usr}:{params.pwd} \"$alturl\" > {output} "



rule query_expression:
    """
    Query ENCODE portal for gene expression matrix URL
    """
    output:
        "results/{sample}/fetch/expression_url.txt"
    params:
        series = get_series,
        replicate = 1,
        dcc_mode = config["dcc_mode"],
        dcc_api_key = os.environ.get("DCC_API_KEY"), 
        dcc_secret_key = os.environ.get("DCC_SECRET_KEY")
    log:
        directory("logs/{sample}/query_expression")
    conda:
        "envs/fetch.yaml"
    script:
        "scripts/get_expression_url.py"

rule query_expression_raw:
    """
    Query ENCODE portal for gene expression matrix URL
    """
    output:
        "results/{sample}/fetch/expression_raw_url.txt"
    params:
        series = get_series,
        replicate = 1,
        dcc_mode = config["dcc_mode"],
        dcc_api_key = os.environ.get("DCC_API_KEY"), 
        dcc_secret_key = os.environ.get("DCC_SECRET_KEY")
    log:
        directory("logs/{sample}/query_expression_raw")
    conda:
        "envs/fetch.yaml"
    script:
        "scripts/get_expression_raw_url.py"

rule download_expression:
    """
    Download expression tarball
    """
    input:
        "results/{sample}/fetch/expression_url.txt"
    output:
        "results/{sample}/fetch/expression.tar.gz"
    params:
        usr = os.environ.get("DCC_API_KEY"),
        pwd = os.environ.get("DCC_SECRET_KEY")
    conda:
        "envs/fetch.yaml"
    shell:
        "filename=$(basename -- \"$(< {input})\"); "
        "extension=\"${{filename##*.}}\"; "
        "fid=\"${{filename%%.*}}\"; "
        "alturl=\"https://www.encodeproject.org/files/${{fid}}/@@download/${{filename}}\"; "
        "echo $alturl; " ####
        "curl --no-progress-meter -L -f $(< {input}) > {output} || "
        "curl --no-progress-meter -L -f -u {params.usr}:{params.pwd} $(< {input}) > {output} || "
        "curl --no-progress-meter -L -f -u {params.usr}:{params.pwd} \"$alturl\" > {output} "

rule download_expression_raw:
    """
    Download raw expression tarball
    """
    input:
        "results/{sample}/fetch/expression_raw_url.txt"
    output:
        "results/{sample}/fetch/expression_raw.tar.gz"
    params:
        usr = os.environ.get("DCC_API_KEY"),
        pwd = os.environ.get("DCC_SECRET_KEY")
    conda:
        "envs/fetch.yaml"
    shell:
        "filename=$(basename -- \"$(< {input})\"); "
        "extension=\"${{filename##*.}}\"; "
        "fid=\"${{filename%%.*}}\"; "
        "alturl=\"https://www.encodeproject.org/files/${{fid}}/@@download/${{filename}}\"; "
        "echo $alturl; " ####
        "curl --no-progress-meter -L -f $(< {input}) > {output} || "
        "curl --no-progress-meter -L -f -u {params.usr}:{params.pwd} $(< {input}) > {output} || "
        "curl --no-progress-meter -L -f -u {params.usr}:{params.pwd} \"$alturl\" > {output} "

rule extract_expression:
    """
    Extract expression data from tarball
    """
    input:
        "results/{sample}/fetch/expression.tar.gz"
    output:
        directory("results/{sample}/fetch/expression_extracted")
    conda:
        "envs/fetch.yaml"
    shell:
        "mkdir -p {output}; "
        "tar -xzf {input} --transform='s/.*\///' -C {output}"

rule extract_expression_raw:
    """
    Extract expression data from tarball
    """
    input:
        "results/{sample}/fetch/expression_raw.tar.gz"
    output:
        directory("results/{sample}/fetch/expression_raw_extracted")
    conda:
        "envs/fetch.yaml"
    shell:
        "mkdir -p {output}; "
        "tar -xzf {input} --transform='s/.*\///' -C {output}"

rule move_expression:
    """
    Move expression data files to final location
    """
    input:
        "results/{sample}/fetch/expression_extracted"
    output:
        matrix = "results/{sample}/fetch/matrix.mtx",
        features = "results/{sample}/fetch/features.tsv",
        barcodes = "results/{sample}/fetch/barcodes.tsv",
    conda:
        "envs/fetch.yaml"
    shell:
        "cp {input}/matrix.mtx {output.matrix}; "
        "cp {input}/features.tsv {output.features}; "
        "cp {input}/barcodes.tsv {output.barcodes};"

rule move_expression_raw:
    """
    Move expression data files to final location
    """
    input:
        "results/{sample}/fetch/expression_raw_extracted"
    output:
        matrix = "results/{sample}/fetch/matrix_raw.mtx",
        features = "results/{sample}/fetch/features_raw.tsv",
        barcodes = "results/{sample}/fetch/barcodes_raw.tsv",
    conda:
        "envs/fetch.yaml"
    shell:
        "cp {input}/matrix.mtx {output.matrix}; "
        "cp {input}/features.tsv {output.features}; "
        "cp {input}/barcodes.tsv {output.barcodes};"

rule download_kramann_data:
    """
    Download reference data
    """
    output:
        "reference/kramann/fetch/data.h5ad"
    params:
        url = config["rna_ref_kramann"]["unified"]
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L -f '{params.url}' > {output}"

rule download_kramann_subcluster_data:
    """
    Download reference subcluster data
    """
    output:
        "reference/kramann/fetch/{subtype}/data.rds"
    params:
        url = lambda w: config["rna_ref_kramann"]["subtypes"][w.subtype]
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L -f '{params.url}' > {output}"

rule download_ellinor_data:
    """
    Download reference data
    """
    output:
        mat = "reference/ellinor/fetch/matrix.mtx.gz",
        features = "reference/ellinor/fetch/features.tsv.gz",
        cells = "reference/ellinor/fetch/barcodes.tsv.gz",
        mat_raw = "reference/ellinor/fetch/matrix_raw.mtx.gz",
        features_raw = "reference/ellinor/fetch/features_raw.tsv.gz",
        cells_raw = "reference/ellinor/fetch/barcodes_raw.tsv.gz",
        metadata = "reference/ellinor/fetch/metadata.tsv.gz"
    params:
        url_mat = config["rna_ref_ellinor"]["mat"],
        url_features = config["rna_ref_ellinor"]["features"],
        url_cells = config["rna_ref_ellinor"]["cells"],
        url_mat_raw = config["rna_ref_ellinor"]["mat_raw"],
        url_features_raw = config["rna_ref_ellinor"]["features_raw"],
        url_cells_raw = config["rna_ref_ellinor"]["cells_raw"],
        url_metadata = config["rna_ref_ellinor"]["metadata"],
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L -f '{params.url_mat}' > {output.mat}; "
        "curl --no-progress-meter -L -f '{params.url_features}' > {output.features}; "
        "curl --no-progress-meter -L -f '{params.url_cells}' > {output.cells}; "
        "curl --no-progress-meter -L -f '{params.url_mat_raw}' > {output.mat_raw}; "
        "curl --no-progress-meter -L -f '{params.url_features_raw}' > {output.features_raw}; "
        "curl --no-progress-meter -L -f '{params.url_cells_raw}' > {output.cells_raw}; "
        "curl --no-progress-meter -L -f '{params.url_metadata}' > {output.metadata}; "

rule download_teichmann_data:
    """
    Download reference data
    """
    output:
        "reference/teichmann/fetch/data.h5ad"
    params:
        url = config["rna_ref_teichmann_data"]
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L -f '{params.url}' > {output}"

rule download_azimuth_data:
    """
    Download reference data
    """
    output:
        tarball = "reference/azimuth/heartref.SeuratData_1.0.0.tar.gz",
        data = directory("reference/azimuth/heartref")
    params:
        url = config["azimuth_ref"]
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L -f '{params.url}' > {output.tarball}; "
        "mkdir -p {output.data}; "
        "tar -xf {output.tarball} -C {output.data}"

rule download_barcode_wl:
    """
    Download barcode whitelist
    """
    output:
        rna = "resources/whitelist_rna.txt",
        atac = "resources/whitelist_atac.txt"
    params:
        url_rna = config["bc_wl_rna"],
        url_atac = config["bc_wl_atac"],
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L -f '{params.url_rna}' | zcat -f > {output.rna}; "
        "curl --no-progress-meter -L -f '{params.url_atac}' | zcat -f > {output.atac}"

rule download_archr_resources:
    """
    Download genome data for ArchR
    """
    output:
        bsgenome = "resources/bsgenome.tar.gz",
        anno = "resources/gene_annotation.rda"
    params:
        url_bsgenome = config["bsgenome"],
        url_anno = config["archr_gene_anno"],
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L -f '{params.url_bsgenome}' > {output.bsgenome}; "
        "curl --no-progress-meter -L -f '{params.url_anno}' > {output.anno}"

rule download_chromsizes:
    """
    Download chromsizes data
    """
    output:
       sizes = "resources/GRCh38.chrom.sizes.tsv",
       chroms = "resources/GRCh38.chroms.tsv"
    params:
        url = config["chrom_sizes"]
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L -f '{params.url}' > {output.sizes}; "
        "awk '{{ print $1 }}' {output.sizes} > {output.chroms}"

rule download_blacklist:
    """
    Download genome blacklist
    """
    output:
        "resources/blacklist.bed"
    params:
        url = config["blacklist"]
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L -f '{params.url}' | zcat -f > {output}"

rule download_amulet_lowmem:
    """
    Download AMULET utility
    """
    output:
        directory("resources/amulet")
    params:
        repo = config["amulet_repo"]
    conda:
        "envs/fetch.yaml"
    shell:
        "git clone {params.repo} {output}"

rule download_amulet:
    """
    Download AMULET utility (low memory patch)
    """
    output:
        directory("resources/amulet_lowmem")
    params:
        repo = config["amulet_repo_lowmem"]
    conda:
        "envs/fetch.yaml"
    shell:
        "git clone {params.repo} {output}"
