import os
import json
from snakemake.utils import min_version

min_version("6.6.1")

configfile: 
    "config/config.yaml"

with open("config/samples_rna.tsv") as sample_file:
    h = sample_file.readline().rstrip('\n').split('\t')
    exp_ind = h.index("Experiment")
    rep_ind = h.index("Replicate")
    samples_rna = []
    for line in sample_file:
        if line.startswith("#"):
            continue
        entries = line.rstrip('\n').split('\t')
        exp = entries[exp_ind]
        rep = entries[rep_ind]
        samples_rna.append(f"{exp}-{rep}")

with open("config/samples_atac.tsv") as sample_file:
    h = sample_file.readline().rstrip('\n').split('\t')
    exp_ind = h.index("Experiment")
    rep_ind = h.index("Replicate")
    samples_atac = []
    for line in sample_file:
        if line.startswith("#"):
            continue
        entries = line.rstrip('\n').split('\t')
        exp = entries[exp_ind]
        rep = entries[rep_ind]
        samples_atac.append(f"{exp}-{rep}")

with open("config/samples_multiome.tsv") as sample_file:
    h = sample_file.readline().rstrip('\n').split('\t')
    exp_rna_ind = h.index("Experiment_RNA")
    rep_rna_ind = h.index("Replicate_RNA")
    exp_atac_ind = h.index("Experiment_ATAC")
    rep_atac_ind = h.index("Replicate_ATAC")
    samples_multiome = []
    for line in sample_file:
        if line.startswith("#"):
            continue
        entries = line.rstrip('\n').split('\t')
        exp_rna = entries[exp_rna_ind]
        rep_rna = entries[rep_rna_ind]
        exp_atac = entries[exp_atac_ind]
        rep_atac = entries[rep_atac_ind]
        samples_multiome.append(f"{exp_rna}-{rep_rna}_{exp_atac}-{rep_atac}")

# with open("config/rna_cluster_names.txt") as f:
#     rna_cluster_names = []
#     for line in f:
#         rna_cluster_names.append(line.rstrip("\n"))
        
workdir: 
    config['workdir']

max_threads = config["max_threads_per_rule"]

# def script_path(script_name):
#     return str(workflow.source_path(script_name))

include:
    "rules/atac.smk"
include:
    "rules/rna.smk"

rule all:
    """
    Generate all outputs (default)
    """
    input: 
        "results_merged/atac/archr_linkage",
        "export/atac/embeddings/harmony.tsv.gz",
        "export/atac/embeddings/umap.tsv.gz",
        "export/atac/labels/cell_types.tsv.gz",
        "export/atac/markers",
        "export/atac/metadata.tsv.gz",
        "export/atac/figures.tar.gz",
        "export/atac/datasets.txt",
        "results_merged/rna/seurat_name_rna/proj.rds",
        "export/rna/embeddings/harmony.tsv.gz",
        "export/rna/embeddings/umap.tsv.gz",
        "export/rna/labels/cell_types.tsv.gz",
        "export/rna/markers",
        "export/rna/metadata.tsv.gz",
        "export/rna/figures.tar.gz",
        "export/rna/datasets.txt"

rule download_gtf:
    """
    Download GTF data
    """
    output:
        "results_merged/fetch/GRCh38.gtf.gz"
    params:
        url = config["gtf"]
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L {params.url} > {output}"

def get_experiment_rna(w):
    x = w.sample.split("_")[0]
    return x.split("-")[0]

def get_replicate_rna(w):
    x = w.sample.split("_")[0]
    return x.split("-")[1]

def get_experiment_atac(w):
    x = w.sample.split("_")[-1]
    return x.split("-")[0]

def get_replicate_atac(w):
    x = w.sample.split("_")[-1]
    return x.split("-")[1]

rule query_fragments:
    """
    Query ENCODE portal for fragments URL
    """
    output:
        "results/{sample}/fetch/fragments_url.txt"
    params:
        experiment = get_experiment_atac,
        replicate = get_replicate_atac,
        dcc_mode = config["dcc_mode"],
        dcc_api_key = os.environ["DCC_API_KEY"], 
        dcc_secret_key = os.environ["DCC_SECRET_KEY"]
    log:
        directory("logs/{sample}/query_fragments")
    conda:
        "envs/fetch.yaml"
    script:
        "scripts/get_fragments_url.py"

rule download_fragments:
    """
    Download fragments tarball
    """
    input:
        "results/{sample}/fetch/fragments_url.txt"
    output:
        "results/{sample}/fetch/fragments.tar.gz"
    params:
        usr = os.environ["DCC_API_KEY"],
        pwd = os.environ["DCC_SECRET_KEY"]
    conda:
        "envs/fetch.yaml"
    shell:
        # "curl --no-progress-meter -L -u {params.usr}:{params.pwd} $(< {input}) > {output}"
        "curl --no-progress-meter -L $(< {input}) > {output}"

rule extract_fragments:
    """
    Extract fragments file from tarball
    """
    input:
        "results/{sample}/fetch/fragments.tar.gz"
    output:
        directory("results/{sample}/fetch/fragments_extracted")
    conda:
        "envs/fetch.yaml"
    shell:
        "mkdir -p {output}; "
        "tar -xzf {input} --transform='s/.*\///' -C {output}"

rule move_fragments:
    """
    Move fragments data to final location
    """
    input:
        "results/{sample}/fetch/fragments_extracted"
    output:
        frag = "results/{sample}/fetch/fragments.tsv.gz",
        frag_ind = "results/{sample}/fetch/fragments.tsv.gz.tbi"
    conda:
        "envs/fetch.yaml"
    shell:
        "cp {input}/fragments.tsv.gz {output.frag}; "
        "cp {input}/fragments.tsv.gz.tbi {output.frag_ind};"

rule query_expression:
    """
    Query ENCODE portal for gene expression matrix URL
    """
    output:
        "results/{sample}/fetch/expression_url.txt"
    params:
        experiment = get_experiment_rna,
        replicate = get_replicate_rna,
        dcc_mode = config["dcc_mode"],
        dcc_api_key = os.environ["DCC_API_KEY"], 
        dcc_secret_key = os.environ["DCC_SECRET_KEY"]
    log:
        directory("logs/{sample}/query_fragments")
    conda:
        "envs/fetch.yaml"
    script:
        "scripts/get_expression_url.py"

rule download_expression:
    """
    Download expression tarball
    """
    input:
        "results/{sample}/fetch/expression_url.txt"
    output:
        "results/{sample}/fetch/expression.tar.gz"
    params:
        usr = os.environ["DCC_API_KEY"],
        pwd = os.environ["DCC_SECRET_KEY"]
    conda:
        "envs/fetch.yaml"
    shell:
        # "curl --no-progress-meter -L -u {params.usr}:{params.pwd} $(< {input}) > {output}"
        "curl --no-progress-meter -L $(< {input}) > {output}"

rule extract_expression:
    """
    Extract expression data from tarball
    """
    input:
        "results/{sample}/fetch/expression.tar.gz"
    output:
        directory("results/{sample}/fetch/expression_extracted")
    conda:
        "envs/fetch.yaml"
    shell:
        "mkdir -p {output}; "
        "tar -xzf {input} --transform='s/.*\///' -C {output}"

rule move_expression:
    """
    Move expression data files to final location
    """
    input:
        "results/{sample}/fetch/expression_extracted"
    output:
        matrix = "results/{sample}/fetch/matrix.mtx",
        features = "results/{sample}/fetch/features.tsv",
        barcodes = "results/{sample}/fetch/barcodes.tsv",
    conda:
        "envs/fetch.yaml"
    shell:
        "cp {input}/matrix.mtx {output.matrix}; "
        "cp {input}/features.tsv {output.features}; "
        "cp {input}/barcodes.tsv {output.barcodes};"

rule download_rna_ref_counts:
    """
    Download reference expression tarball
    """
    output:
        "reference/fetch/expression.rds"
    params:
        url = config["rna_ref_counts"]
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L '{params.url}' > {output}"

rule download_rna_ref_metadata:
    """
    Download reference metadata
    """
    output:
        "reference/fetch/annotations.rds"
    params:
        url = config["rna_ref_metadata"]
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L '{params.url}' > {output}"

rule download_barcode_wl:
    """
    Download barcode whitelist
    """
    output:
        rna = "whitelists/rna.txt",
        atac = "whitelists/atac.txt"
    params:
        url_rna = config["bc_wl_rna"],
        url_atac = config["bc_wl_atac"],
    conda:
        "envs/fetch.yaml"
    shell:
        "curl --no-progress-meter -L '{params.url_rna}' | zcat -f > {output.rna}; "
        "curl --no-progress-meter -L '{params.url_atac}' | zcat -f > {output.atac}"