"""
Snakemake pipeline implementing the hierarchical varbook spec from CLAUDE.md.

Implements the 3-level hierarchy:
  variant_dataset -> model_dataset -> cluster (optional) -> variants

Directory structure generated:
  varbook_gen/
    {variant_dataset}/                           # e.g., Broad neurodevelopmental and neuromuscular disorders
      {model_dataset}/                           # e.g., KUN_FB
        heatmap/
          {model_dataset}.png
          {model_dataset}.md
        {cluster_name}/                          # e.g., "microglia-specific cluster (#3)"
          {variant_id}/
            model-scatterplot/
              {model_dataset}.html
              {model_dataset}.md
            model-specificity-barplot/
              {model_dataset}.png
              {model_dataset}.md
            profiles/
              {model}.png                        # One per prioritized model

Variant filtering logic:
  1. Start with all variants in variant_dataset TSV (superset)
  2. Filter to variants prioritized by ANY model matching model_dataset patterns
  3. If cluster specified, further filter to variants in that cluster

Cluster level is UNDER model_dataset because clusters are derived from
KMeans clustering on the specific models in that model_dataset.

Workflow pattern:
  for variant_dataset in variant_datasets:
    for model_dataset in model_datasets:
      - Generate heatmap for all prioritized variants
      - For each cluster (optional):
        - For each prioritized variant in cluster:
          - Generate model-scatterplot
          - Generate model-specificity-barplot
          - Generate profiles (one per prioritized model)
"""

import pandas as pd
import os
import hashlib
import json
from pathlib import Path

# ============================================================================
# Wildcard Constraints (prevent PeriodicWildcardError for variant IDs with colons)
# ============================================================================

# Global wildcard constraints to prevent infinite recursion errors
# Variant IDs like "chr15:42010132:G:GCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC" contain
# colons and repeated characters that trigger Snakemake's pattern detection
wildcard_constraints:
    variant_id="[^/]+",
    model_name="[^/]+",
    variant_dataset="[^/]+",
    model_dataset="[^/]+",
    variant_subdataset="[^/]+"

# ============================================================================
# Configuration
# ============================================================================

VENV_PYTHON = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/varbook-container/.venv/bin/python"
VARBOOK_CMD = f"{VENV_PYTHON} -m varbook"

# Input data
# VARIANTS_TSV is now automatically generated per variant_dataset (see get_variants_tsv_path function)
MODEL_PATHS_TSV = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/metadata_renaming/broad.model_paths.tsv"
GENOME_FA = "/oak/stanford/groups/akundaje/soumyak/refs/hg38/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta"

# Directory for split files (symlinked to match variant dataset names)
LOCAL_SPLITS_DIR = "splits"

# Output directory (VARBOOK_DEFAULT_OUTPUT_DIR)
OUTPUT_DIR = "varbook_gen"

# Use batch processing for profile generation (faster but requires testing)
USE_BATCH_PROFILES = True  # Set to True to enable batch mode

# Finemo annotation configuration
SPLITS_DIR = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/splits"
GENERAL_VARIANTS_TSV = f"{SPLITS_DIR}/broad.general.tsv"
PRIORITIZATION_TSV = f"{SPLITS_DIR}/broad.model_prioritized_by_any-KUN_FB.tsv"
MODISCO_H5 = "/oak/stanford/groups/akundaje/airanman/refs/motif-compendium/modisco_invitro_avg.070325.h5"
VARBOOK_DIR = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/varbook-container/varbook"

# ============================================================================
# Hierarchical Structure Configuration
# ============================================================================

# Configuration: each variant dataset has its own set of model datasets
#
# Format:
# VARIANT_DATASET_CONFIGS = {
#     "variant_dataset_name": {
#         'scatterplot_context_tsv': 'path/to/context.tsv',  # Optional: larger variant set for scatterplot context
#         'model_datasets': [
#             {
#                 'name': 'model_dataset_name',
#                 'models': ['pattern1', 'pattern2', ...],
#                 'clusters': ['cluster_1', 'cluster_2', ...],  # Optional
#             },
#             ...
#         ],
#     },
#     ...
# }
#
# OR (backward compatible):
# VARIANT_DATASET_CONFIGS = {
#     "variant_dataset_name": [
#         {
#             'name': 'model_dataset_name',
#             'models': ['pattern1', 'pattern2', ...],
#             'clusters': ['cluster_1', 'cluster_2', ...],  # Optional
#         },
#         ...
#     ],
#     ...
# }
#
# Examples:
#   models = ['KUN_FB*']                              -> name = "KUN_FB"
#   models = ['KUN_FB*', 'KUN_HDMA_Eye_c13*']        -> name = "KUN_FB and KUN_HDMA_Eye"
#   models = ['KUN_FB_microglia', 'KUN_FB_neuron']   -> name = "KUN_FB_microglia and KUN_FB_neuron"
#
# To include multiple model sets (e.g., KUN_FB & KUN_HDMA) in scatterplot context:
# 1. Run: snakemake "data/Broad neurodevelopmental and neuromuscular disorders_kun_fb_kun_hdma_scatterplot_context.tsv"
# 2. Add 'scatterplot_context_tsv' to your variant dataset config (see format above)
#    Example: 'scatterplot_context_tsv': 'data/Broad neurodevelopmental and neuromuscular disorders_kun_fb_kun_hdma_scatterplot_context.tsv'
#    This will show variants from all specified model sets in scatterplots for better context.
#
VARIANT_DATASET_CONFIGS = {
    "Broad neurodevelopmental and neuromuscular disorders": [
        {
            'name': 'Fetal Brain',
            'models': ['KUN_FB*'],  # Regex/glob patterns to match model names
            'model_superset': ['KUN_FB*', 'KUN_HDMA*'],  # Regex/glob patterns to match model names
            'clusters': [
                {
                    'id': 'cluster_3',  # Cluster identifier (matches column in TSV)
                    'name': 'glutamatergic neuron 7 GoF cluster (#3)',  # Human-readable name
                },
                {
                    'id': 'cluster_9',
                    'name': 'nIPC GoF cluster (#9)',
                },
                {
                    'id': 'cluster_10',
                    'name': 'early & late radial glia + oIPC GoF cluster (#10)',
                },
                {
                    'id': 'cluster_14',
                    'name': 'glutamatergic neuron 2 to 7 GoF cluster (#14)',
                },
                {
                    'id': 'cluster_15',
                    'name': 'glutamatergic neuron 2 to 7 + nIPC + interneurons 2 & 4 LoF cluster (#15)',
                },
                {
                    'id': 'cluster_17',
                    'name': 'early & late radial glia + nIPC + oIPC LoF cluster (#17)',
                },
                {
                    'id': 'cluster_27',
                    'name': 'glutamatergic neuron 1 to 7 + early & late radial glia + oIPC + nIPC + more LoF cluster (#27)',
                },
            ],
        },
    ],
}

# ============================================================================
# Helper Functions
# ============================================================================

def has_periodic_pattern(variant_id, threshold=10):
    """Detect if variant ID has periodic (repeated) character patterns.

    Snakemake's PeriodicWildcardError triggers when wildcard values contain
    many consecutive repeated characters (e.g., GCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC).

    Parameters:
    -----------
    variant_id : str
        Variant ID to check
    threshold : int
        Number of consecutive repeated chars that trigger detection (default: 10)

    Returns:
    --------
    bool
        True if variant has threshold+ consecutive repeated characters

    Examples:
    ---------
    >>> has_periodic_pattern("chr1:12345:A:T")
    False
    >>> has_periodic_pattern("chr15:42010132:G:GCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC")
    True
    """
    for char in set(variant_id):
        if char * threshold in variant_id:
            return True
    return False

def variant_id_to_hash(variant_id):
    """Convert variant ID to MD5 hash.

    Used for variants with periodic patterns to create safe directory names.

    Parameters:
    -----------
    variant_id : str
        Original variant ID

    Returns:
    --------
    str
        First 16 characters of MD5 hash
    """
    return hashlib.md5(variant_id.encode('utf-8')).hexdigest()[:16]

def get_safe_variant_id(variant_id):
    """Get safe variant ID for use in file paths.

    Returns hashed ID only for variants with periodic patterns that would
    trigger Snakemake's PeriodicWildcardError. Most variants use original ID.

    Parameters:
    -----------
    variant_id : str
        Original variant ID

    Returns:
    --------
    str
        Either original ID or "variant_{hash}" for problematic variants

    Examples:
    ---------
    >>> get_safe_variant_id("chr1:12345:A:T")
    'chr1:12345:A:T'
    >>> get_safe_variant_id("chr15:42010132:G:GCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC")
    'variant_3dd10fe1cfa1c612'
    """
    if has_periodic_pattern(variant_id):
        return f"variant_{variant_id_to_hash(variant_id)}"
    return variant_id

def get_original_variant_id(safe_variant_id, model_dataset_config, cluster_id=None, variant_dataset=None):
    """Reverse-map from safe variant ID back to original variant ID.

    For normal variant IDs (not hashed), this returns the ID as-is.
    For hashed IDs (starting with 'variant_'), searches all variants to find
    which one hashes to this value.

    Parameters:
    -----------
    safe_variant_id : str
        The safe variant ID (either original or hashed)
    model_dataset_config : dict
        Model dataset configuration
    cluster_id : str, optional
        Cluster ID for filtering variants
    variant_dataset : str, optional
        Variant dataset name

    Returns:
    --------
    str
        Original variant ID
    """
    # If it doesn't start with 'variant_', it's already the original ID
    if not safe_variant_id.startswith('variant_'):
        return safe_variant_id

    # It's a hashed ID - search all variants to find the original
    all_variants = get_prioritized_variants(model_dataset_config, cluster_id, variant_dataset)

    for variant in all_variants:
        if get_safe_variant_id(variant) == safe_variant_id:
            return variant

    # If not found, return the safe ID (shouldn't happen in normal operation)
    print(f"Warning: Could not find original variant for {safe_variant_id}")
    return safe_variant_id

def get_variants_tsv_path(variant_dataset):
    """Get the path to the merged variants TSV for a variant_dataset.

    This TSV is automatically generated by merging all required annotation files
    based on VARIANT_DATASET_CONFIGS.
    """
    return f"data/{variant_dataset}.variants_merged.tsv"

def get_variant_dataset_config(variant_dataset):
    """Get the configuration for a variant dataset, handling both list and dict formats.
    
    Supports:
    - List format (backward compatible): [{'name': 'model_dataset', ...}, ...]
    - Dict format (new): {'scatterplot_context_tsv': 'path', 'model_datasets': [...]}
    
    Returns:
    --------
    dict or list
        The config value for the variant dataset
    """
    if variant_dataset not in VARIANT_DATASET_CONFIGS:
        raise ValueError(f"Unknown variant_dataset: {variant_dataset}")
    return VARIANT_DATASET_CONFIGS[variant_dataset]

def get_model_datasets_list(variant_dataset):
    """Get the list of model dataset configs for a variant dataset.
    
    Handles both list and dict formats in VARIANT_DATASET_CONFIGS.
    
    Returns:
    --------
    list
        List of model dataset config dicts
    """
    config = get_variant_dataset_config(variant_dataset)
    if isinstance(config, dict) and 'model_datasets' in config:
        return config['model_datasets']
    elif isinstance(config, list):
        return config
    else:
        raise ValueError(f"Invalid config format for variant_dataset '{variant_dataset}': expected list or dict with 'model_datasets' key")

def get_scatterplot_context_tsv(variant_dataset):
    """Get the scatterplot context TSV path for a variant dataset, if specified.
    
    The context TSV contains a larger set of variants to show in scatterplots for context.
    This is optional and specified at the variant dataset level.
    
    Returns:
    --------
    str or None
        Path to context TSV file, or None if not specified
    """
    config = get_variant_dataset_config(variant_dataset)
    if isinstance(config, dict) and 'scatterplot_context_tsv' in config:
        return config['scatterplot_context_tsv']
    return None

def get_least_used_gpu():
    """
    Find the GPU with the least memory usage.
    
    Returns:
    --------
    int
        GPU device ID (0, 1, 2, or 3) with least memory usage
    """
    import subprocess
    import re
    
    try:
        # Run nvidia-smi to get GPU memory usage
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=index,memory.used', '--format=csv,noheader,nounits'],
            capture_output=True,
            text=True,
            check=True
        )
        
        # Parse output: "0, 1234\n1, 5678\n..."
        gpu_usage = []
        for line in result.stdout.strip().split('\n'):
            if line.strip():
                parts = line.split(',')
                if len(parts) == 2:
                    gpu_id = int(parts[0].strip())
                    memory_used = int(parts[1].strip())
                    gpu_usage.append((gpu_id, memory_used))
        
        if not gpu_usage:
            # Fallback to GPU 2 if we can't parse
            print("Warning: Could not parse nvidia-smi output, defaulting to GPU 2")
            return 2
        
        # Sort by memory usage (ascending) and return GPU with least usage
        gpu_usage.sort(key=lambda x: x[1])
        least_used_gpu = gpu_usage[0][0]
        least_used_memory = gpu_usage[0][1]
        
        print(f"GPU usage: {dict(gpu_usage)}")
        print(f"Selected GPU {least_used_gpu} (least used: {least_used_memory} MB)")
        
        return least_used_gpu
        
    except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
        # Fallback to GPU 2 if nvidia-smi fails
        print(f"Warning: Could not determine GPU usage ({e}), defaulting to GPU 2")
        return 2

def get_scatterplot_input_tsv(wildcards, default_tsv):
    """Get the input TSV for scatterplot generation.
    
    If a scatterplot_context_tsv is specified at the variant dataset level,
    use that for a larger context. Otherwise, use the default filtered TSV.
    
    Parameters:
    -----------
    wildcards : object
        Snakemake wildcards object with variant_dataset attribute
    default_tsv : str
        Default TSV path to use if no context TSV is specified
    
    Returns:
    --------
    str
        Path to TSV file to use for scatterplot
    """
    context_tsv = get_scatterplot_context_tsv(wildcards.variant_dataset)
    if context_tsv:
        return context_tsv
    return default_tsv

def get_clustered_variants_tsv_path(variant_dataset, model_dataset):
    """Get the path to the clustered variants TSV for cluster-specific plots.

    The clustered TSV contains additional columns like 'organs' and 'kmeans_35'
    that are required by varbook for cluster-level variant plots.
    """
    return f"data/{variant_dataset}.{model_dataset}.clustered.tsv"

def _get_prioritization_tsv_path(variant_dataset, model_dataset):
    """Get the path to the prioritization TSV file, checking for both singular and plural naming.
    
    Returns a dummy file path if neither exists. The rule will handle missing files gracefully.
    """
    import os
    
    # Try plural first (models_prioritized_by_any)
    plural_path = f"{SPLITS_DIR}/{variant_dataset}.models_prioritized_by_any-{model_dataset}.tsv"
    if os.path.exists(plural_path):
        return plural_path
    
    # Try singular (model_prioritized_by_any)
    singular_path = f"{SPLITS_DIR}/{variant_dataset}.model_prioritized_by_any-{model_dataset}.tsv"
    if os.path.exists(singular_path):
        return singular_path
    
    # If neither exists, create a dummy empty file that always exists
    # This allows Snakemake to build the DAG, and the rule will skip it gracefully
    dummy_path = f"/tmp/.dummy_prioritization_{variant_dataset}_{model_dataset}.tsv"
    if not os.path.exists(dummy_path):
        # Create empty dummy file
        os.makedirs(os.path.dirname(dummy_path), exist_ok=True)
        with open(dummy_path, 'w') as f:
            f.write("variant_id\n")  # Minimal valid TSV
    return dummy_path

def get_required_annotation_files(variant_dataset):
    """Determine which annotation files are needed for a variant_dataset.
    
    Based on VARIANT_DATASET_CONFIGS, determines which prioritization files
    and other annotations are required.
    
    Returns:
    --------
    list of str
        List of file paths to merge (relative to SPLITS_DIR or absolute paths)
    """
    import fnmatch
    import os
    
    if variant_dataset not in VARIANT_DATASET_CONFIGS:
        raise ValueError(f"Unknown variant_dataset: {variant_dataset}")
    
    # Start with core annotation files that are always needed
    required_files = []
    
    # 1. General variant information (always needed)
    # Use variant_dataset-specific general.tsv file
    general_file = f"{SPLITS_DIR}/{variant_dataset}.general.tsv"
    if not os.path.exists(general_file):
        raise FileNotFoundError(f"General file not found for variant_dataset '{variant_dataset}': {general_file}")
    required_files.append(general_file)
    
    # 2. Collect all unique model patterns from all model_datasets
    all_model_patterns = set()
    for model_dataset_config in get_model_datasets_list(variant_dataset):
        model_patterns = model_dataset_config.get('models', [])
        all_model_patterns.update(model_patterns)
    
    # 3. Find all prioritization files that match the model patterns
    # Look for files like: {variant_dataset}.model_prioritized_by_any-{dataset}.tsv
    # or broad.model_prioritized_by_any-{dataset}.tsv
    # Also check for models_prioritized_by_any-* files
    prioritization_files = set()
    
    for prefix in [variant_dataset]:
        for pattern in all_model_patterns:
            # Try exact match first (e.g., KUN_FB matches KUN_FB)
            exact_file = f"{SPLITS_DIR}/{prefix}.model_prioritized_by_any-{pattern.replace('*', '')}.tsv"
            if os.path.exists(exact_file):
                prioritization_files.add(exact_file)
            
            # Try pattern matching (e.g., KUN_FB* matches KUN_FB_microglia)
            # List all files and check if they match the pattern
            try:
                for filename in os.listdir(SPLITS_DIR):
                    if filename.endswith('.tsv'):
                        # Check for model_prioritized_by_any-* or models_prioritized_by_any-*
                        if 'model_prioritized_by_any-' in filename or 'models_prioritized_by_any-' in filename:
                            # Extract dataset/model name from filename
                            # Format: {prefix}.model_prioritized_by_any-{name}.tsv
                            if filename.startswith(prefix + '.'):
                                parts = filename.replace('.tsv', '').split('.model')
                                if len(parts) >= 2:
                                    # Get the part after model_prioritized_by_any-
                                    name_part = parts[1].split('_prioritized_by_any-')
                                    if len(name_part) == 2:
                                        model_name = name_part[1]
                                        # Check if model_name matches the pattern
                                        if fnmatch.fnmatch(model_name, pattern):
                                            filepath = os.path.join(SPLITS_DIR, filename)
                                            if os.path.exists(filepath):
                                                prioritization_files.add(filepath)
            except OSError:
                # Directory might not exist or be accessible
                pass
    
    required_files.extend(sorted(prioritization_files))
    
    # 4. Add logfc and aaq files (needed for scatterplots)
    logfc_file = f"{SPLITS_DIR}/{variant_dataset}.logfc.tsv"
    aaq_file = f"{SPLITS_DIR}/{variant_dataset}.aaq.tsv"
    
    if os.path.exists(logfc_file):
        required_files.append(logfc_file)
    else:
        logfc_fallback = f"{SPLITS_DIR}/broad.logfc.tsv"
        if os.path.exists(logfc_fallback):
            required_files.append(logfc_fallback)
    
    if os.path.exists(aaq_file):
        required_files.append(aaq_file)
    else:
        aaq_fallback = f"{SPLITS_DIR}/broad.aaq.tsv"
        if os.path.exists(aaq_fallback):
            required_files.append(aaq_fallback)
    
    # 5. Add clustering columns if clusters are specified
    # Clustering columns are typically added during heatmap generation,
    # but we should check if they exist in a separate file
    # For now, clustering info is added during heatmap generation
    
    return required_files

def get_variant_datasets_param(variant_dataset, cluster_id=None):
    """Get the --variant-datasets parameter value for varbook commands.

    Format: variant_dataset or variant_dataset:cluster_id

    Note: The model_dataset is passed separately via --model-dataset parameter.
    """
    if cluster_id:
        return f"{variant_dataset}:{cluster_id}"
    else:
        return variant_dataset

def get_base_output_path(variant_dataset, model_dataset_name, cluster=None):
    """Get the base output path.

    Returns:
      - {OUTPUT_DIR}/{variant_dataset}/{model_dataset}/ (for heatmaps)
      - {OUTPUT_DIR}/{variant_dataset}/{model_dataset}/{cluster}/ (for clustered variants)
    """
    if cluster:
        return f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster}"
    else:
        return f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}"

def get_models_string(model_dataset_config):
    """Get a string representation of models for file naming.

    Uses the model_dataset name, sanitized for filesystem use.
    """
    name = model_dataset_config['name']
    return name.replace(' ', '_').replace('/', '_')

def get_models_args(model_dataset_config, use_superset=False):
    """Get the --models argument for varbook commands.

    Parameters:
    -----------
    model_dataset_config : dict
        Model dataset configuration
    use_superset : bool, optional
        If True, use 'model_superset' field (for superset-level plots).
        If False, use 'models' field (for cluster-level plots).
        Default is False.

    Returns:
    --------
    str
        Model patterns/regexes as a space-separated string with proper quoting.
    """
    if use_superset:
        # For superset-level plots, use model_superset if available, otherwise fall back to models
        models = model_dataset_config.get('model_superset') or model_dataset_config.get('models', [])
    else:
        # For cluster-level plots, use models field
        models = model_dataset_config.get('models', [])
    
    if not models:
        field_name = 'model_superset' if use_superset else 'models'
        raise ValueError(f"Model dataset config must have '{field_name}': {model_dataset_config}")
    
    # Quote each pattern to prevent shell glob expansion
    return " ".join(f'"{m}"' for m in models)

# Alias for backward compatibility (if needed elsewhere)
def get_scatterplot_superset_models_args(model_dataset_config):
    """Get the --models argument for superset-level scatterplots.
    
    This is an alias for get_models_args(model_dataset_config, use_superset=True).
    """
    return get_models_args(model_dataset_config, use_superset=True)

def get_prioritized_variants(model_dataset_config, cluster_id=None, variant_dataset=None):
    """Get list of variants prioritized by ANY model in the model_dataset.

    Filtering logic:
    1. Start with all variants in the TSV (superset from variant_dataset)
    2. Filter to variants prioritized by at least 1 model matching model_dataset patterns
    3. If cluster_id specified, further filter to variants in that cluster

    Parameters:
    -----------
    model_dataset_config : dict
        Configuration with 'models' patterns
    cluster_id : str, optional
        Cluster identifier (e.g., 'cluster_3')
    variant_dataset : str, optional
        Variant dataset name. If None, tries to infer from context.

    Returns:
    --------
    list of str
        List of variant IDs
    """
    # Determine variant_dataset if not provided
    if variant_dataset is None:
        # Try to find variant_dataset from VARIANT_DATASET_CONFIGS
        for vd, configs in VARIANT_DATASET_CONFIGS.items():
            if model_dataset_config in configs:
                variant_dataset = vd
                break
        if variant_dataset is None:
            raise ValueError("Could not determine variant_dataset. Please provide it explicitly.")

    # When cluster_id is specified, read from clustered.tsv (has kmeans_35 column)
    # Otherwise read from base variants_merged.tsv
    #
    # During DAG construction, these files may not exist yet (they're created by checkpoint rules).
    # In that case, fall back to reading from splits/ directory for variant lists.
    if cluster_id:
        model_dataset_name = model_dataset_config.get('name')
        if not model_dataset_name:
            raise ValueError("model_dataset_config must have 'name' field for cluster filtering")
        clustered_tsv = f"data/{variant_dataset}.{model_dataset_name}.clustered.tsv"
        if not os.path.exists(clustered_tsv):
            # File will be created by cluster_model_dataset checkpoint
            # For now, read from splits/ to get variant list during DAG construction
            variants_tsv = None  # Will use fallback logic below
        else:
            variants_tsv = clustered_tsv
    else:
        variants_tsv = get_variants_tsv_path(variant_dataset)
        if not os.path.exists(variants_tsv):
            variants_tsv = None  # Will use fallback logic below

    # If data/ file exists, read it. Otherwise, read from splits/ during DAG construction.
    if variants_tsv and os.path.exists(variants_tsv):
        df = pd.read_csv(variants_tsv, sep='\t')
    else:
        # Fallback: Read from splits/ directory (for DAG construction before checkpoint rules run)
        # We need the prioritization file to get variant_id and model_prioritized_by_any-* columns
        import fnmatch

        # Determine which prioritization file to use based on model patterns
        model_patterns = model_dataset_config.get('models', [])

        # Find appropriate prioritization file in splits/
        # For KUN_FB* models, use broad.model_prioritized_by_any-KUN_FB.tsv
        # Extract dataset prefix from first pattern (e.g., "KUN_FB*" -> "KUN_FB")
        dataset_prefix = None
        if model_patterns:
            first_pattern = model_patterns[0]
            # Remove trailing wildcards
            dataset_prefix = first_pattern.rstrip('*')

        if dataset_prefix:
            splits_file = f"{SPLITS_DIR}/broad.model_prioritized_by_any-{dataset_prefix}.tsv"
            if not os.path.exists(splits_file):
                # Fall back to full prioritization file
                splits_file = f"{SPLITS_DIR}/broad.model_prioritized_by_any.tsv"
        else:
            splits_file = f"{SPLITS_DIR}/broad.model_prioritized_by_any.tsv"

        if not os.path.exists(splits_file):
            raise FileNotFoundError(
                f"Neither data/ file nor splits/ file found.\n"
                f"Expected: {variants_tsv or 'data file'}\n"
                f"Or fallback: {splits_file}\n"
                f"Run merge_variants_tsv checkpoint first to create data/ files."
            )

        df = pd.read_csv(splits_file, sep='\t')

    # Step 1: Expand model patterns to actual model names
    # We need to check which models from the patterns are in the TSV
    model_patterns = model_dataset_config.get('models', [])

    # Get all models that match the patterns by checking column names
    # Note: Column format is "model_prioritized_by_any-{model_name}" (singular "model")
    matched_models = set()
    for col in df.columns:
        if col.startswith('model_prioritized_by_any-'):
            model_name = col.replace('model_prioritized_by_any-', '')
            # Check if this model name matches any pattern
            import fnmatch
            for pattern in model_patterns:
                if fnmatch.fnmatch(model_name, pattern):
                    matched_models.add(model_name)
                    break

    if not matched_models:
        print(f"Warning: No models matched patterns {model_patterns}")
        return []

    # Step 2: Filter to variants prioritized by at least one matched model
    # A variant is prioritized if its model_prioritized_by_any-{model_name} column is "true"
    prioritized_mask = pd.Series([False] * len(df))

    for model in matched_models:
        col = f'model_prioritized_by_any-{model}'
        if col in df.columns:
            # Variant is prioritized if this column value is "true" (string, case-insensitive)
            prioritized_mask |= (df[col].astype(str).str.lower() == 'true')

    df_prioritized = df[prioritized_mask].copy()
    # Only print if this is a significant operation (not during DAG construction)
    # Suppress verbose output that gets repeated many times
    # print(f"Found {len(df_prioritized)} variants prioritized by matched models")

    # Step 3: If cluster specified, filter to that cluster
    if cluster_id:
        # Look for cluster column (format: kmeans_cluster_35 or kmeans_35-{dataset})
        # Extract numeric cluster number from cluster_id
        # Handles formats: 'cluster_3', 'glutamatergic neuron 7 GoF cluster (#3)', etc.
        cluster_num = None
        if cluster_id.startswith('cluster_'):
            try:
                cluster_num = int(cluster_id.replace('cluster_', ''))
            except ValueError:
                pass
        else:
            # Try to extract from "(#N)" or "cluster (#N)" format
            import re
            match = re.search(r'\(#(\d+)\)', cluster_id)
            if match:
                cluster_num = int(match.group(1))

        cluster_col = None
        for col in df_prioritized.columns:
            if col.startswith('kmeans_'):
                # Check if this column contains the cluster value
                # Try both numeric and string formats
                if cluster_num is not None and cluster_num in df_prioritized[col].unique():
                    cluster_col = col
                    cluster_value = cluster_num
                    break
                elif cluster_id in df_prioritized[col].astype(str).unique():
                    cluster_col = col
                    cluster_value = cluster_id
                    break

        if cluster_col:
            df_filtered = df_prioritized[df_prioritized[cluster_col] == cluster_value]
            # Suppress verbose output that gets repeated many times
            # print(f"Filtered to {len(df_filtered)} variants in {cluster_id} (column: {cluster_col})")
            return df_filtered['variant_id'].tolist()
        else:
            print(f"Warning: No cluster column found for {cluster_id}, using all prioritized variants")
            return df_prioritized['variant_id'].tolist()

    return df_prioritized['variant_id'].tolist()

def get_variant_prioritized_models(variant_id, model_dataset_config, variant_dataset=None):
    """Get list of models that prioritized this variant in the model_dataset.

    Returns list of model names for profile generation.
    """
    # Determine variant_dataset if not provided
    if variant_dataset is None:
        # Try to find variant_dataset from VARIANT_DATASET_CONFIGS
        for vd, configs in VARIANT_DATASET_CONFIGS.items():
            if model_dataset_config in configs:
                variant_dataset = vd
                break
        if variant_dataset is None:
            raise ValueError("Could not determine variant_dataset. Please provide it explicitly.")
    
    variants_tsv = get_variants_tsv_path(variant_dataset)
    df = pd.read_csv(variants_tsv, sep='\t')
    variant_row = df[df['variant_id'] == variant_id]

    if len(variant_row) == 0:
        return []

    models = []
    # Get all models_prioritized_by_any columns that exist
    prioritized_cols = [col for col in df.columns if col.startswith('models_prioritized_by_any-')]

    for col in prioritized_cols:
        models_str = variant_row[col].iloc[0]
        if pd.notna(models_str) and str(models_str).strip():
            # Parse: ";MODEL1(score);MODEL2(score);..."
            model_entries = [m.strip() for m in str(models_str).split(';') if m.strip()]
            for entry in model_entries:
                if '(' in entry:
                    model = entry.split('(')[0]
                    if model:  # Only add non-empty model names
                        models.append(model)

    # TODO: In production, filter models based on model_dataset_config['models'] patterns
    # For now, return all prioritized models
    return models

def get_prioritized_models_for_variant(variant_dataset, variant_id, model_dataset_name, prioritization_tsv=None):
    """
    Get list of models that prioritize a specific variant.

    Only generates profiles for models that:
    1. Prioritize the variant (in prioritization TSV)
    2. Are in the specified model_dataset

    Parameters:
    -----------
    variant_dataset : str
        Variant dataset name (e.g., "Broad neurodevelopmental and neuromuscular disorders")
    variant_id : str
        Variant ID (e.g., "chr1:123:A:G")
    model_dataset_name : str
        Model dataset name (e.g., "Fetal Brain")
    prioritization_tsv : str, optional
        Path to prioritization TSV. If None, uses VARIANTS_TSV.

    Returns:
    --------
    list of str
        Model names that prioritize this variant and exist in the model_dataset
    """
    import pandas as pd
    import fnmatch

    # Get model_patterns for this model_dataset from config
    dataset_config = get_model_datasets_list(variant_dataset) if variant_dataset in VARIANT_DATASET_CONFIGS else []
    model_dataset_config = None
    for config in dataset_config:
        if config['name'] == model_dataset_name:
            model_dataset_config = config
            break

    if model_dataset_config is None:
        print(f"Warning: No config found for model_dataset '{model_dataset_name}' in variant_dataset '{variant_dataset}'")
        return []

    model_patterns = model_dataset_config.get('models', [])

    # Read prioritization file
    if prioritization_tsv is None:
        prioritization_tsv = get_variants_tsv_path(variant_dataset)

    df = pd.read_csv(prioritization_tsv, sep='\t')
    variant_row = df[df['variant_id'] == variant_id]

    if variant_row.empty:
        return []

    # Get prioritized models for this variant
    prioritized = []

    # Check all models_prioritized_by_any-* columns
    for col in df.columns:
        if col.startswith('models_prioritized_by_any-'):
            models_str = variant_row[col].iloc[0]
            if pd.notna(models_str) and str(models_str).strip():
                # Parse: ";MODEL1(score);MODEL2(score);..."
                model_entries = [m.strip() for m in str(models_str).split(';') if m.strip()]
                for entry in model_entries:
                    if '(' in entry:
                        model = entry.split('(')[0]
                        if model:  # Only add non-empty model names
                            # Check if model matches ANY pattern in this model_dataset
                            for pattern in model_patterns:
                                if fnmatch.fnmatch(model, pattern):
                                    prioritized.append(model)
                                    break

    return list(set(prioritized))  # Remove duplicates

def get_all_models_for_barplot(variant_dataset):
    """
    Get ALL model names from KUN_HDMA and KUN_FB prioritization TSVs.

    This extracts all model names from column headers like:
    'model_prioritized_by_any-KUN_FB_microglia'

    The barplot needs ALL models (both prioritized and unprioritized) to show
    total organ counts vs prioritized organ counts.

    TEMPORARY: Only using KUN_HDMA and KUN_FB models (not ENC_ENCODE or KUN_THYROID).

    Parameters:
    -----------
    variant_dataset : str
        Variant dataset name (e.g., "Broad neurodevelopmental and neuromuscular disorders")

    Returns:
    --------
    list of str
        All model names found in KUN_HDMA and KUN_FB TSV column headers
    """
    import pandas as pd

    # Use the per-model prioritization files for KUN_HDMA and KUN_FB
    kun_fb_file = f"/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/splits/Broad neurodevelopmental and neuromuscular disorders.model_prioritized_by_any-KUN_FB.tsv"
    kun_hdma_file = f"/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/splits/Broad neurodevelopmental and neuromuscular disorders.model_prioritized_by_any-KUN_HDMA.tsv"

    all_models = []

    # Read KUN_HDMA models
    if os.path.exists(kun_hdma_file):
        df_hdma = pd.read_csv(kun_hdma_file, sep='\t', nrows=0)
        for col in df_hdma.columns:
            if col.startswith('model_prioritized_by_any-'):
                model_name = col.replace('model_prioritized_by_any-', '')
                all_models.append(model_name)
        print(f"Found {len([m for m in all_models if m.startswith('KUN_HDMA')])} KUN_HDMA models")
    else:
        print(f"Warning: KUN_HDMA file not found: {kun_hdma_file}")

    # Read KUN_FB models
    if os.path.exists(kun_fb_file):
        df_fb = pd.read_csv(kun_fb_file, sep='\t', nrows=0)
        for col in df_fb.columns:
            if col.startswith('model_prioritized_by_any-'):
                model_name = col.replace('model_prioritized_by_any-', '')
                all_models.append(model_name)
        print(f"Found {len([m for m in all_models if m.startswith('KUN_FB')])} KUN_FB models")
    else:
        print(f"Warning: KUN_FB file not found: {kun_fb_file}")

    print(f"Total: {len(all_models)} models (KUN_HDMA + KUN_FB)")
    return all_models

def get_all_models_args_for_barplot(variant_dataset):
    """
    Get --models argument with ALL models for comprehensive barplot.

    The barplot shows organ distribution across ALL models, highlighting which
    ones prioritize the variant. This requires passing all model names.

    Returns:
    --------
    str
        Space-separated quoted model names for shell command
    """
    models = get_all_models_for_barplot(variant_dataset)
    if not models:
        return ""
    # Quote each model name to prevent shell issues
    return " ".join(f'"{m}"' for m in models)

def get_cluster_id_from_name(variant_dataset, model_dataset_name, cluster_name):
    """Reverse lookup: cluster name -> cluster ID.

    Parameters:
    -----------
    variant_dataset : str
        Variant dataset name (e.g., "Broad neurodevelopmental and neuromuscular disorders")
    model_dataset_name : str
        Model dataset name (e.g., "Fetal Brain")
    cluster_name : str
        Human-readable cluster name (e.g., "microglia-specific cluster (#3)")

    Returns:
    --------
    str or None
        Cluster ID (e.g., "cluster_3") or None if not found
    """
    if variant_dataset not in VARIANT_DATASET_CONFIGS:
        return None

    for config in get_model_datasets_list(variant_dataset):
        if config['name'] == model_dataset_name:
            for cluster in config.get('clusters', []):
                if isinstance(cluster, dict) and cluster.get('name') == cluster_name:
                    return cluster.get('id')
    return None

def get_cluster_name_from_id(variant_dataset, model_dataset_name, cluster_id):
    """Lookup: cluster ID -> cluster name.
    
    Parameters:
    -----------
    variant_dataset : str
        Variant dataset name
    model_dataset_name : str
        Model dataset name
    cluster_id : str or int
        Cluster ID (e.g., "cluster_3" or 3)
    
    Returns:
    --------
    str or None
        Cluster name (e.g., "microglia-specific cluster (#3)") or None if not found
    """
    if variant_dataset not in VARIANT_DATASET_CONFIGS:
        return None
    
    # Convert cluster_id to string format if it's an integer
    if isinstance(cluster_id, int):
        cluster_id_str = f"cluster_{cluster_id}"
    else:
        cluster_id_str = cluster_id
    
    for config in get_model_datasets_list(variant_dataset):
        if config['name'] == model_dataset_name:
            for cluster in config.get('clusters', []):
                if isinstance(cluster, dict) and cluster.get('id') == cluster_id_str:
                    return cluster.get('name')
    return None

def get_variants_prioritized_by_model(variant_dataset, model_name, model_dataset_config, cluster_id=None):
    """Get list of variant IDs prioritized by a specific model.

    This is the INVERSE of get_prioritized_models_for_variant():
    - get_prioritized_models_for_variant: variant -> [models]
    - get_variants_prioritized_by_model: model -> [variants]

    Parameters:
    -----------
    variant_dataset : str
        Variant dataset name
    model_name : str
        Specific model name (e.g., "KUN_FB_microglia")
    model_dataset_config : dict
        Model dataset configuration with 'models' patterns
    cluster_id : str, optional
        Cluster ID for filtering (e.g., 'cluster_3')

    Returns:
    --------
    list of str
        Variant IDs prioritized by this model (and optionally in cluster)
    """
    import pandas as pd

    # Step 1: Get ALL prioritized variants for the model_dataset + cluster
    all_variants = get_prioritized_variants(model_dataset_config, cluster_id, variant_dataset)

    if not all_variants:
        return []

    # Step 2: Read TSV and filter to variants prioritized by THIS specific model
    # Determine variant_dataset from config
    variant_dataset_name = None
    for vd, configs in VARIANT_DATASET_CONFIGS.items():
        for cfg in configs:
            if cfg == model_dataset_config:
                variant_dataset_name = vd
                break
        if variant_dataset_name:
            break
    
    if variant_dataset_name is None:
        raise ValueError("Could not determine variant_dataset from model_dataset_config")
    
    variants_tsv = get_variants_tsv_path(variant_dataset_name)
    df = pd.read_csv(variants_tsv, sep='\t')

    # Filter to variants in our list
    df_filtered = df[df['variant_id'].isin(all_variants)].copy()

    # Step 3: Check if the specific model column exists and get prioritized variants
    prioritized_variants = []

    # Direct column lookup - more efficient than iterating
    model_col = f'model_prioritized_by_any-{model_name}'

    if model_col in df_filtered.columns:
        # Filter variants where this model's column is "true" (string, not boolean)
        prioritized_mask = (df_filtered[model_col].astype(str).str.lower() == 'true')
        prioritized_variants = df_filtered[prioritized_mask]['variant_id'].tolist()

    return prioritized_variants

# ============================================================================
# Finemo Helper Functions
# ============================================================================

# Cache for clustered.tsv reads across multiple model calls
# Key: (variant_dataset, model_dataset_name) -> df_clustered
_clustered_tsv_cache = {}

def get_kun_fb_models(model_paths_tsv):
    """Get list of all KUN_FB models from model_paths_tsv."""
    import pandas as pd
    df = pd.read_csv(model_paths_tsv, sep='\t')
    kun_fb_models = df[df['model_name'].str.startswith('KUN_FB')]['model_name'].unique().tolist()
    return kun_fb_models

def get_finemo_prioritized_variants(prioritization_tsv, model_name):
    """Get list of variants prioritized by this model."""
    import pandas as pd
    df = pd.read_csv(prioritization_tsv, sep='\t')
    priority_col = f'model_prioritized_by_any-{model_name}'
    if priority_col not in df.columns:
        return []
    return df[df[priority_col] == True]['variant_id'].tolist()

def get_finemo_tsv_path(wildcards):
    """Get path to finemo TSV file, mapping model_name wildcard to model.
    
    The finemo rule uses wildcard {model}, but profile rules use {model_name}.
    This function bridges that gap by using the exact pattern from the finemo
    rule output, allowing Snakemake to infer the dependency correctly.
    
    IMPORTANT: This returns a pattern that matches the finemo rule output.
    Snakemake should be able to match {model} from annotate_finemo_split_file
    to {model_name} from the profile rules when the file paths are identical.
    """
    # Extract model name from wildcards (handles both model_name and model)
    model = getattr(wildcards, 'model_name', getattr(wildcards, 'model', None))
    if model is None:
        raise ValueError("Wildcards must have either 'model_name' or 'model' attribute")
    # Return the exact pattern that matches annotate_finemo_split_file output
    # The file path is identical, so Snakemake should be able to match it
    return f"{SPLITS_DIR}/finemo/broad.finemo.{model}.tsv"

def get_finemo_new_variants(prioritized_variants, output_file):
    """Get list of new variants to process (not in existing output)."""
    import pandas as pd
    import os

    if not os.path.exists(output_file):
        return prioritized_variants

    existing_df = pd.read_csv(output_file, sep='\t')
    already_processed = set(existing_df['variant_id'].tolist())

    return [v for v in prioritized_variants if v not in already_processed]

# Get list of all KUN_FB models
KUN_FB_MODELS = get_kun_fb_models(MODEL_PATHS_TSV)

def get_demanded_variants_by_model():
    """Get variants demanded by VARIANT_DATASET_CONFIGS, grouped by model.

    Returns dict mapping model_name -> set of variant_ids that should be
    processed for that model (prioritized by THAT SPECIFIC MODEL AND in configured clusters).
    
    Each model only gets variants that it actually prioritized, not all variants
    from clusters where any matching model prioritized variants.
    
    NOTE: This processes ALL models. For a single model, use get_demanded_variants_for_model() instead.
    """
    import fnmatch

    demanded_variants = {}  # model_name -> set of variant_ids

    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        model_dataset_configs = get_model_datasets_list(variant_dataset)
        for model_dataset_config in model_dataset_configs:
            model_patterns = model_dataset_config.get('models', [])

            # Get clusters for this model_dataset (or [None] if no clustering)
            clusters = model_dataset_config.get('clusters', [])
            if not clusters:
                clusters = [None]

            for cluster in clusters:
                # Extract cluster_id
                if cluster is None:
                    cluster_id = None
                elif isinstance(cluster, dict):
                    cluster_id = cluster.get('id')
                else:
                    cluster_id = cluster

                # For each model that matches the patterns, get variants it specifically prioritized
                for pattern in model_patterns:
                    for model in KUN_FB_MODELS:
                        if fnmatch.fnmatch(model, pattern):
                            # Add this model to demanded_variants if not present
                            if model not in demanded_variants:
                                demanded_variants[model] = set()

                            # Get variants prioritized by THIS SPECIFIC MODEL (not all variants in cluster)
                            model_variants = get_variants_prioritized_by_model(
                                variant_dataset, model, model_dataset_config, cluster_id
                            )
                            demanded_variants[model].update(model_variants)

    return demanded_variants

def get_demanded_variants_for_model(model_name):
    """Get variants demanded for a SPECIFIC model only.
    
    This is much more efficient than get_demanded_variants_by_model() when you only
    need variants for one model, as it only processes that model's clusters.
    
    Uses clustered.tsv which contains both cluster assignments (kmeans_35) and
    prioritization columns (model_prioritized_by_any-{model}) for all matched models.
    This avoids reading general.tsv which doesn't contain prioritization information.
    
    OPTIMIZATION: Reads clustered TSV only ONCE per variant_dataset/model_dataset
    and reuses it for all clusters, rather than reading it once per cluster.
    
    Parameters:
    -----------
    model_name : str
        Specific model name (e.g., "KUN_FB_microglia")
    
    Returns:
    --------
    set of str
        Set of variant IDs that should be processed for this model
    """
    import fnmatch
    import pandas as pd
    import os

    demanded_variants = set()
    print(f"DEBUG get_demanded_variants_for_model: Processing model '{model_name}'")

    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        model_dataset_configs = get_model_datasets_list(variant_dataset)
        for model_dataset_config in model_dataset_configs:
            model_patterns = model_dataset_config.get('models', [])
            model_dataset_name = model_dataset_config.get('name')
            
            # Check if this model matches any pattern in this config
            model_matches = False
            matched_pattern = None
            for pattern in model_patterns:
                if fnmatch.fnmatch(model_name, pattern):
                    model_matches = True
                    matched_pattern = pattern
                    break
            
            if not model_matches:
                continue  # Skip this config if model doesn't match

            print(f"DEBUG: Model '{model_name}' matches pattern '{matched_pattern}' in {variant_dataset}/{model_dataset_name}")

            # OPTIMIZATION: Read clustered TSV ONCE and cache it
            # The clustered TSV contains both cluster assignments AND prioritization columns
            # This avoids reading multiple files and is more efficient
            model_col = f'model_prioritized_by_any-{model_name}'
            cache_key = (variant_dataset, model_dataset_name, model_name)
            print(f"DEBUG: Looking for column '{model_col}' in clustered.tsv")
            
            # Check cache first
            if cache_key in _clustered_tsv_cache:
                df_clustered = _clustered_tsv_cache[cache_key]
            else:
                clustered_tsv = f"data/{variant_dataset}.{model_dataset_name}.clustered.tsv"
                df_clustered = None
                
                if os.path.exists(clustered_tsv):
                    print(f"DEBUG: Found clustered.tsv: {clustered_tsv}")
                    # Read clustered TSV with variant_id, kmeans_35, AND prioritization column
                    # This file should contain model_prioritized_by_any-{model} for all matched models
                    required_cols = ['variant_id', 'kmeans_35', model_col]
                    try:
                        df_clustered = pd.read_csv(clustered_tsv, sep='\t', usecols=required_cols)
                        print(f"DEBUG: Successfully read clustered.tsv with {len(df_clustered)} variants")
                    except ValueError as e:
                        # If required columns don't exist, read all columns and check
                        print(f"DEBUG: Required columns not found, reading all columns...")
                        df_clustered = pd.read_csv(clustered_tsv, sep='\t')
                        print(f"DEBUG: Available columns: {[c for c in df_clustered.columns if 'model_prioritized' in c or 'kmeans' in c][:20]}")
                        if 'kmeans_35' not in df_clustered.columns:
                            print(f"Warning: clustered.tsv {clustered_tsv} missing kmeans_35 column. Skipping.")
                            df_clustered = None
                        elif model_col not in df_clustered.columns:
                            # Prioritization column missing - this is an error
                            available_prio_cols = [c for c in df_clustered.columns if 'model_prioritized' in c]
                            print(f"ERROR: clustered.tsv {clustered_tsv} missing required column '{model_col}'.")
                            print(f"  Available prioritization columns: {available_prio_cols[:10]}")
                            raise ValueError(
                                f"clustered.tsv {clustered_tsv} missing required column '{model_col}'. "
                                f"This column should be present for model '{model_name}' in model_dataset '{model_dataset_name}'. "
                                f"Available columns: {available_prio_cols[:10]}..."
                            )
                        else:
                            # Keep only the needed columns
                            df_clustered = df_clustered[required_cols].copy()
                else:
                    # Clustered TSV doesn't exist yet (checkpoint hasn't run)
                    print(f"Warning: clustered.tsv {clustered_tsv} doesn't exist yet. Skipping.")
                    df_clustered = None
                
                # Cache the result (even if None, to avoid re-checking file existence)
                _clustered_tsv_cache[cache_key] = df_clustered
            
            # If clustered TSV doesn't exist or couldn't be read, skip this model_dataset
            if df_clustered is None:
                continue

            # Get clusters for this model_dataset (or [None] if no clustering)
            clusters = model_dataset_config.get('clusters', [])
            if not clusters:
                clusters = [None]

            for cluster in clusters:
                # Extract cluster_id
                if cluster is None:
                    cluster_id = None
                elif isinstance(cluster, dict):
                    cluster_id = cluster.get('id')
                else:
                    cluster_id = cluster

                print(f"DEBUG: Processing cluster_id={cluster_id} (type: {type(cluster_id)})")

                # Filter clustered TSV to this cluster AND prioritized by this model
                # Use df_clustered which has both kmeans_35 and model_prioritized_by_any-{model} columns
                if cluster_id is not None:
                    # Filter to variants in this cluster
                    # Note: kmeans_35 contains integers 0-34, but cluster_id from config might be string like "cluster_3"
                    # Convert cluster_id to integer if it's a string like "cluster_3" -> 3
                    if isinstance(cluster_id, str) and cluster_id.startswith('cluster_'):
                        try:
                            cluster_id_int = int(cluster_id.replace('cluster_', ''))
                        except ValueError:
                            print(f"Warning: Could not parse cluster_id '{cluster_id}' as integer. Skipping.")
                            continue
                    else:
                        cluster_id_int = cluster_id
                    
                    cluster_mask = (df_clustered['kmeans_35'] == cluster_id_int)
                    df_filtered = df_clustered[cluster_mask].copy()
                    print(f"DEBUG: After cluster filter (kmeans_35=={cluster_id_int}): {len(df_filtered)} variants")
                else:
                    # No clustering - use all variants
                    df_filtered = df_clustered.copy()
                    print(f"DEBUG: No cluster filter, using all {len(df_filtered)} variants")
                
                if len(df_filtered) == 0:
                    print(f"DEBUG: No variants in cluster {cluster_id}, skipping")
                    continue

                # Filter to variants prioritized by this model
                # The model_col should exist (we checked when reading the file)
                if model_col not in df_filtered.columns:
                    raise ValueError(
                        f"Missing column '{model_col}' in filtered dataframe. "
                        f"This should not happen if clustered.tsv was read correctly."
                    )
                
                # Check prioritization value format
                sample_values = df_filtered[model_col].dropna().unique()[:5]
                print(f"DEBUG: Sample prioritization values in {model_col}: {sample_values} (types: {[type(v).__name__ for v in sample_values]})")
                
                # Filter variants where this model's column indicates prioritization
                # Handle multiple formats: True, "True", "true", 1, "1"
                # Convert to string and check for 'true' (case-insensitive) or check for True/1 directly
                col_values = df_filtered[model_col]
                prioritized_mask = (
                    (col_values.astype(str).str.lower() == 'true') |
                    (col_values == True) |
                    (col_values == 1) |
                    (col_values.astype(str) == '1')
                ).fillna(False)
                model_variants = set(df_filtered[prioritized_mask]['variant_id'])
                print(f"DEBUG: After prioritization filter: {len(model_variants)} variants prioritized by {model_name}")
                demanded_variants.update(model_variants)

    print(f"DEBUG get_demanded_variants_for_model: Returning {len(demanded_variants)} total variants for model '{model_name}'")
    return demanded_variants

def get_all_finemo_split_files():
    """Generate list of all finemo split files needed for HTML generation.

    Extracts all unique models from VARIANT_DATASET_CONFIGS and returns
    the corresponding finemo split file paths.
    
    NOTE: This function does NOT read TSV files to avoid stalling during DAG construction.
    It only matches model patterns against KUN_FB_MODELS to determine which models
    might have finemo files.
    """
    import fnmatch
    
    models = set()
    
    # Iterate through configs and match patterns against KUN_FB_MODELS
    # This avoids reading TSV files during DAG construction
    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        model_dataset_configs = get_model_datasets_list(variant_dataset)
        for model_dataset_config in model_dataset_configs:
            model_patterns = model_dataset_config.get('models', [])
            
            # Match patterns against KUN_FB_MODELS
            for pattern in model_patterns:
                for model in KUN_FB_MODELS:
                    if fnmatch.fnmatch(model, pattern):
                        models.add(model)
    
    return [f"{SPLITS_DIR}/finemo/broad.finemo.{model}.tsv" for model in sorted(models)]

# ============================================================================
# Generate all expected outputs
# ============================================================================

def get_comprehensive_variants_tsvs():
    """Generate list of comprehensive variant TSV files for all variant datasets."""
    tsvs = []
    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        # Convert variant dataset name to match file naming convention
        # "Broad neurodevelopmental and neuromuscular disorders" -> "Broad neurodevelopmental and neuromuscular disorders"
        tsv_path = f"data/{variant_dataset}.comprehensive.tsv"
        tsvs.append(tsv_path)
    return tsvs

def get_all_variants_merged_tsvs():
    """Generate list of all merged variant TSV files that need to be created."""
    tsvs = []
    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        tsv_path = get_variants_tsv_path(variant_dataset)
        tsvs.append(tsv_path)
    return tsvs

def get_all_heatmap_outputs():
    """Generate list of all heatmap outputs (at model_dataset level with numeric prefixes)."""
    outputs = []
    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        model_dataset_configs = get_model_datasets_list(variant_dataset)
        for model_dataset_config in model_dataset_configs:
            model_dataset_name = model_dataset_config['name']

            # Heatmap is at model_dataset level (one per model_dataset, not per cluster)
            base_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}"

            # Intro and heatmap files with numeric prefixes at model_dataset level
            outputs.append(f"{base_path}/00-intro.md")
            outputs.append(f"{base_path}/01-heatmap.png")
            outputs.append(f"{base_path}/01-heatmap.md")
    return outputs

def get_all_variant_outputs():
    """Generate list of all per-variant outputs with flattened structure and numeric prefixes."""
    outputs = []

    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        model_dataset_configs = get_model_datasets_list(variant_dataset)
        for model_dataset_config in model_dataset_configs:
            model_dataset_name = model_dataset_config['name']

            # Get clusters for this model_dataset (or [None] if no clustering)
            clusters = model_dataset_config.get('clusters', [])
            if not clusters:
                clusters = [None]

            for cluster in clusters:
                # Extract cluster_id and cluster_name
                if cluster is None:
                    cluster_id = None
                    cluster_name = None
                elif isinstance(cluster, dict):
                    cluster_id = cluster.get('id')
                    cluster_name = cluster.get('name', cluster_id)
                else:
                    cluster_id = cluster
                    cluster_name = cluster

                # Add cluster-level intro if cluster exists
                if cluster_name:
                    cluster_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_name}"
                    outputs.append(f"{cluster_path}/00-intro.md")

                # Get prioritized variants for this model_dataset and cluster
                variants = get_prioritized_variants(model_dataset_config, cluster_id, variant_dataset)

                # BATCH MODE vs INDIVIDUAL MODE
                if USE_BATCH_PROFILES and cluster_name:
                    # BATCH MODE: Add symlink sentinel files for each model
                    # This ensures profiles are generated AND symlinked before other steps
                    import fnmatch

                    model_patterns = model_dataset_config.get('models', [])
                    matched_models = []

                    for pattern in model_patterns:
                        for model in KUN_FB_MODELS:
                            if fnmatch.fnmatch(model, pattern):
                                matched_models.append(model)

                    # Add symlink sentinel file for each model (depends on batch completion)
                    for model in set(matched_models):  # Remove duplicates
                        sentinel_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_name}/.profiles_symlinked_{model}.done"
                        outputs.append(sentinel_path)

                    # Still add non-profile outputs for each variant
                    for variant in variants:
                        variant_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_name}/{variant}"

                        # 00-intro.md
                        outputs.append(f"{variant_path}/00-intro.md")

                        # 01-model-specificity-barplot
                        outputs.append(f"{variant_path}/01-model-specificity-barplot.md")
                        outputs.append(f"{variant_path}/01-model-specificity-barplot.png")

                        # 02-model-scatterplot
                        outputs.append(f"{variant_path}/02-model-scatterplot.md")
                        outputs.append(f"{variant_path}/02-model-scatterplot.html")

                        # NOTE: Profile outputs are NOT added here in batch mode
                        # They are created as side effects of the batch rule

                else:
                    # INDIVIDUAL MODE: Keep original behavior
                    for variant in variants:
                        if cluster_name:
                            variant_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_name}/{variant}"
                        else:
                            variant_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{variant}"

                        # 00-intro.md
                        outputs.append(f"{variant_path}/00-intro.md")

                        # 01-model-specificity-barplot
                        outputs.append(f"{variant_path}/01-model-specificity-barplot.md")
                        outputs.append(f"{variant_path}/01-model-specificity-barplot.png")

                        # 02-model-scatterplot
                        outputs.append(f"{variant_path}/02-model-scatterplot.md")
                        outputs.append(f"{variant_path}/02-model-scatterplot.html")

                        # 03-profile-{model} for each prioritized model
                        prioritized_models = get_prioritized_models_for_variant(
                            variant_dataset, variant, model_dataset_name
                        )
                        for model in prioritized_models:
                            outputs.append(f"{variant_path}/03-profile-{model}.md")
                            outputs.append(f"{variant_path}/03-profile-{model}.png")

    return outputs

def get_all_plot_outputs():
    """Generate list of only plot outputs (excluding intro, before/after MD files, and HTML).

    This includes:
    - Heatmap PNG files
    - Model-specificity barplot PNG files
    - Model-scatterplot HTML files
    - Profile PNG files (via symlink sentinels in batch mode)

    Excludes:
    - 00-intro.md files
    - variant_report.html
    - Main .md files (these are just references to plots)
    """
    outputs = []

    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        model_dataset_configs = get_model_datasets_list(variant_dataset)
        for model_dataset_config in model_dataset_configs:
            model_dataset_name = model_dataset_config['name']

            # Heatmap PNG only (no .md)
            base_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}"
            outputs.append(f"{base_path}/01-heatmap.png")

            # Upset plot PNG at model_dataset level
            outputs.append(f"{base_path}/upset/hpo_overlaps.png")

            # Get clusters for this model_dataset (or [None] if no clustering)
            clusters = model_dataset_config.get('clusters', [])
            if not clusters:
                clusters = [None]

            for cluster in clusters:
                # Extract cluster_id and cluster_name
                if cluster is None:
                    cluster_id = None
                    cluster_name = None
                elif isinstance(cluster, dict):
                    cluster_id = cluster.get('id')
                    cluster_name = cluster.get('name', cluster_id)
                else:
                    cluster_id = cluster
                    cluster_name = cluster

                # Add cluster-level upset plot if cluster is defined
                if cluster_name:
                    cluster_upset_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_name}/upset/hpo_overlaps.png"
                    outputs.append(cluster_upset_path)

                # Get prioritized variants for this model_dataset and cluster
                variants = get_prioritized_variants(model_dataset_config, cluster_id, variant_dataset)

                # BATCH MODE vs INDIVIDUAL MODE
                if USE_BATCH_PROFILES and cluster_name:
                    # BATCH MODE: Add symlink sentinel files for each model
                    import fnmatch

                    model_patterns = model_dataset_config.get('models', [])
                    matched_models = []

                    for pattern in model_patterns:
                        for model in KUN_FB_MODELS:
                            if fnmatch.fnmatch(model, pattern):
                                matched_models.append(model)

                    # Add symlink sentinel file for each model (ensures profiles are generated)
                    for model in set(matched_models):  # Remove duplicates
                        sentinel_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_name}/.profiles_symlinked_{model}.done"
                        outputs.append(sentinel_path)
                        # Ensure finemo file is generated before profiles
                        finemo_file = f"{SPLITS_DIR}/finemo/broad.finemo.{model}.tsv"
                        outputs.append(finemo_file)

                    # Add plot outputs for each variant
                    for variant in variants:
                        safe_variant_id = get_safe_variant_id(variant)
                        variant_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_name}/{safe_variant_id}"

                        # 01-model-specificity-barplot PNG only
                        outputs.append(f"{variant_path}/01-model-specificity-barplot.png")

                        # 02-model-scatterplot HTML only
                        outputs.append(f"{variant_path}/02-model-scatterplot.html")

                        # NOTE: Profile PNGs are created as side effects of batch rule
                        # The sentinel files above ensure they are generated

                else:
                    # INDIVIDUAL MODE: Keep original behavior
                    for variant in variants:
                        safe_variant_id = get_safe_variant_id(variant)
                        if cluster_name:
                            variant_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_name}/{safe_variant_id}"
                        else:
                            variant_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{safe_variant_id}"

                        # 01-model-specificity-barplot PNG only
                        outputs.append(f"{variant_path}/01-model-specificity-barplot.png")

                        # 02-model-scatterplot HTML only
                        outputs.append(f"{variant_path}/02-model-scatterplot.html")

                        # 03-profile-{model} PNG only (for each prioritized model)
                        prioritized_models = get_prioritized_models_for_variant(
                            variant_dataset, variant, model_dataset_name
                        )
                        for model in prioritized_models:
                            outputs.append(f"{variant_path}/03-profile-{model}.png")
                            # Ensure finemo file is generated before profiles
                            finemo_file = f"{SPLITS_DIR}/finemo/broad.finemo.{model}.tsv"
                            outputs.append(finemo_file)

    return outputs


def aggregate_plots_after_checkpoints(wildcards):
    """
    Checkpoint-aware input function for plots_only rule.

    This function is called AFTER Snakemake executes all checkpoints.
    It can safely read from data/ files to determine which plots to generate.

    Execution flow:
    1. Snakemake builds initial DAG with checkpoint rules
    2. Executes merge_variants_tsv checkpoint → creates .variants_merged.tsv
    3. Executes cluster_model_dataset checkpoint → creates .clustered.tsv with kmeans columns
    4. DAG re-evaluation triggered
    5. This function called → reads checkpoint outputs
    6. Returns list of plot files to generate
    """
    outputs = []

    # For each variant dataset configuration
    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        model_dataset_configs = get_model_datasets_list(variant_dataset)

        # CRITICAL: Reference checkpoint to trigger execution and wait for completion
        checkpoints.merge_variants_tsv.get(
            variant_dataset=variant_dataset
        )

        for model_dataset_config in model_dataset_configs:
            model_dataset_name = model_dataset_config['name']

            # CRITICAL: Reference checkpoint to trigger execution and wait for completion
            checkpoints.cluster_model_dataset.get(
                variant_dataset=variant_dataset,
                model_dataset=model_dataset_name
            )

            # NOW safe to call get_all_plot_outputs() logic - files exist
            # Heatmap PNG only (no .md)
            base_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}"
            outputs.append(f"{base_path}/01-heatmap.png")

            # Upset plot PNG at model_dataset level
            outputs.append(f"{base_path}/upset/hpo_overlaps.png")

            # Get clusters for this model_dataset (or [None] if no clustering)
            clusters = model_dataset_config.get('clusters', [])
            if not clusters:
                clusters = [None]

            for cluster in clusters:
                # Extract cluster_id and cluster_name
                if cluster is None:
                    cluster_id = None
                    cluster_name = None
                elif isinstance(cluster, dict):
                    cluster_id = cluster.get('id')
                    cluster_name = cluster.get('name', cluster_id)
                else:
                    cluster_id = cluster
                    cluster_name = cluster

                # Add cluster-level upset plot if cluster is defined
                if cluster_name:
                    cluster_upset_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_name}/upset/hpo_overlaps.png"
                    outputs.append(cluster_upset_path)

                # Get prioritized variants for this model_dataset and cluster
                # NOW safe to call - checkpoint outputs exist
                variants = get_prioritized_variants(model_dataset_config, cluster_id, variant_dataset)

                # Add human-readable TSV for this cluster (if cluster_name exists)
                if cluster_name:
                    human_readable_tsv = f"{OUTPUT_DIR}/{variant_dataset}/human_readable_spreadsheets/{model_dataset_name}.{cluster_name}.human_readable.tsv"
                    outputs.append(human_readable_tsv)
                
                # Add variant summary HTML for each variant (if cluster_name exists)
                if cluster_name:
                    for variant in variants:
                        safe_variant_id = get_safe_variant_id(variant)
                        html_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_name}/{safe_variant_id}/00-summary.html"
                        outputs.append(html_path)

                # BATCH MODE vs INDIVIDUAL MODE
                if USE_BATCH_PROFILES and cluster_name:
                    # BATCH MODE: Add symlink sentinel files for each model
                    import fnmatch

                    model_patterns = model_dataset_config.get('models', [])
                    matched_models = []

                    for pattern in model_patterns:
                        for model in KUN_FB_MODELS:
                            if fnmatch.fnmatch(model, pattern):
                                matched_models.append(model)

                    # Add symlink sentinel file for each model (ensures profiles are generated)
                    for model in set(matched_models):  # Remove duplicates
                        sentinel_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_name}/.profiles_symlinked_{model}.done"
                        outputs.append(sentinel_path)
                        # Ensure finemo file is generated before profiles
                        finemo_file = f"{SPLITS_DIR}/finemo/broad.finemo.{model}.tsv"
                        outputs.append(finemo_file)

                    # Add plot outputs for each variant
                    for variant in variants:
                        safe_variant_id = get_safe_variant_id(variant)
                        variant_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_name}/{safe_variant_id}"

                        # 01-model-specificity-barplot PNG only
                        outputs.append(f"{variant_path}/01-model-specificity-barplot.png")

                        # 02-model-scatterplot HTML only
                        outputs.append(f"{variant_path}/02-model-scatterplot.html")

                        # NOTE: Profile PNGs are created as side effects of batch rule
                        # The sentinel files above ensure they are generated

                else:
                    # INDIVIDUAL MODE: Keep original behavior
                    for variant in variants:
                        safe_variant_id = get_safe_variant_id(variant)
                        if cluster_name:
                            variant_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_name}/{safe_variant_id}"
                        else:
                            variant_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{safe_variant_id}"

                        # 01-model-specificity-barplot PNG only
                        outputs.append(f"{variant_path}/01-model-specificity-barplot.png")

                        # 02-model-scatterplot HTML only
                        outputs.append(f"{variant_path}/02-model-scatterplot.html")

                        # 03-profile-{model} PNG only (for each prioritized model)
                        prioritized_models = get_prioritized_models_for_variant(
                            variant_dataset, variant, model_dataset_name
                        )
                        for model in prioritized_models:
                            outputs.append(f"{variant_path}/03-profile-{model}.png")
                            # Ensure finemo file is generated before profiles
                            finemo_file = f"{SPLITS_DIR}/finemo/broad.finemo.{model}.tsv"
                            outputs.append(finemo_file)

    return outputs


# ============================================================================
# Rules
# ============================================================================

rule all:
    input:
        get_all_variants_merged_tsvs(),  # Ensure merged TSVs are generated first
        get_all_heatmap_outputs(),
        get_all_variant_outputs(),
        "variant_report.html"

rule plots_only:
    """Generate only plot outputs (PNG, HTML) without intro, before/after MD files, or HTML report.

    Usage:
        snakemake plots_only --cores 6 --resources gpu_per_model=1

    This rule uses a checkpoint-aware input function pattern:
    1. Initial DAG: Snakemake identifies checkpoint dependencies
    2. Executes merge_variants_tsv and cluster_model_dataset checkpoints
    3. DAG re-evaluates after checkpoints complete
    4. Input function reads checkpoint outputs to determine plot targets
    5. Generates all plot outputs

    This rule generates:
    - Heatmap PNG files
    - Model-specificity barplot PNG files
    - Model-scatterplot HTML files
    - Profile PNG files (via batch processing)

    Does NOT generate:
    - 00-intro.md files
    - Main .md reference files
    - variant_report.html
    """
    input:
        # Combine plots and human-readable TSV sentinels
        # Input function must be called via lambda to combine with other inputs
        files = lambda w: aggregate_plots_after_checkpoints(w) + _get_all_variant_dataset_sentinels()

# ----------------------------------------------------------------------------
# Merged TSV generation (combine prioritization + score/logfc/aaq columns)
# ----------------------------------------------------------------------------

def parse_model_sets_str(model_sets_str):
    """
    Parse model sets string into list of model set names.
    
    Parameters:
    -----------
    model_sets_str : str
        Underscore-separated model set names (e.g., "kun_fb_kun_hdma" -> ["KUN_FB", "KUN_HDMA"])
    
    Returns:
    --------
    list of str
        List of model set names
    """
    # Convert from lowercase with underscores to uppercase model set names
    # Handle pattern: "kun_fb_kun_hdma" -> ["KUN_FB", "KUN_HDMA"]
    parts = model_sets_str.split('_')
    result = []
    i = 0
    while i < len(parts):
        if parts[i].upper() == 'KUN' and i + 1 < len(parts):
            result.append(f"KUN_{parts[i+1].upper()}")
            i += 2
        else:
            result.append(parts[i].upper())
            i += 1
    return result

rule merge_model_sets_for_scatterplot_context:
    """
    Merge multiple model sets for a variant dataset to create a larger variant set for scatterplot context.
    
    This rule combines data from multiple model sets (e.g., KUN_FB, KUN_HDMA) for a given variant dataset:
    - General variant information
    - Prioritization columns from all specified model sets
    - logfc and aaq columns from all specified model sets
    
    The output can be used as scatterplot_context_tsv in VARIANT_DATASET_CONFIGS
    to show a larger set of variants in scatterplots for better context.
    
    Wildcards:
    ----------
    variant_dataset : str
        The variant dataset name (e.g., "Broad neurodevelopmental and neuromuscular disorders")
    model_sets_str : str
        Underscore-separated model set names (e.g., "kun_fb_kun_hdma" for KUN_FB and KUN_HDMA)
    
    Usage in VARIANT_DATASET_CONFIGS:
    {
        "variant_dataset_name": {
            'scatterplot_context_tsv': 'data/{variant_dataset}_kun_fb_kun_hdma_scatterplot_context.tsv',
            'model_datasets': [...]
        }
    }
    """
    output:
        merged_tsv = "data/{variant_dataset}_{model_sets_str}_scatterplot_context.tsv"
    params:
        variant_dataset = lambda w: w.variant_dataset,
        model_sets = lambda w: parse_model_sets_str(w.model_sets_str)
    run:
        import pandas as pd
        import os
        import time
        
        variant_dataset = params.variant_dataset
        model_sets = params.model_sets
        
        start_time = time.time()
        print(f"[TIMING] Starting merge_model_sets_for_scatterplot_context for {variant_dataset} (model_sets: {model_sets}) at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Merging model sets {model_sets} for variant dataset '{variant_dataset}' for scatterplot context...")
        
        # Start with general file for the variant dataset
        # Use variant_dataset-specific general.tsv file
        general_file = f"{SPLITS_DIR}/{variant_dataset}.general.tsv"
        if not os.path.exists(general_file):
            raise FileNotFoundError(f"General file not found for variant_dataset '{variant_dataset}': {general_file}")
        
        print(f"Reading {general_file}...")
        df_merged = pd.read_csv(general_file, sep='\t')
        print(f"  Initial variants: {len(df_merged)}")
        
        # Read logfc and aaq files once (they contain columns for all models)
        logfc_file = f"{SPLITS_DIR}/{variant_dataset}.logfc.tsv"
        
        aaq_file = f"{SPLITS_DIR}/{variant_dataset}.aaq.tsv"
        
        # For each model set, merge prioritization files
        for model_set in model_sets:
            print(f"\nProcessing model set: {model_set}")
            
            # Prioritization file
            prio_file = f"{SPLITS_DIR}/{variant_dataset}.model_prioritized_by_any-{model_set}.tsv"
            if not os.path.exists(prio_file):
                # Try fallback to 'broad' prefix
                prio_file = f"{SPLITS_DIR}/broad.model_prioritized_by_any-{model_set}.tsv"
            
            if os.path.exists(prio_file):
                print(f"  Reading prioritization: {prio_file}")
                df_prio = pd.read_csv(prio_file, sep='\t')
                df_merged = df_merged.merge(df_prio, on='variant_id', how='outer')
                print(f"    Added {len(df_prio)} variants (total: {len(df_merged)})")
            else:
                print(f"  Warning: Prioritization file not found: {prio_file}")
        
        # Merge logfc file once, keeping columns for all model sets
        if os.path.exists(logfc_file):
            print(f"\nReading logfc: {logfc_file}")
            df_logfc = pd.read_csv(logfc_file, sep='\t')
            # Keep columns for all model sets
            logfc_cols = ['variant_id'] + [c for c in df_logfc.columns if any(model_set in c for model_set in model_sets)]
            if len(logfc_cols) > 1:  # More than just variant_id
                df_logfc_filtered = df_logfc[logfc_cols]
                df_merged = df_merged.merge(df_logfc_filtered, on='variant_id', how='left')
                print(f"    Added {len([c for c in logfc_cols if c != 'variant_id'])} logfc columns")
        else:
            print(f"  Warning: Logfc file not found: {logfc_file}")
        
        # Merge aaq file once, keeping columns for all model sets
        if os.path.exists(aaq_file):
            print(f"Reading aaq: {aaq_file}")
            df_aaq = pd.read_csv(aaq_file, sep='\t')
            # Keep columns for all model sets
            aaq_cols = ['variant_id'] + [c for c in df_aaq.columns if any(model_set in c for model_set in model_sets)]
            if len(aaq_cols) > 1:  # More than just variant_id
                df_aaq_filtered = df_aaq[aaq_cols]
                df_merged = df_merged.merge(df_aaq_filtered, on='variant_id', how='left')
                print(f"    Added {len([c for c in aaq_cols if c != 'variant_id'])} aaq columns")
        else:
            print(f"  Warning: Aaq file not found: {aaq_file}")
        
        # Ensure output directory exists
        os.makedirs(os.path.dirname(output.merged_tsv), exist_ok=True)
        
        # Save merged TSV
        df_merged.to_csv(output.merged_tsv, sep='\t', index=False)
        end_time = time.time()
        duration = end_time - start_time
        print(f"[TIMING] Finished merge_model_sets_for_scatterplot_context for {variant_dataset} in {duration:.1f}s")
        print(f"\nCreated merged scatterplot context TSV: {output.merged_tsv}")
        print(f"  Total variants: {len(df_merged)}")
        print(f"  Total columns: {len(df_merged.columns)}")

checkpoint merge_variants_tsv:
    """
    Automatically merge all required annotation files for a variant_dataset.

    This checkpoint creates prerequisite data files before DAG re-evaluation.
    Snakemake will wait for this checkpoint to complete, then re-evaluate the DAG
    to determine which plots need to be generated based on the merged data.

    The merged TSV includes:
    - General variant information (coordinates, alleles, etc.)
    - Prioritization columns for all models matching patterns in configs
    - logfc and aaq columns (for scatterplots)
    - Any other annotations needed by downstream rules
    """
    output:
        "data/{variant_dataset}.variants_merged.tsv"
    run:
        import subprocess
        import os
        
        variant_dataset = wildcards.variant_dataset
        required_files = get_required_annotation_files(variant_dataset)
        output_file = output[0]
        
        # Verify all required files exist
        missing_files = [f for f in required_files if not os.path.exists(f)]
        if missing_files:
            raise FileNotFoundError(
                f"Missing required annotation files for {variant_dataset}:\n" +
                "\n".join(f"  - {f}" for f in missing_files)
            )
        
        if not required_files:
            raise ValueError(f"No annotation files found for variant_dataset: {variant_dataset}")
        
        # Ensure output directory exists
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        
        # Build merge-columns command
        cmd = [
            "merge-columns",
            "--merge-column", "variant_id",
            "--join-type", "outer",  # Use outer join to keep all variants
            "--output", output_file
        ] + required_files
        
        print(f"Merging {len(required_files)} files for {variant_dataset}:")
        for f in required_files:
            print(f"  - {f}")
        
        # Run merge-columns with timing
        import time
        start_time = time.time()
        print(f"[TIMING] Starting filter_and_score_for_model_dataset merge for {variant_dataset} ({len(required_files)} files) at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        result = subprocess.run(cmd, capture_output=True, text=True)
        end_time = time.time()
        duration = end_time - start_time
        
        if result.returncode != 0:
            raise RuntimeError(
                f"merge-columns failed for {variant_dataset}:\n"
                f"stdout: {result.stdout}\n"
                f"stderr: {result.stderr}"
            )
        
        print(f"[TIMING] Finished filter_and_score_for_model_dataset merge for {variant_dataset} in {duration:.1f}s")
        print(f"Successfully merged {len(required_files)} files -> {output_file}")

rule filter_and_score_for_model_dataset:
    """
    Filter variants and generate score columns for a model_dataset.

    Optimization: Filter to prioritized variants FIRST, then read other split files
    only for those variants. This avoids loading huge TSVs unnecessarily.

    Steps:
    1. Read prioritization split file (lightweight)
    2. Filter to variants prioritized by ANY model matching model_dataset patterns
    3. Read other split files (general, logfc, hpo) - ONLY for filtered variants
    4. Generate score-{model} = logfc-{model} × model_prioritized_by_any-{model}
    5. Keep only relevant columns

    Output: ~3,880 variants × ~100 columns (vs 30k × 1000+)
    """
    output:
        "data/{variant_dataset}.{model_dataset}.filtered.tsv"
    run:
        import pandas as pd
        import fnmatch
        import os
        import time

        start_time = time.time()
        variant_dataset = wildcards.variant_dataset
        model_dataset_name = wildcards.model_dataset

        # Get model patterns from config
        model_dataset_config = None
        for config in (get_model_datasets_list(variant_dataset) if variant_dataset in VARIANT_DATASET_CONFIGS else []):
            if config['name'] == model_dataset_name:
                model_dataset_config = config
                break

        if model_dataset_config is None:
            raise ValueError(f"No config found for model_dataset '{model_dataset_name}' in variant_dataset '{variant_dataset}'")

        model_patterns = model_dataset_config.get('models', [])

        # Determine which dataset prefix to use for prioritization file (KUN_FB, KUN_HDMA, etc.)
        # Extract prefix from first model pattern (e.g., "KUN_FB*" -> "KUN_FB")
        dataset_prefix = None
        for pattern in model_patterns:
            # Remove wildcards to get base prefix
            prefix = pattern.replace('*', '').strip('_')
            if prefix:
                dataset_prefix = prefix
                break

        if dataset_prefix is None:
            raise ValueError(f"Could not determine dataset prefix from model patterns: {model_patterns}")

        # STEP 1: Read prioritization file FIRST (small file, only variant_id + boolean columns)
        prioritization_file = f"{SPLITS_DIR}/{variant_dataset}.model_prioritized_by_any-{dataset_prefix}.tsv"

        if not os.path.exists(prioritization_file):
            # Try LOCAL_SPLITS_DIR as fallback
            prioritization_file = f"{LOCAL_SPLITS_DIR}/{variant_dataset}.model_prioritized_by_any-{dataset_prefix}.tsv"

        print(f"Reading prioritization file: {prioritization_file}")
        df_prio = pd.read_csv(prioritization_file, sep='\t')

        # STEP 2: Identify models matching patterns and filter to prioritized variants
        matched_models = []
        for col in df_prio.columns:
            if col.startswith('model_prioritized_by_any-'):
                model_name = col.replace('model_prioritized_by_any-', '')
                # Check if this model matches ANY pattern
                for pattern in model_patterns:
                    if fnmatch.fnmatch(model_name, pattern):
                        matched_models.append(model_name)
                        break

        if not matched_models:
            raise ValueError(f"No models matched patterns {model_patterns} in {prioritization_file}")

        print(f"Matched {len(matched_models)} models: {matched_models[:5]}..." if len(matched_models) > 5 else matched_models)

        # Filter to variants prioritized by at least one matched model
        prioritized_cols = [f'model_prioritized_by_any-{m}' for m in matched_models]
        # Convert to boolean (handle both "True"/"False" strings and True/False booleans)
        prioritized_mask = df_prio[prioritized_cols].isin(['True', True, 'true', 1]).any(axis=1)
        prioritized_variant_ids = df_prio[prioritized_mask]['variant_id'].tolist()

        # Suppress verbose output - timing message already indicates what's happening
        # print(f"Filtered to {len(prioritized_variant_ids)} prioritized variants (from {len(df_prio)} total)")

        # STEP 3: NOW read and filter other split files (only for prioritized variants)
        # General
        general_file = f"{SPLITS_DIR}/{variant_dataset}.general.tsv"
        if not os.path.exists(general_file):
            general_file = f"{LOCAL_SPLITS_DIR}/{variant_dataset}.general.tsv"
        print(f"Reading general file: {general_file}")
        df_general = pd.read_csv(general_file, sep='\t')
        df_general = df_general[df_general['variant_id'].isin(prioritized_variant_ids)]

        # Logfc
        logfc_file = f"{SPLITS_DIR}/{variant_dataset}.logfc.tsv"
        if not os.path.exists(logfc_file):
            logfc_file = f"{LOCAL_SPLITS_DIR}/{variant_dataset}.logfc.tsv"
        print(f"Reading logfc file: {logfc_file}")
        df_logfc = pd.read_csv(logfc_file, sep='\t')
        df_logfc = df_logfc[df_logfc['variant_id'].isin(prioritized_variant_ids)]

        # HPO
        hpo_file = f"{SPLITS_DIR}/{variant_dataset}.patient_hpo_expanded.tsv"
        print(f"Reading HPO file: {hpo_file}")
        df_hpo = pd.read_csv(hpo_file, sep='\t')
        df_hpo = df_hpo[df_hpo['variant_id'].isin(prioritized_variant_ids)]

        # STEP 4: Merge filtered dataframes
        df = df_general.merge(df_prio[prioritized_mask], on='variant_id', how='inner')
        df = df.merge(df_logfc, on='variant_id', how='inner')
        df = df.merge(df_hpo, on='variant_id', how='left')  # left join for HPO (not all variants have HPO)

        # STEP 5: Keep only relevant model columns
        keep_cols = ['variant_id', 'chr', 'pos', 'ref', 'alt']
        for model in matched_models:
            if f'model_prioritized_by_any-{model}' in df.columns:
                keep_cols.append(f'model_prioritized_by_any-{model}')
            if f'logfc-{model}' in df.columns:
                keep_cols.append(f'logfc-{model}')

        # Add HPO columns
        hpo_cols = [c for c in df.columns if c.startswith('has_hpo_')]
        keep_cols.extend(hpo_cols)

        # Filter to keep_cols that actually exist
        keep_cols = [c for c in keep_cols if c in df.columns]
        df = df[keep_cols]

        # STEP 6: Generate score columns
        print("Generating score-{model} columns...")
        for model in matched_models:
            prio_col = f'model_prioritized_by_any-{model}'
            logfc_col = f'logfc-{model}'

            if prio_col in df.columns and logfc_col in df.columns:
                # Convert prioritization to numeric (True→1, False→0)
                prio_numeric = df[prio_col].isin(['True', True, 'true', 1]).astype(int)
                # score = logfc × prioritization
                df[f'score-{model}'] = df[logfc_col] * prio_numeric

        # Write output
        df.to_csv(output[0], sep='\t', index=False)
        end_time = time.time()
        duration = end_time - start_time
        print(f"[TIMING] Finished filter_and_score_for_model_dataset for {variant_dataset}/{model_dataset_name} in {duration:.1f}s")
        print(f"Wrote {len(df)} variants × {len(df.columns)} columns to {output[0]}")

checkpoint cluster_model_dataset:
    """
    Run KMeans clustering on score columns for a model_dataset.

    This checkpoint performs clustering before DAG re-evaluation, allowing
    downstream rules to be generated dynamically based on cluster assignments.

    Steps:
    1. Read filtered TSV (has score-{model} columns)
    2. Extract all score-{model} columns
    3. Run KMeans (k=35, random_state=42) on score matrix
    4. Add kmeans_35 column with cluster assignments (0-34)
    5. DROP logfc-{model} columns (no longer needed after score generation)
    6. Write clustered TSV

    Output: Same TSV + kmeans_35 column, minus logfc columns
    """
    input:
        "data/{variant_dataset}.{model_dataset}.filtered.tsv"
    output:
        "data/{variant_dataset}.{model_dataset}.clustered.tsv"
    run:
        import pandas as pd
        import numpy as np
        import time
        from sklearn.cluster import KMeans
        from varbook.annotate.kmeans import sort_clusters_by_size

        start_time = time.time()
        print(f"[TIMING] Starting cluster_model_dataset for {wildcards.variant_dataset}/{wildcards.model_dataset} at {time.strftime('%Y-%m-%d %H:%M:%S')}")

        # Read filtered TSV
        df = pd.read_csv(input[0], sep='\t')

        # Extract score columns
        score_cols = [col for col in df.columns if col.startswith('score-')]

        if len(score_cols) == 0:
            raise ValueError(f"No score-{{model}} columns found in {input[0]}")

        print(f"Running KMeans clustering on {len(score_cols)} score columns...")

        # Prepare data for clustering (fill NaN with 0)
        X = df[score_cols].fillna(0).values

        # Run KMeans (k=35, random_state for reproducibility)
        kmeans = KMeans(n_clusters=35, random_state=42, n_init=10)
        labels = kmeans.fit_predict(X)

        # Sort clusters by size (largest → smallest)
        labels, remap = sort_clusters_by_size(labels)
        df['kmeans_35'] = labels

        print(f"Cluster distribution (sorted by size):")
        print(df['kmeans_35'].value_counts().sort_index())

        # Add organs column from model_tissues.tsv
        # This is required for model_specificity_barplot to work
        model_tissues_path = f"splits/{wildcards.variant_dataset}.model_tissues.tsv"
        print(f"Loading model tissues metadata from {model_tissues_path}...")
        try:
            df_tissues = pd.read_csv(model_tissues_path, sep='\t')
            # Create dictionary mapping model names to organs
            model_to_organs = dict(zip(df_tissues['model_name'], df_tissues['organs']))
            # Add as string representation of dictionary (same for all variants)
            df['organs'] = str(model_to_organs)
            print(f"Added organs column with {len(model_to_organs)} model-to-organ mappings")
        except Exception as e:
            print(f"Warning: Could not load organs metadata: {e}")
            print("Barplot generation may fail without organs column")

        # DROP logfc columns (no longer needed - score columns are sufficient for heatmap)
        logfc_cols = [col for col in df.columns if col.startswith('logfc-')]
        if logfc_cols:
            print(f"Dropping {len(logfc_cols)} logfc columns to save space...")
            df = df.drop(columns=logfc_cols)

        # Write output
        df.to_csv(output[0], sep='\t', index=False)
        end_time = time.time()
        duration = end_time - start_time
        print(f"[TIMING] Finished cluster_model_dataset for {wildcards.variant_dataset}/{wildcards.model_dataset} in {duration:.1f}s")
        print(f"Wrote {len(df)} variants × {len(df.columns)} columns to {output[0]}")
        print(f"Columns: {list(df.columns[:10])}..." if len(df.columns) > 10 else list(df.columns))

# Resolve ambiguity: filter_variants_for_cluster should be preferred when the pattern
# matches cluster files (data/{variant_dataset}.{model_dataset}.{cluster_id}.filtered.tsv)
ruleorder: filter_variants_for_cluster > filter_and_score_for_model_dataset

# Prioritize human-readable TSV generation - it doesn't depend on finemo and can run early
# This ensures it runs before finemo when both are requested
ruleorder: generate_human_readable_cluster_tsv > annotate_finemo_split_file
ruleorder: generate_human_readable_cluster_tsv > ensure_finemo_files

rule generate_human_readable_cluster_tsv:
    """
    Generate human-readable TSV for a cluster containing:
    - Variants & general info (chr, pos, ref, alt)
    - Kmeans cluster assignments
    - Prioritized models lists and counts (aggregated at dataset level)
    - Per-model columns: logfc-{model}, aaq-{model}, model_prioritized_by_any-{model}
    - Closest elements (genes, miRNA, lncRNA with distances)
    - VEP most severe consequences columns
    - GENCODE region type columns
    - gnomAD columns
    - HPO columns
    
    Excludes per-model score columns and organs.
    Uses merge-columns CLI to merge all input files.
    """
    input:
        # Use general.tsv as base (has ALL variants with chr, pos, ref, alt)
        general_tsv = lambda w: f"{SPLITS_DIR}/{w.variant_dataset}.general.tsv",
        clustered_tsv = "data/{variant_dataset}.{model_dataset}.clustered.tsv",
        closest_elements_tsv = lambda w: f"{SPLITS_DIR}/{w.variant_dataset}.closest_elements.tsv",
        prioritization_tsv = lambda w: _get_prioritization_tsv_path(w.variant_dataset, w.model_dataset),
        vep_most_severe_csqs_tsv = lambda w: f"{SPLITS_DIR}/{w.variant_dataset}.VEP.most_severe_csqs.tsv",
        gencode_region_type_tsv = lambda w: f"{SPLITS_DIR}/{w.variant_dataset}.GENCODE.region_type.tsv",
        gnomad_tsv = lambda w: f"{SPLITS_DIR}/{w.variant_dataset}.gnomad.tsv",
        patient_hpo_expanded_tsv = lambda w: f"{SPLITS_DIR}/{w.variant_dataset}.patient_hpo_expanded.tsv"
    output:
        "{OUTPUT_DIR}/{variant_dataset}/human_readable_spreadsheets/{model_dataset}.{cluster_id}.human_readable.tsv"
    params:
        variant_dataset = lambda w: w.variant_dataset,
        model_dataset = lambda w: w.model_dataset,
        cluster_id = lambda w: w.cluster_id
    run:
        import pandas as pd
        import os
        import subprocess
        import time
        import re
        import tempfile
        
        variant_dataset = params.variant_dataset
        model_dataset_name = params.model_dataset
        cluster_id = params.cluster_id
        
        start_time = time.time()
        print(f"[TIMING] Starting generate_human_readable_cluster_tsv for {variant_dataset}/{model_dataset_name}/{cluster_id} at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        
        # Get model_dataset config
        model_dataset_config = None
        for config in (get_model_datasets_list(variant_dataset) if variant_dataset in VARIANT_DATASET_CONFIGS else []):
            if config['name'] == model_dataset_name:
                model_dataset_config = config
                break
        
        if model_dataset_config is None:
            raise ValueError(f"No config found for model_dataset '{model_dataset_name}' in variant_dataset '{variant_dataset}'")
        
        # Parse cluster number from cluster_id
        cluster_num = None
        if cluster_id.startswith('cluster_'):
            try:
                cluster_num = int(cluster_id.replace('cluster_', ''))
            except ValueError:
                pass
        else:
            # Try to extract number from parentheses (e.g., "(#3)" → 3)
            match = re.search(r'\(#(\d+)\)', cluster_id)
            if match:
                cluster_num = int(match.group(1))
        
        if cluster_num is None:
            raise ValueError(f"Could not parse cluster number from cluster_id: {cluster_id}")
        
        # STEP 1: Read general.tsv (has ALL variants with chr, pos, ref, alt)
        print(f"Reading general TSV: {input.general_tsv}")
        df_general = pd.read_csv(input.general_tsv, sep='\t')
        print(f"Loaded {len(df_general)} variants from general.tsv")
        
        # STEP 2: Read clustered.tsv and get cluster assignments
        print(f"Reading clustered TSV: {input.clustered_tsv}")
        df_clustered = pd.read_csv(input.clustered_tsv, sep='\t')
        
        if 'kmeans_35' not in df_clustered.columns:
            raise ValueError(f"kmeans_35 column not found in {input.clustered_tsv}")
        
        # Filter to cluster and extract columns we need from clustered.tsv
        # Include variant_id, kmeans_35, and model_prioritized_by_any-{model} columns
        # Don't include HPO columns from clustered.tsv - they'll come from patient_hpo_expanded.tsv
        cluster_variants = df_clustered[df_clustered['kmeans_35'] == cluster_num]
        
        # Get model_prioritized_by_any-{model} columns from clustered.tsv
        model_prio_cols = [col for col in df_clustered.columns if col.startswith('model_prioritized_by_any-')]
        cols_to_extract = ['variant_id', 'kmeans_35'] + model_prio_cols
        df_cluster_assignments = cluster_variants[[col for col in cols_to_extract if col in cluster_variants.columns]].copy()
        print(f"Found {len(df_cluster_assignments)} variants in cluster {cluster_num} from clustered.tsv")
        if model_prio_cols:
            print(f"Extracted {len([c for c in model_prio_cols if c in df_cluster_assignments.columns])} model_prioritized_by_any- columns from clustered.tsv")
        
        if len(df_cluster_assignments) == 0:
            print(f"Warning: No variants found in cluster {cluster_num}. Creating empty file.")
            # Create empty file with correct structure
            os.makedirs(os.path.dirname(output[0]), exist_ok=True)
            empty_df = pd.DataFrame(columns=['variant_id', 'chr', 'pos', 'ref', 'alt', 'kmeans_35'])
            empty_df.to_csv(output[0], sep='\t', index=False)
            return
        
        # STEP 3: Merge general.tsv with cluster assignments
        # This ensures ALL variants in the cluster have chr, pos, ref, alt filled in
        df_base = df_general.merge(df_cluster_assignments, on='variant_id', how='inner')
        print(f"After merging with cluster assignments: {len(df_base)} variants")
        
        # Debug: Check what columns we have and if they have data
        print(f"Columns in merged df_base: {list(df_base.columns)}")
        for col in ['chr', 'pos', 'ref', 'alt', 'kmeans_35', 'most_active_pos_LFC_celltype', 'most_active_pos_LFC', 'count_of_pos_LFC_models', 'count_of_pos_prio_LFC_models', 'most_active_neg_LFC_celltype', 'most_active_neg_LFC']:
            if col in df_base.columns:
                non_null_count = df_base[col].notna().sum()
                print(f"  {col}: {non_null_count}/{len(df_base)} non-null values")
            else:
                print(f"  {col}: MISSING from df_base")
        
        # Select only needed columns from base (exclude per-model score columns, organs, and HPO columns)
        # HPO columns will come from patient_hpo_expanded.tsv to avoid duplicates
        # Include model_prioritized_by_any-{model} columns from clustered.tsv
        keep_cols = ['variant_id', 'chr', 'pos', 'ref', 'alt', 'kmeans_35', 'most_active_pos_LFC_celltype', 'most_active_pos_LFC', 'count_of_pos_LFC_models', 'count_of_pos_prio_LFC_models', 'most_active_neg_LFC_celltype', 'most_active_neg_LFC']
        
        # Add model_prioritized_by_any-{model} columns from clustered.tsv (already merged into df_base)
        model_prio_cols_in_base = [col for col in df_base.columns if col.startswith('model_prioritized_by_any-')]
        keep_cols.extend(model_prio_cols_in_base)
        
        # Filter to columns that actually exist
        keep_cols = [col for col in keep_cols if col in df_base.columns]
        df_base_filtered = df_base[keep_cols].copy()
        
        # Debug: Verify the filtered dataframe has data
        print(f"Columns in df_base_filtered: {list(df_base_filtered.columns)}")
        print(f"Sample data (first 3 rows):")
        print(df_base_filtered.head(3).to_string())
        
        # Write filtered base to temp file
        tmp_base_file = tempfile.NamedTemporaryFile(mode='w', suffix='.tsv', delete=False)
        tmp_base_path = tmp_base_file.name
        tmp_base_file.close()
        df_base_filtered.to_csv(tmp_base_path, sep='\t', index=False)
        print(f"Wrote temp file with {len(df_base_filtered)} rows and {len(df_base_filtered.columns)} columns to {tmp_base_path}")
        
        # STEP 4: Build list of files to merge
        files_to_merge = [tmp_base_path]
        
        # Add models_prioritized_by_any-{model_dataset_in_superset}.tsv for each model dataset in superset
        # Also add models_prioritized_by_peak-{set}, models_prioritized_by_promoter-{set}, models_prioritized_by_outofpeak-{set}
        if 'model_superset' in model_dataset_config:
            superset_datasets = get_superset_model_datasets(model_dataset_config)
            print(f"Found model_superset with {len(superset_datasets)} model datasets: {superset_datasets}")
            
            for superset_dataset in superset_datasets:
                # Skip if this is the same as the main model_dataset (already included)
                if superset_dataset == model_dataset_name:
                    continue
                
                # Add models_prioritized_by_any-{superset_dataset}.tsv
                superset_prio_file = f"{SPLITS_DIR}/{variant_dataset}.models_prioritized_by_any-{superset_dataset}.tsv"
                if os.path.exists(superset_prio_file):
                    files_to_merge.append(superset_prio_file)
                else:
                    print(f"Warning: Superset prioritization file not found: {superset_prio_file}")
                
                # Add models_prioritized_by_peak-{superset_dataset}.tsv
                superset_peak_file = f"{SPLITS_DIR}/{variant_dataset}.models_prioritized_by_peak-{superset_dataset}.tsv"
                if os.path.exists(superset_peak_file):
                    files_to_merge.append(superset_peak_file)
                else:
                    print(f"Warning: Superset peak prioritization file not found: {superset_peak_file}")
                
                # Add models_prioritized_by_promoter-{superset_dataset}.tsv
                superset_promoter_file = f"{SPLITS_DIR}/{variant_dataset}.models_prioritized_by_promoter-{superset_dataset}.tsv"
                if os.path.exists(superset_promoter_file):
                    files_to_merge.append(superset_promoter_file)
                else:
                    print(f"Warning: Superset promoter prioritization file not found: {superset_promoter_file}")
                
                # Add models_prioritized_by_outofpeak-{superset_dataset}.tsv
                superset_outofpeak_file = f"{SPLITS_DIR}/{variant_dataset}.models_prioritized_by_outofpeak-{superset_dataset}.tsv"
                if os.path.exists(superset_outofpeak_file):
                    files_to_merge.append(superset_outofpeak_file)
                else:
                    print(f"Warning: Superset outofpeak prioritization file not found: {superset_outofpeak_file}")
        
        # Add closest_elements.tsv
        if os.path.exists(input.closest_elements_tsv):
            files_to_merge.append(input.closest_elements_tsv)
        else:
            print(f"Warning: closest_elements.tsv not found: {input.closest_elements_tsv}")
        
        # Add models_prioritized_by_any-{model_dataset}.tsv
        if os.path.exists(input.prioritization_tsv):
            files_to_merge.append(input.prioritization_tsv)
        else:
            print(f"Warning: prioritization.tsv not found: {input.prioritization_tsv}")
        
        # Add logfc.tsv to get logfc-{model} columns
        logfc_file = f"{SPLITS_DIR}/{variant_dataset}.logfc.tsv"
        if os.path.exists(logfc_file):
            files_to_merge.append(logfc_file)
            print(f"Adding logfc.tsv for logfc-{{model}} columns")
        else:
            print(f"Warning: logfc.tsv not found: {logfc_file}")
        
        # Add aaq.tsv to get aaq-{model} columns
        aaq_file = f"{SPLITS_DIR}/{variant_dataset}.aaq.tsv"
        if os.path.exists(aaq_file):
            files_to_merge.append(aaq_file)
            print(f"Adding aaq.tsv for aaq-{{model}} columns")
        else:
            print(f"Warning: aaq.tsv not found: {aaq_file}")
        
        # Add models_prioritized_by_any-{set}.tsv
        models_prioritized_by_any_file = f"{SPLITS_DIR}/{variant_dataset}.models_prioritized_by_any-{set}.tsv"
        if os.path.exists(models_prioritized_by_any_file):
            files_to_merge.append(models_prioritized_by_any_file)
        else:
            print(f"Warning: models_prioritized_by_any-{set}.tsv not found: {models_prioritized_by_any_file}")
        
        # Add VEP most severe consequences TSV
        if os.path.exists(input.vep_most_severe_csqs_tsv):
            print(f"Adding VEP most severe consequences TSV: {input.vep_most_severe_csqs_tsv}")
            files_to_merge.append(input.vep_most_severe_csqs_tsv)
        else:
            print(f"Warning: VEP most severe consequences TSV not found: {input.vep_most_severe_csqs_tsv}")
        
        # Add GENCODE region type TSV
        if os.path.exists(input.gencode_region_type_tsv):
            print(f"Adding GENCODE region type TSV: {input.gencode_region_type_tsv}")
            files_to_merge.append(input.gencode_region_type_tsv)
        else:
            print(f"Warning: GENCODE region type TSV not found: {input.gencode_region_type_tsv}")
        
        # Add gnomAD TSV
        if os.path.exists(input.gnomad_tsv):
            print(f"Adding gnomAD TSV: {input.gnomad_tsv}")
            files_to_merge.append(input.gnomad_tsv)
        else:
            print(f"Warning: gnomAD TSV not found: {input.gnomad_tsv}")
        
        # Add patient_hpo_expanded.tsv if it exists
        # We always merge it if it exists because we excluded HPO columns from the base
        # This avoids duplicates and ensures we get the correct HPO data
        if os.path.exists(input.patient_hpo_expanded_tsv):
            print(f"Adding patient_hpo_expanded.tsv (HPO columns excluded from base to avoid duplicates)")
            files_to_merge.append(input.patient_hpo_expanded_tsv)
        else:
            print(f"Warning: patient_hpo_expanded.tsv not found: {input.patient_hpo_expanded_tsv}")
        
        # STEP 5: Use merge-columns CLI to merge all files
        print(f"Merging {len(files_to_merge)} files:")
        for f in files_to_merge:
            print(f"  - {f}")
        
        # Create output directory
        os.makedirs(os.path.dirname(output[0]), exist_ok=True)
        
        # Build merge-columns command
        # Use 'left' join type to preserve all columns from the first file (tmp_base_path)
        # This ensures chr, pos, ref, alt, kmeans_35 are preserved
        cmd = [
            "merge-columns",
            "--merge-column", "variant_id",
            "--join-type", "left",  # Use left join to preserve base columns
            "--output", output[0]
        ] + files_to_merge
        
        merge_start = time.time()
        print(f"[TIMING] Starting merge-columns at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        result = subprocess.run(cmd, capture_output=True, text=True)
        merge_end = time.time()
        merge_duration = merge_end - merge_start
        
        # Clean up temp file
        try:
            os.unlink(tmp_base_path)
        except:
            pass
        
        if result.returncode != 0:
            raise RuntimeError(
                f"merge-columns failed for {variant_dataset}/{model_dataset_name}/{cluster_id}:\n"
                f"stdout: {result.stdout}\n"
                f"stderr: {result.stderr}"
            )
        
        print(f"[TIMING] Finished merge-columns in {merge_duration:.1f}s")
        
        # STEP 6: Read merged dataframe and filter per-model columns to only prioritized models
        print("Reading merged dataframe and filtering per-model columns...")
        df_merged = pd.read_csv(output[0], sep='\t')
        original_col_count = len(df_merged.columns)
        
        # Identify models that are prioritized by at least one variant in the cluster
        prioritized_models = set()
        for col in df_merged.columns:
            if col.startswith('model_prioritized_by_any-'):
                model_name = col.replace('model_prioritized_by_any-', '')
                # Check if any variant in cluster is prioritized by this model
                # Handle multiple formats: True, "True", "true", 1, "1"
                col_values = df_merged[col]
                is_prioritized = (
                    (col_values.astype(str).str.lower() == 'true') |
                    (col_values == True) |
                    (col_values == 1) |
                    (col_values.astype(str) == '1')
                ).fillna(False)
                
                if is_prioritized.any():
                    prioritized_models.add(model_name)
        
        print(f"Found {len(prioritized_models)} prioritized models in cluster: {sorted(list(prioritized_models))[:10]}..." if len(prioritized_models) > 10 else f"Found {len(prioritized_models)} prioritized models in cluster: {sorted(list(prioritized_models))}")
        
        # Filter per-model columns to only include prioritized models
        # Keep all non-per-model columns
        non_per_model_cols = [
            col for col in df_merged.columns 
            if not (col.startswith('logfc-') or 
                   col.startswith('aaq-') or 
                   col.startswith('model_prioritized_by_any-'))
        ]
        
        # Filter model_prioritized_by_any-{model} columns
        model_prio_cols_to_keep = [
            col for col in df_merged.columns 
            if col.startswith('model_prioritized_by_any-') and 
               col.replace('model_prioritized_by_any-', '') in prioritized_models
        ]
        
        # Filter logfc-{model} columns
        logfc_cols_to_keep = [
            col for col in df_merged.columns 
            if col.startswith('logfc-') and 
               col.replace('logfc-', '') in prioritized_models
        ]
        
        # Filter aaq-{model} columns
        aaq_cols_to_keep = [
            col for col in df_merged.columns 
            if col.startswith('aaq-') and 
               col.replace('aaq-', '') in prioritized_models
        ]
        
        # Combine all columns to keep
        cols_to_keep = non_per_model_cols + model_prio_cols_to_keep + logfc_cols_to_keep + aaq_cols_to_keep
        
        # Filter dataframe
        df_merged = df_merged[cols_to_keep].copy()
        
        print(f"Filtered per-model columns: kept {len(model_prio_cols_to_keep)} model_prioritized_by_any-, {len(logfc_cols_to_keep)} logfc-, {len(aaq_cols_to_keep)} aaq- columns")
        print(f"Total columns after filtering: {len(df_merged.columns)} (was {original_col_count}, removed {original_col_count - len(df_merged.columns)} columns)")
        
        # STEP 7: Add HTML URL column using mitra-utils url
        print("Adding HTML URL column...")
        
        # Get URL pattern by calling mitra-utils once on a dummy variant
        # Then replace the dummy variant_id with each actual variant_id
        html_urls = []
        dummy_variant_id = "DUMMY_VARIANT_ID_FOR_URL_PATTERN"
        dummy_html_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset_name}/{cluster_id}/{dummy_variant_id}/00-summary.html"
        
        result = subprocess.run(
            ['mitra-utils', 'url', dummy_html_path],
            capture_output=True,
            text=True,
            check=True
        )
        # Strip whitespace (including newlines) from the URL pattern
        url_pattern = result.stdout.strip()
        
        # Generate URLs for all variants using the pattern
        for variant_id in df_merged['variant_id']:
            # Replace dummy variant_id with actual variant_id in URL
            url = url_pattern.replace(dummy_variant_id, variant_id)
            html_urls.append(url)
        
        # Add HTML URL column
        df_merged['variant_html_url'] = html_urls
        
        # Reorder columns so variant_html_url is the second column (after variant_id)
        cols = list(df_merged.columns)
        if 'variant_id' in cols and 'variant_html_url' in cols:
            # Remove variant_html_url from its current position
            cols.remove('variant_html_url')
            # Find variant_id index and insert variant_html_url right after it
            variant_id_idx = cols.index('variant_id')
            cols.insert(variant_id_idx + 1, 'variant_html_url')
            # Reorder dataframe
            df_merged = df_merged[cols]
            print(f"Reordered columns: variant_html_url is now the second column (after variant_id)")
        
        # STEP 8: Create merged closest_genes, closest_miRNA, and closest_lncRNA columns
        print("Creating merged closest elements columns...")
        
        def merge_closest_elements(row, name_prefix, max_count):
            """Merge all closest_{name_prefix}_{1-max_count} and closest_{name_prefix}_distance_{1-max_count} into a single formatted string."""
            element_distance_pairs = []
            
            # Collect all element-distance pairs from columns 1 to max_count
            for i in range(1, max_count + 1):
                element_col = f'closest_{name_prefix}_{i}'
                dist_col = f'closest_{name_prefix}_distance_{i}'
                
                if element_col in row.index and dist_col in row.index:
                    element = row[element_col]
                    distance = row[dist_col]
                    
                    # Skip if either element or distance is missing/NaN/empty
                    if pd.notna(element) and pd.notna(distance) and str(element).strip() != '' and str(element) != '.':
                        try:
                            # Convert distance to numeric, handling string representations
                            dist_value = float(distance) if distance != '.' else None
                            if dist_value is not None:
                                element_distance_pairs.append((str(element).strip(), dist_value))
                        except (ValueError, TypeError):
                            # Skip if distance can't be converted to float
                            pass
            
            # Sort by absolute distance
            element_distance_pairs.sort(key=lambda x: abs(x[1]))
            
            # Format as "ELEMENT (distance bp), ELEMENT (distance bp), ..."
            formatted_pairs = [f"{element} ({int(dist)} bp)" for element, dist in element_distance_pairs]
            
            return ", ".join(formatted_pairs) if formatted_pairs else ""
        
        # Apply the function to create merged columns for genes (1-8), miRNA (1-5), and lncRNA (1-5)
        df_merged['closest_genes'] = df_merged.apply(lambda row: merge_closest_elements(row, 'genes', 8), axis=1)
        print(f"Created merged closest_genes column: {df_merged['closest_genes'].notna().sum()}/{len(df_merged)} rows have data")
        
        df_merged['closest_miRNA'] = df_merged.apply(lambda row: merge_closest_elements(row, 'miRNA', 5), axis=1)
        print(f"Created merged closest_miRNA column: {df_merged['closest_miRNA'].notna().sum()}/{len(df_merged)} rows have data")
        
        df_merged['closest_lncRNA'] = df_merged.apply(lambda row: merge_closest_elements(row, 'lncRNA', 5), axis=1)
        print(f"Created merged closest_lncRNA column: {df_merged['closest_lncRNA'].notna().sum()}/{len(df_merged)} rows have data")
        
        # Remove individual closest element columns (keep only the merged columns)
        cols_to_remove = []
        
        # Remove closest_genes_{1-8} and closest_genes_distance_{1-8}
        for i in range(1, 9):
            cols_to_remove.extend([f'closest_genes_{i}', f'closest_genes_distance_{i}'])
        
        # Remove closest_miRNA_{1-5} and closest_miRNA_distance_{1-5}
        for i in range(1, 6):
            cols_to_remove.extend([f'closest_miRNA_{i}', f'closest_miRNA_distance_{i}'])
        
        # Remove closest_lncRNA_{1-5} and closest_lncRNA_distance_{1-5}
        for i in range(1, 6):
            cols_to_remove.extend([f'closest_lncRNA_{i}', f'closest_lncRNA_distance_{i}'])
        
        # Filter to only columns that actually exist in the dataframe
        cols_to_remove = [col for col in cols_to_remove if col in df_merged.columns]
        
        # Drop the columns
        df_merged = df_merged.drop(columns=cols_to_remove)
        print(f"Removed {len(cols_to_remove)} individual closest element columns (kept merged columns: closest_genes, closest_miRNA, closest_lncRNA)")
        
        # Reorder columns so merged closest element columns appear after variant_html_url
        cols = list(df_merged.columns)
        merged_closest_cols = ['closest_genes', 'closest_miRNA', 'closest_lncRNA']
        
        if 'variant_html_url' in cols:
            # Remove merged closest columns from their current positions
            for col in merged_closest_cols:
                if col in cols:
                    cols.remove(col)
            
            # Find variant_html_url index and insert merged closest columns right after it
            variant_html_url_idx = cols.index('variant_html_url')
            # Insert in reverse order so they appear in the desired order
            for i, col in enumerate(merged_closest_cols):
                if col in df_merged.columns:
                    cols.insert(variant_html_url_idx + 1 + i, col)
            
            # Reorder dataframe
            df_merged = df_merged[cols]
            print(f"Reordered columns: merged closest element columns are now after variant_html_url")
        
        # STEP 9: Create simons_searchlight_genes_in_closest_genes column
        print("Creating simons_searchlight_genes_in_closest_genes column...")
        
        # Read simons's searchlight genes file
        searchlight_genes_file = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/varbook-container/snakemake/refs/simonss_searchlight_genes.txt"
        searchlight_genes = set()
        
        if os.path.exists(searchlight_genes_file):
            with open(searchlight_genes_file, 'r') as f:
                for line in f:
                    # Clean up gene name (remove whitespace, trailing commas)
                    gene = line.strip().rstrip(',').strip()
                    if gene:
                        searchlight_genes.add(gene.upper())  # Store as uppercase for case-insensitive matching
            print(f"Loaded {len(searchlight_genes)} genes from simons's searchlight genes file")
        else:
            print(f"Warning: simons's searchlight genes file not found: {searchlight_genes_file}")
        
        def filter_searchlight_genes(closest_genes_str):
            """Filter closest_genes string to only include genes in searchlight_genes set."""
            if pd.isna(closest_genes_str) or str(closest_genes_str).strip() == '':
                return ""
            
            # Parse the closest_genes string (format: "GENE1 (dist bp), GENE2 (dist bp), ...")
            filtered_pairs = []
            
            # Split by comma to get individual gene-distance pairs
            pairs = str(closest_genes_str).split(',')
            
            for pair in pairs:
                pair = pair.strip()
                if not pair:
                    continue
                
                # Extract gene name (everything before the opening parenthesis)
                if '(' in pair:
                    gene_name = pair.split('(')[0].strip()
                    
                    # Check if gene is in searchlight genes (case-insensitive)
                    if gene_name.upper() in searchlight_genes:
                        # Keep the original pair format
                        filtered_pairs.append(pair)
            
            return ", ".join(filtered_pairs) if filtered_pairs else ""
        
        # Apply the function to create the filtered column
        df_merged['simons_searchlight_genes_in_closest_genes'] = df_merged['closest_genes'].apply(filter_searchlight_genes)
        print(f"Created simons_searchlight_genes_in_closest_genes column: {df_merged['simons_searchlight_genes_in_closest_genes'].apply(lambda x: len(str(x).strip()) > 0 if pd.notna(x) else False).sum()}/{len(df_merged)} rows have matching genes")
        
        # Reorder columns so simons_searchlight_genes_in_closest_genes appears after closest_genes
        cols = list(df_merged.columns)
        if 'closest_genes' in cols and 'simons_searchlight_genes_in_closest_genes' in cols:
            # Remove simons_searchlight_genes_in_closest_genes from its current position
            cols.remove('simons_searchlight_genes_in_closest_genes')
            # Find closest_genes index and insert simons_searchlight_genes_in_closest_genes right after it
            closest_genes_idx = cols.index('closest_genes')
            cols.insert(closest_genes_idx + 1, 'simons_searchlight_genes_in_closest_genes')
            # Reorder dataframe
            df_merged = df_merged[cols]
            print(f"Reordered columns: simons_searchlight_genes_in_closest_genes is now after closest_genes")
        
        # STEP 10: Create G2P intersection columns for genes, miRNA, and lncRNA
        print("Creating G2P intersection columns...")
        
        # Load G2P CSV file
        g2p_file = "/oak/stanford/groups/akundaje/airanman/refs/gene2phenotype/2025-12-12/AllG2P.csv"
        g2p_lookup = {}
        
        if os.path.exists(g2p_file):
            try:
                df_g2p = pd.read_csv(g2p_file, sep=',', low_memory=False)
                print(f"Loaded G2P file with {len(df_g2p)} entries")
                
                # Create lookup dictionary: gene_symbol -> list of G2P entries
                # Also index by previous gene symbols (semicolon-delimited)
                for idx, row in df_g2p.iterrows():
                    gene_symbol = str(row.get('gene symbol', '')).strip().upper()
                    if not gene_symbol or gene_symbol == 'NAN' or pd.isna(row.get('gene symbol')):
                        continue
                    
                    # Extract relevant fields
                    g2p_entry = {
                        'disease_name': str(row.get('disease name', '')).strip() if pd.notna(row.get('disease name')) else 'N/A',
                        'confidence': str(row.get('confidence', '')).strip() if pd.notna(row.get('confidence')) else 'N/A',
                        'molecular mechanism': str(row.get('molecular mechanism', '')).strip() if pd.notna(row.get('molecular mechanism')) else 'N/A',
                        'molecular mechanism categorisation': str(row.get('molecular mechanism categorisation', '')).strip() if pd.notna(row.get('molecular mechanism categorisation')) else 'N/A',
                        'variant_consequence': str(row.get('variant consequence', '')).strip() if pd.notna(row.get('variant consequence')) else 'N/A',
                        'variant_types': str(row.get('variant types', '')).strip() if pd.notna(row.get('variant types')) else 'N/A',
                        'phenotypes': str(row.get('phenotypes', '')).strip() if pd.notna(row.get('phenotypes')) else 'N/A',
                        'panel': str(row.get('panel', '')).strip() if pd.notna(row.get('panel')) else 'N/A'
                    }
                    
                    # Add to lookup by gene symbol (store as list to handle multiple entries per gene)
                    if gene_symbol not in g2p_lookup:
                        g2p_lookup[gene_symbol] = []
                    g2p_lookup[gene_symbol].append(g2p_entry)
                    
                    # Also add to lookup by previous gene symbols (semicolon-delimited)
                    previous_symbols_str = str(row.get('previous gene symbols', '')).strip()
                    if previous_symbols_str and previous_symbols_str != 'NAN' and pd.notna(row.get('previous gene symbols')):
                        # Split by semicolon and add each previous symbol to lookup
                        previous_symbols = [s.strip().upper() for s in previous_symbols_str.split(';') if s.strip()]
                        for prev_symbol in previous_symbols:
                            if prev_symbol:  # Make sure it's not empty
                                if prev_symbol not in g2p_lookup:
                                    g2p_lookup[prev_symbol] = []
                                g2p_lookup[prev_symbol].append(g2p_entry)
                
                print(f"Created G2P lookup with {len(g2p_lookup)} unique gene symbols (including previous gene symbols)")
            except Exception as e:
                print(f"Error loading G2P file: {e}")
                g2p_lookup = {}
        else:
            print(f"Warning: G2P file not found: {g2p_file}")
            g2p_lookup = {}
        
        def format_g2p_match(element_name, distance, g2p_entries):
            """Format a single element match with G2P metadata.
            
            Args:
                element_name: Name of the element (gene/miRNA/lncRNA)
                distance: Distance in base pairs
                g2p_entries: List of G2P entry dictionaries
            
            Returns:
                Formatted string: "ELEMENT (distance bp; entry1_field1: value1; entry1_field2: value2; ...; entry2_field1: value1; ...)"
                For duplicate entries, show each complete entry's information together, one after the other
            """
            if not g2p_entries:
                return f"{element_name} ({int(distance)} bp)"
            
            # Build metadata string - for duplicate entries, show each complete entry together
            # e.g., "disease_name: X; confidence: A; molecular mechanism: Z; ...; disease_name: Y; confidence: B; ..."
            field_order = ['disease_name', 'confidence', 'molecular mechanism', 'molecular mechanism categorisation', 
                          'variant_consequence', 'variant_types', 'phenotypes', 'panel']
            metadata_parts = []
            
            # For each entry, show all its fields together
            for entry in g2p_entries:
                entry_parts = []
                for field in field_order:
                    value = entry.get(field, 'N/A')
                    if value and str(value).strip() and value != 'N/A':
                        entry_parts.append(f"{field}: {value}")
                    else:
                        entry_parts.append(f"{field}: N/A")
                
                # Add all fields for this entry together
                if entry_parts:
                    metadata_parts.append("; ".join(entry_parts))
            
            # Combine all entries, separated by // to distinguish between different G2P entries
            if metadata_parts:
                metadata = " // ".join(metadata_parts)
                return f"{element_name} ({int(distance)} bp; {metadata})"
            else:
                return f"{element_name} ({int(distance)} bp)"
        
        def intersect_with_g2p(closest_elements_str, g2p_lookup_dict, element_type='gene'):
            """Intersect closest elements string with G2P lookup and format results.
            
            Args:
                closest_elements_str: String like "ELEMENT1 (dist bp), ELEMENT2 (dist bp), ..."
                g2p_lookup_dict: Dictionary mapping element names (uppercase) to lists of G2P entries
                element_type: Type of element ('gene', 'miRNA', 'lncRNA')
            
            Returns:
                Comma-separated string of formatted matches
            """
            if pd.isna(closest_elements_str) or str(closest_elements_str).strip() == '':
                return ""
            
            formatted_matches = []
            
            # Parse the closest_elements string (format: "ELEMENT1 (dist bp), ELEMENT2 (dist bp), ...")
            pairs = str(closest_elements_str).split(',')
            
            for pair in pairs:
                pair = pair.strip()
                if not pair:
                    continue
                
                # Extract element name and distance
                if '(' in pair:
                    element_name = pair.split('(')[0].strip()
                    # Extract distance (everything between parentheses)
                    distance_str = pair[pair.find('(')+1:pair.find(')')].replace('bp', '').strip()
                    
                    try:
                        distance = float(distance_str)
                    except (ValueError, TypeError):
                        continue
                    
                    # Check if element is in G2P lookup (case-insensitive)
                    element_upper = element_name.upper()
                    if element_upper in g2p_lookup_dict:
                        g2p_entries = g2p_lookup_dict[element_upper]
                        formatted_match = format_g2p_match(element_name, distance, g2p_entries)
                        formatted_matches.append(formatted_match)
            
            return ", ".join(formatted_matches) if formatted_matches else ""
        
        # Apply to genes, miRNA, and lncRNA
        df_merged['g2p_genes_in_closest_genes'] = df_merged['closest_genes'].apply(
            lambda x: intersect_with_g2p(x, g2p_lookup, 'gene')
        )
        print(f"Created g2p_genes_in_closest_genes column: {df_merged['g2p_genes_in_closest_genes'].apply(lambda x: len(str(x).strip()) > 0 if pd.notna(x) else False).sum()}/{len(df_merged)} rows have matching genes")
        
        df_merged['g2p_miRNA_in_closest_miRNA'] = df_merged['closest_miRNA'].apply(
            lambda x: intersect_with_g2p(x, g2p_lookup, 'miRNA')
        )
        print(f"Created g2p_miRNA_in_closest_miRNA column: {df_merged['g2p_miRNA_in_closest_miRNA'].apply(lambda x: len(str(x).strip()) > 0 if pd.notna(x) else False).sum()}/{len(df_merged)} rows have matching miRNA")
        
        df_merged['g2p_lncRNA_in_closest_lncRNA'] = df_merged['closest_lncRNA'].apply(
            lambda x: intersect_with_g2p(x, g2p_lookup, 'lncRNA')
        )
        print(f"Created g2p_lncRNA_in_closest_lncRNA column: {df_merged['g2p_lncRNA_in_closest_lncRNA'].apply(lambda x: len(str(x).strip()) > 0 if pd.notna(x) else False).sum()}/{len(df_merged)} rows have matching lncRNA")
        
        # Reorder columns so G2P columns appear after their respective closest element columns
        cols = list(df_merged.columns)
        g2p_cols = ['g2p_genes_in_closest_genes', 'g2p_miRNA_in_closest_miRNA', 'g2p_lncRNA_in_closest_lncRNA']
        
        # Place g2p_genes_in_closest_genes after closest_genes (or after simons_searchlight_genes_in_closest_genes if it exists)
        if 'g2p_genes_in_closest_genes' in cols:
            cols.remove('g2p_genes_in_closest_genes')
            if 'simons_searchlight_genes_in_closest_genes' in cols:
                insert_idx = cols.index('simons_searchlight_genes_in_closest_genes') + 1
            elif 'closest_genes' in cols:
                insert_idx = cols.index('closest_genes') + 1
            else:
                insert_idx = len(cols)
            cols.insert(insert_idx, 'g2p_genes_in_closest_genes')
        
        # Place g2p_miRNA_in_closest_miRNA after closest_miRNA
        if 'g2p_miRNA_in_closest_miRNA' in cols:
            cols.remove('g2p_miRNA_in_closest_miRNA')
            if 'closest_miRNA' in cols:
                insert_idx = cols.index('closest_miRNA') + 1
            else:
                insert_idx = len(cols)
            cols.insert(insert_idx, 'g2p_miRNA_in_closest_miRNA')
        
        # Place g2p_lncRNA_in_closest_lncRNA after closest_lncRNA
        if 'g2p_lncRNA_in_closest_lncRNA' in cols:
            cols.remove('g2p_lncRNA_in_closest_lncRNA')
            if 'closest_lncRNA' in cols:
                insert_idx = cols.index('closest_lncRNA') + 1
            else:
                insert_idx = len(cols)
            cols.insert(insert_idx, 'g2p_lncRNA_in_closest_lncRNA')
        
        # Reorder dataframe
        df_merged = df_merged[cols]
        print(f"Reordered columns: G2P intersection columns are now after their respective closest element columns")
        
        # Write updated TSV
        df_merged.to_csv(output[0], sep='\t', index=False)
        print(f"Added HTML URLs: {sum(1 for url in html_urls if url)}/{len(html_urls)} variants have URLs")
        
        end_time = time.time()
        duration = end_time - start_time
        print(f"[TIMING] Finished generate_human_readable_cluster_tsv for {variant_dataset}/{model_dataset_name}/{cluster_id} in {duration:.1f}s")
        print(f"✓ Created human-readable TSV: {output[0]}")
        
        # Create sentinel file to indicate this TSV is complete
        # This allows ensure_finemo_files to depend on all human-readable TSVs
        sentinel_dir = f"{OUTPUT_DIR}/{variant_dataset}/human_readable_spreadsheets/.sentinels"
        os.makedirs(sentinel_dir, exist_ok=True)
        sentinel_file = f"{sentinel_dir}/{model_dataset_name}.{cluster_id}.done"
        with open(sentinel_file, 'w') as f:
            f.write(f"{output[0]}\n")
        print(f"Created sentinel: {sentinel_file}")

rule ensure_human_readable_tsvs_complete:
    """Sentinel rule that ensures all human-readable TSV files are generated before finemo.
    
    This rule creates a sentinel file per variant_dataset after all individual cluster TSV files
    are complete. Other rules can depend on these sentinels instead of many individual TSV files.
    
    Note: This rule depends on the TSV files directly (which are checkpoint-aware via
    aggregate_plots_after_checkpoints), so it will automatically wait for checkpoints to complete.
    """
    input:
        # Collect all TSV files from all clusters for this variant_dataset
        # These TSV files are checkpoint-aware and will be created by generate_human_readable_cluster_tsv
        # Using TSV files instead of sentinels ensures the rules are scheduled in the DAG
        human_readable_tsvs = lambda w: _get_all_human_readable_tsvs_after_checkpoints(w.variant_dataset)
    output:
        sentinel = "{OUTPUT_DIR}/{variant_dataset}/human_readable_spreadsheets/.all_tsvs_complete"
    run:
        import os
        # Touch the sentinel file
        os.makedirs(os.path.dirname(output.sentinel), exist_ok=True)
        with open(output.sentinel, 'w') as f:
            f.write(f"All human-readable TSVs complete for {wildcards.variant_dataset}\n")
            f.write(f"Total TSV files: {len(input.human_readable_tsvs)}\n")
        print(f"Created sentinel for all human-readable TSVs: {output.sentinel} ({len(input.human_readable_tsvs)} TSV files)")

def _get_all_human_readable_tsvs_after_checkpoints(variant_dataset):
    """Get all human-readable TSV files for a variant_dataset after checkpoints complete.
    
    This function is checkpoint-aware and should be called after checkpoints execute.
    """
    import os
    tsv_files = []
    
    if variant_dataset not in VARIANT_DATASET_CONFIGS:
        return tsv_files
    
    # Reference checkpoint to ensure it completes first
    checkpoints.merge_variants_tsv.get(variant_dataset=variant_dataset)
    
    model_dataset_configs = get_model_datasets_list(variant_dataset)
    for model_dataset_config in model_dataset_configs:
        model_dataset_name = model_dataset_config.get('name')
        if not model_dataset_name:
            continue
        
        # Reference checkpoint to ensure it completes first
        checkpoints.cluster_model_dataset.get(
            variant_dataset=variant_dataset,
            model_dataset=model_dataset_name
        )
        
        # Check if clustered.tsv exists (created by checkpoint)
        clustered_tsv = f"data/{variant_dataset}.{model_dataset_name}.clustered.tsv"
        if not os.path.exists(clustered_tsv):
            continue
        
        # Get clusters for this model_dataset
        clusters = model_dataset_config.get('clusters', [])
        if not clusters:
            continue
        
        for cluster in clusters:
            if isinstance(cluster, dict):
                cluster_name = cluster.get('name', cluster.get('id'))
            else:
                cluster_name = cluster
            
            if cluster_name:
                tsv_file = f"{OUTPUT_DIR}/{variant_dataset}/human_readable_spreadsheets/{model_dataset_name}.{cluster_name}.human_readable.tsv"
                tsv_files.append(tsv_file)
    
    return tsv_files

def _get_human_readable_sentinel_for_variant_dataset(variant_dataset):
    """Get all human-readable TSV sentinel files for a specific variant_dataset.
    
    Returns list of sentinel file paths for all clusters in all model_datasets.
    
    Note: This function doesn't need to be checkpoint-aware because:
    1. The sentinels are created by generate_human_readable_cluster_tsv
    2. That rule depends on clustered.tsv (from cluster_model_dataset checkpoint)
    3. Snakemake will ensure checkpoints complete before those rules run
    4. So sentinels will only exist after checkpoints complete
    """
    import os
    sentinels = []
    
    if variant_dataset not in VARIANT_DATASET_CONFIGS:
        return sentinels
    
    model_dataset_configs = get_model_datasets_list(variant_dataset)
    for model_dataset_config in model_dataset_configs:
        model_dataset_name = model_dataset_config.get('name')
        if not model_dataset_name:
            continue
        
        clusters = model_dataset_config.get('clusters', [])
        if not clusters:
            continue
        
        for cluster in clusters:
            if isinstance(cluster, dict):
                cluster_name = cluster.get('name', cluster.get('id'))
            else:
                cluster_name = cluster
            
            if cluster_name:
                sentinel_file = f"{OUTPUT_DIR}/{variant_dataset}/human_readable_spreadsheets/.sentinels/{model_dataset_name}.{cluster_name}.done"
                sentinels.append(sentinel_file)
    
    return sentinels

def _get_all_human_readable_tsvs(variant_dataset):
    """Get all human-readable TSV files for a variant_dataset.
    
    This function is called after checkpoints complete, so it can safely
    read from the clustered.tsv files to determine which TSVs to generate.
    """
    import os
    tsv_files = []
    
    if variant_dataset not in VARIANT_DATASET_CONFIGS:
        return tsv_files
    
    model_dataset_configs = get_model_datasets_list(variant_dataset)
    for model_dataset_config in model_dataset_configs:
        model_dataset_name = model_dataset_config.get('name')
        if not model_dataset_name:
            continue
        
        # Check if clustered.tsv exists (created by checkpoint)
        clustered_tsv = f"data/{variant_dataset}.{model_dataset_name}.clustered.tsv"
        if not os.path.exists(clustered_tsv):
            continue
        
        # Get clusters for this model_dataset
        clusters = model_dataset_config.get('clusters', [])
        if not clusters:
            continue
        
        for cluster in clusters:
            if isinstance(cluster, dict):
                cluster_name = cluster.get('name', cluster.get('id'))
            else:
                cluster_name = cluster
            
            if cluster_name:
                tsv_file = f"{OUTPUT_DIR}/{variant_dataset}/human_readable_spreadsheets/{model_dataset_name}.{cluster_name}.human_readable.tsv"
                tsv_files.append(tsv_file)
    
    return tsv_files

rule filter_variants_for_cluster:
    """
    Filter to specific cluster and add back columns needed for variant plots.

    Lazy loading optimization: Only add heavy columns (logfc, aaq, prioritization)
    for the ~200 variants in this cluster, not all 3,880 variants.

    Steps:
    1. Read clustered TSV, filter to cluster
    2. Read split files (logfc, aaq, prioritization) - ONLY for cluster variants
    3. Merge with cluster variant data
    4. Keep only model columns matching model_dataset patterns

    Output: ~200 variants × ~100 columns (all columns needed for scatterplot/barplot)
    """
    output:
        "data/{variant_dataset}.{model_dataset}.{cluster_id}.filtered.tsv"
    run:
        import pandas as pd
        import fnmatch
        import os
        import time

        variant_dataset = wildcards.variant_dataset
        model_dataset_name = wildcards.model_dataset
        cluster_id = wildcards.cluster_id
        
        start_time = time.time()
        print(f"[TIMING] Starting filter_variants_for_cluster for {variant_dataset}/{model_dataset_name}/{cluster_id} at {time.strftime('%Y-%m-%d %H:%M:%S')}")

        # Get model patterns from config
        model_dataset_config = None
        for config in (get_model_datasets_list(variant_dataset) if variant_dataset in VARIANT_DATASET_CONFIGS else []):
            if config['name'] == model_dataset_name:
                model_dataset_config = config
                break

        if model_dataset_config is None:
            raise ValueError(f"No config found for model_dataset '{model_dataset_name}' in variant_dataset '{variant_dataset}'")

        model_patterns = model_dataset_config.get('models', [])

        # Determine dataset prefix for split files
        dataset_prefix = None
        for pattern in model_patterns:
            prefix = pattern.replace('*', '').strip('_')
            if prefix:
                dataset_prefix = prefix
                break

        # STEP 1: Read clustered TSV and filter to cluster
        clustered_file = f"data/{variant_dataset}.{model_dataset_name}.clustered.tsv"
        # Suppress verbose output - timing message already indicates what's happening
        # print(f"Reading clustered TSV: {clustered_file}")
        df_clustered = pd.read_csv(clustered_file, sep='\t')

        # Parse cluster number from cluster_id
        # Supports two formats:
        #   1. "cluster_3" → 3
        #   2. "microglia-specific cluster (#3)" → 3
        cluster_num = None
        if cluster_id.startswith('cluster_'):
            try:
                cluster_num = int(cluster_id.replace('cluster_', ''))
            except ValueError:
                pass
        else:
            # Try to extract number from parentheses (e.g., "(#3)" → 3)
            import re
            match = re.search(r'\(#(\d+)\)', cluster_id)
            if match:
                cluster_num = int(match.group(1))

        if cluster_num is None:
            raise ValueError(f"Could not parse cluster number from cluster_id: {cluster_id}")

        # Filter to cluster
        if 'kmeans_35' not in df_clustered.columns:
            raise ValueError(f"kmeans_35 column not found in {clustered_file}")

        df_cluster = df_clustered[df_clustered['kmeans_35'] == cluster_num].copy()
        cluster_variant_ids = df_cluster['variant_id'].tolist()

        # Suppress verbose output that gets repeated many times
        # print(f"Filtered to {len(cluster_variant_ids)} variants in {cluster_id} (from {len(df_clustered)} total)")

        # STEP 2: Add back columns from split files (ONLY for cluster variants - lazy loading!)
        # General (for ref/alt columns if missing)
        general_file = f"{SPLITS_DIR}/{variant_dataset}.general.tsv"
        if not os.path.exists(general_file):
            general_file = f"{LOCAL_SPLITS_DIR}/{variant_dataset}.general.tsv"
        
        df_general = None
        if os.path.exists(general_file):
            print(f"Reading general file: {general_file}")
            df_general = pd.read_csv(general_file, sep='\t')
            df_general = df_general[df_general['variant_id'].isin(cluster_variant_ids)]
            # Only keep ref/alt columns if they exist
            general_cols = ['variant_id']
            if 'ref' in df_general.columns:
                general_cols.append('ref')
            if 'alt' in df_general.columns:
                general_cols.append('alt')
            if 'allele1' in df_general.columns:
                general_cols.append('allele1')
            if 'allele2' in df_general.columns:
                general_cols.append('allele2')
            df_general = df_general[[c for c in general_cols if c in df_general.columns]]
        
        # Logfc
        logfc_file = f"{SPLITS_DIR}/{variant_dataset}.logfc.tsv"
        if not os.path.exists(logfc_file):
            logfc_file = f"{LOCAL_SPLITS_DIR}/{variant_dataset}.logfc.tsv"
        print(f"Reading logfc file: {logfc_file}")
        df_logfc = pd.read_csv(logfc_file, sep='\t')
        df_logfc = df_logfc[df_logfc['variant_id'].isin(cluster_variant_ids)]

        # Aaq
        aaq_file = f"{SPLITS_DIR}/{variant_dataset}.aaq.tsv"
        if not os.path.exists(aaq_file):
            aaq_file = f"{LOCAL_SPLITS_DIR}/{variant_dataset}.aaq.tsv"
        print(f"Reading aaq file: {aaq_file}")
        df_aaq = pd.read_csv(aaq_file, sep='\t')
        df_aaq = df_aaq[df_aaq['variant_id'].isin(cluster_variant_ids)]

        # Prioritization types (any, promoter, peak, outofpeak)
        # NOTE: We only load peak, promoter, outofpeak for scatterplot labels (skip 'any')
        prioritization_dfs = []
        for prio_type in ['peak', 'promoter', 'outofpeak']:
            # Try model-specific file first
            prio_file = f"{SPLITS_DIR}/{variant_dataset}.model_prioritized_by_{prio_type}-{dataset_prefix}.tsv"
            if not os.path.exists(prio_file):
                prio_file = f"{LOCAL_SPLITS_DIR}/{variant_dataset}.model_prioritized_by_{prio_type}-{dataset_prefix}.tsv"

            # If model-specific not found, try non-model-specific file
            if not os.path.exists(prio_file):
                prio_file = f"{SPLITS_DIR}/{variant_dataset}.model_prioritized_by_{prio_type}.tsv"
            if not os.path.exists(prio_file):
                prio_file = f"{LOCAL_SPLITS_DIR}/{variant_dataset}.model_prioritized_by_{prio_type}.tsv"

            if os.path.exists(prio_file):
                print(f"Reading {prio_type} prioritization file: {prio_file}")
                df_prio = pd.read_csv(prio_file, sep='\t')
                df_prio = df_prio[df_prio['variant_id'].isin(cluster_variant_ids)]
                prioritization_dfs.append(df_prio)
            else:
                print(f"Warning: {prio_type} prioritization file not found: {prio_file}")

        # STEP 3: Merge with cluster data
        df = df_cluster.copy()
        df = df.merge(df_logfc, on='variant_id', how='left')
        df = df.merge(df_aaq, on='variant_id', how='left')

        for df_prio in prioritization_dfs:
            df = df.merge(df_prio, on='variant_id', how='left')
        
        # Merge general file to ensure ref/alt columns are present (after all other merges)
        if df_general is not None and len(df_general) > 0:
            # Handle both ref/alt and allele1/allele2 naming
            if 'ref' not in df.columns or 'alt' not in df.columns:
                if 'ref' in df_general.columns and 'alt' in df_general.columns:
                    df = df.merge(df_general[['variant_id', 'ref', 'alt']], on='variant_id', how='left')
                elif 'allele1' in df_general.columns and 'allele2' in df_general.columns:
                    df = df.merge(df_general[['variant_id', 'allele1', 'allele2']], on='variant_id', how='left')
                    # Rename to ref/alt for consistency
                    df = df.rename(columns={'allele1': 'ref', 'allele2': 'alt'})

        # STEP 4: Keep only model columns matching patterns
        matched_models = []
        for col in df.columns:
            # Extract model name from columns like "logfc-{model}", "aaq-{model}", etc.
            for prefix in ['logfc-', 'aaq-', 'model_prioritized_by_any-', 'model_prioritized_by_promoter-',
                          'model_prioritized_by_peak-', 'model_prioritized_by_outofpeak-']:
                if col.startswith(prefix):
                    model_name = col.replace(prefix, '')
                    # Check if matches pattern
                    for pattern in model_patterns:
                        if fnmatch.fnmatch(model_name, pattern):
                            if model_name not in matched_models:
                                matched_models.append(model_name)
                            break

        # Build keep_cols
        keep_cols = ['variant_id', 'chr', 'pos', 'kmeans_35', 'organs']
        
        # Add ref/alt columns (handle both naming conventions)
        if 'ref' in df.columns and 'alt' in df.columns:
            keep_cols.extend(['ref', 'alt'])
        elif 'allele1' in df.columns and 'allele2' in df.columns:
            keep_cols.extend(['allele1', 'allele2'])
            # Rename to ref/alt for consistency
            df = df.rename(columns={'allele1': 'ref', 'allele2': 'alt'})
            keep_cols.remove('allele1')
            keep_cols.remove('allele2')
            keep_cols.extend(['ref', 'alt'])
        
        for model in matched_models:
            for col in df.columns:
                if col.endswith(f'-{model}'):
                    keep_cols.append(col)

        # Add HPO columns
        hpo_cols = [c for c in df.columns if c.startswith('has_hpo_')]
        keep_cols.extend(hpo_cols)

        # Add prioritization columns (for scatterplot labels)
        prio_cols = [c for c in df.columns if c.startswith('model_prioritized_by_')]
        keep_cols.extend(prio_cols)

        # Filter to keep_cols that exist
        keep_cols = [c for c in keep_cols if c in df.columns]
        df = df[keep_cols]
        
        # Final check: ensure ref/alt columns exist
        if 'ref' not in df.columns or 'alt' not in df.columns:
            raise ValueError(f"Missing ref/alt columns in filtered TSV. Available columns: {list(df.columns)}")

        # Write output
        df.to_csv(output[0], sep='\t', index=False)
        end_time = time.time()
        duration = end_time - start_time
        print(f"[TIMING] Finished filter_variants_for_cluster for {variant_dataset}/{model_dataset_name}/{cluster_id} in {duration:.1f}s")
        # Suppress verbose output - timing message already indicates completion
        # print(f"Wrote cluster-filtered TSV: {len(df)} variants × {len(df.columns)} columns")
        # print(f"  Matched {len(matched_models)} models for patterns {model_patterns}")

# DEPRECATED: merge_heatmap_data rule removed
# Replaced by: filter_and_score_for_model_dataset → cluster_model_dataset → prepare_heatmap_data
# The new pipeline auto-generates score-{model} columns and uses KMeans clustering

rule merge_comprehensive_variants:
    """
    Merge comprehensive variant data for HTML report generation.

    This creates a single TSV with:
    - General variant information (coordinates, alleles, etc.)
    - Patient HPO annotations (expanded)
    - Closest genomic elements
    - All other metadata needed for interactive HTML table

    The output is used by html-live generation for the variants table.
    """
    input:
        general = SPLITS_DIR + "/{variant_dataset}.general.tsv",
        patient_hpo = SPLITS_DIR + "/{variant_dataset}.patient_hpo_expanded.tsv",
        closest_elements = SPLITS_DIR + "/{variant_dataset}.closest_elements.tsv"
    output:
        "data/{variant_dataset}.comprehensive.tsv"
    shell:
        """
        START_TIME=$(date +%s)
        echo "[TIMING] Starting merge_comprehensive_variants at $(date)" >&2
        merge-columns \
            --merge-column variant_id \
            --join-type outer \
            --output '{output}' \
            '{input.general}' \
            '{input.patient_hpo}' \
            '{input.closest_elements}'
        END_TIME=$(date +%s)
        DURATION=$((END_TIME - START_TIME))
        echo "[TIMING] Finished merge_comprehensive_variants in $$DURATIONs at $(date)" >&2
        """

# ----------------------------------------------------------------------------
# Heatmap generation (model_dataset level, before clustering)
# ----------------------------------------------------------------------------

def get_heatmap_config(wildcards):
    """Get the model_dataset config for this heatmap."""
    variant_dataset = wildcards.variant_dataset
    model_dataset_name = wildcards.model_dataset

    if variant_dataset not in VARIANT_DATASET_CONFIGS:
        raise ValueError(f"Unknown variant_dataset: {variant_dataset}")

    for config in get_model_datasets_list(variant_dataset):
        if config['name'] == model_dataset_name:
            return config
    raise ValueError(f"Unknown model_dataset '{model_dataset_name}' for variant_dataset '{variant_dataset}'")

rule create_model_dataset_intro:
    """Create intro.md file for each model_dataset (single space)."""
    output:
        intro_md = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/00-intro.md"
    wildcard_constraints:
        model_dataset="[^/]+"  # No slashes - prevents matching cluster paths
    shell:
        'echo " " > "{output.intro_md}"'

rule generate_heatmap:
    """Generate heatmap at model_dataset level with numeric prefixes and HPO annotations.

    UPDATED: Now uses .clustered.tsv with pre-computed kmeans_35 clusters.
    - Input: data/{variant_dataset}.{model_dataset}.clustered.tsv
    - Generated by cluster_model_dataset checkpoint
    - Heatmap column: score (matches score-{model} pattern)
    - logfc/aaq args are placeholders (not used by heatmap, but required by CLI)
    """
    input:
        variants_tsv = "data/{variant_dataset}.{model_dataset}.clustered.tsv"
    output:
        png = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/01-heatmap.png",
        md = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/01-heatmap.md"
    params:
        variant_dataset_name = lambda w: w.variant_dataset,
        model_dataset_name = lambda w: w.model_dataset,
        models_args = lambda w: get_models_args(get_heatmap_config(w))
    shell:
        """
        START_TIME=$(date +%s)
        echo "[TIMING] Starting generate_heatmap for {params.model_dataset_name} at $(date)" >&2
        {VARBOOK_CMD} plot models heatmap \
            "{input.variants_tsv}" \
            score \
            logfc \
            aaq \
            --variant-datasets "{params.variant_dataset_name}" \
            --model-dataset "{params.model_dataset_name}" \
            --models {params.models_args} \
            --kmeans-clusters 35 \
            --row-labels "has_hpo_Abnormality of the nervous system:bool" \
            --row-labels-mode cluster-mean \
            -o "{output.png}"
        END_TIME=$(date +%s)
        DURATION=$((END_TIME - START_TIME))
        echo "[TIMING] Finished generate_heatmap for {params.model_dataset_name} in $$DURATIONs at $(date)" >&2

        # Generate URLs for heatmap
        mitra-utils url "{output.png}" || true
        mitra-utils url "{output.md}" || true

        # Create markdown file
        echo "# Heatmap: {params.model_dataset_name}" > "{output.md}"
        echo "" >> "{output.md}"
        echo "![Heatmap](01-heatmap.png)" >> "{output.md}"
        """

rule generate_upset_plot:
    """Generate UpSet plot for HPO overlaps at model_dataset level.

    Visualizes overlaps between HPO phenotype annotations across variants.
    Uses .clustered.tsv (generated by cluster_model_dataset checkpoint) with boolean has_hpo_* columns.
    """
    input:
        variants_tsv = "data/{variant_dataset}.{model_dataset}.clustered.tsv"
    output:
        png = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/upset/hpo_overlaps.png",
        svg = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/upset/hpo_overlaps.svg",
        md = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/upset/hpo_overlaps.md"
    params:
        variant_dataset_name = lambda w: w.variant_dataset,
        model_dataset_name = lambda w: w.model_dataset
    shell:
        """
        START_TIME=$(date +%s)
        echo "[TIMING] Starting generate_upset_plot for {params.model_dataset_name} at $(date)" >&2
        {VARBOOK_CMD} plot models upset \
            --bool-cols "has_hpo_*" \
            --variant-datasets "{params.variant_dataset_name}" \
            --model-dataset "{params.model_dataset_name}" \
            --clean-labels \
            --min-subset-size 5 \
            -o "{output.png}" \
            "{input.variants_tsv}" \
            variant_id
        END_TIME=$(date +%s)
        DURATION=$((END_TIME - START_TIME))
        echo "[TIMING] Finished generate_upset_plot for {params.model_dataset_name} in $$DURATIONs at $(date)" >&2
        
        # Generate URLs for upset plot
        mitra-utils url "{output.png}" || true
        mitra-utils url "{output.svg}" || true
        mitra-utils url "{output.md}" || true
        """

rule generate_upset_plot_clustered:
    """Generate UpSet plot for HPO overlaps at cluster level.

    Visualizes overlaps between HPO phenotype annotations for variants within a specific cluster.
    Uses the cluster-filtered TSV with boolean has_hpo_* columns.
    """
    input:
        variants_tsv = "data/{variant_dataset}.{model_dataset}.{cluster_id}.filtered.tsv"
    output:
        png = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{cluster_id}/upset/hpo_overlaps.png",
        svg = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{cluster_id}/upset/hpo_overlaps.svg",
        md = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{cluster_id}/upset/hpo_overlaps.md"
    params:
        variant_datasets_param = lambda w: get_variant_datasets_param(w.variant_dataset, w.cluster_id),
        model_dataset_name = lambda w: w.model_dataset
    shell:
        """
        START_TIME=$(date +%s)
        echo "[TIMING] Starting generate_upset_plot_clustered for {params.model_dataset_name}/{wildcards.cluster_id} at $(date)" >&2
        {VARBOOK_CMD} plot models upset \
            --bool-cols "has_hpo_*" \
            --variant-datasets "{params.variant_datasets_param}" \
            --model-dataset "{params.model_dataset_name}" \
            --clean-labels \
            --min-subset-size 3 \
            -o "{output.png}" \
            "{input.variants_tsv}" \
            variant_id
        END_TIME=$(date +%s)
        DURATION=$((END_TIME - START_TIME))
        echo "[TIMING] Finished generate_upset_plot_clustered for {params.model_dataset_name}/{wildcards.cluster_id} in $$DURATIONs at $(date)" >&2
        
        # Generate URLs for upset plot
        mitra-utils url "{output.png}" || true
        mitra-utils url "{output.svg}" || true
        mitra-utils url "{output.md}" || true
        """


# ----------------------------------------------------------------------------
# Per-variant plots (flattened structure with numeric prefixes)
# ----------------------------------------------------------------------------

# Specify rule order to resolve ambiguity:
# 1. Variant intro (most specific - has variant_id)
# 2. Cluster intro (middle - has variant_subdataset)
# 3. Model dataset intro (least specific - only has model_dataset)
ruleorder: create_variant_intro > create_cluster_intro > create_model_dataset_intro

rule create_cluster_intro:
    """Create intro.md file for each cluster/variant_subdataset (single space)."""
    output:
        intro_md = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/00-intro.md"
    wildcard_constraints:
        model_dataset="[^/]+",  # No slashes in model_dataset
        variant_subdataset=".*[cC]luster.*"  # Must contain "cluster" to distinguish from model_dataset
    shell:
        'echo " " > "{output.intro_md}"'

rule create_variant_intro:
    """Create intro.md file for each variant (single space)."""
    output:
        intro_md = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/{variant_id}/00-intro.md"
    shell:
        'echo " " > "{output.intro_md}"'

# COMMENTED OUT: Original rule that only uses models from current model_dataset
# rule generate_variant_barplot:
#     """Generate model-specificity barplot for a variant with flattened structure."""
#     input:
#         variants_tsv = VARIANTS_TSV
#     output:
#         png = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/{variant_id}/01-model-specificity-barplot.png",
#         md = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/{variant_id}/01-model-specificity-barplot.md",
#         before_md = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/{variant_id}/01-model-specificity-barplot.before.md",
#         after_md = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/{variant_id}/01-model-specificity-barplot.after.md"
#     params:
#         variant = lambda w: w.variant_id,
#         variant_datasets_param = lambda w: f"{w.variant_dataset}:{w.variant_subdataset}" if hasattr(w, 'variant_subdataset') and w.variant_subdataset else w.variant_dataset,
#         model_dataset_name = lambda w: w.model_dataset,
#         models_args = lambda w: get_models_args(get_variant_config(w))
#     shell:
#         """
#         {VARBOOK_CMD} plot variant model-specificity-barplot \
#             {input.variants_tsv} \
#             "{params.variant}" \
#             --variant-datasets "{params.variant_datasets_param}" \
#             --model-dataset "{params.model_dataset_name}" \
#             --models {params.models_args} \
#             -o "{output.png}"
#
#         # Create markdown reference
#         echo "![Model Specificity Barplot](01-model-specificity-barplot.png)" > "{output.md}"
#
#         # Create before/after markdown files
#         echo " " > "{output.before_md}"
#         echo "## Model Specificity Barplot" > "{output.after_md}"
#         """

rule create_barplot_merged_tsv:
    """
    Create merged TSV with prioritization data and organs metadata for barplot.

    Merges KUN_HDMA and KUN_FB prioritization TSVs with model_tissues.tsv
    to create a file that has both per-model prioritization columns and organs.
    """
    input:
        kun_hdma = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/splits/broad.model_prioritized_by_any-KUN_HDMA.tsv",
        kun_fb = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/splits/broad.model_prioritized_by_any-KUN_FB.tsv",
        model_tissues = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/splits/broad.model_tissues.tsv"
    output:
        merged_tsv = "data/barplot_merged_kun_hdma_fb.tsv"
    run:
        import pandas as pd

        # Read the three files
        df_hdma = pd.read_csv(input.kun_hdma, sep='\t')
        df_fb = pd.read_csv(input.kun_fb, sep='\t')
        df_tissues = pd.read_csv(input.model_tissues, sep='\t')

        # Merge KUN_HDMA and KUN_FB on variant_id
        df_merged = df_hdma.merge(df_fb, on='variant_id', how='outer')

        # Create organs column by mapping model names to organs
        # The organs column should be a dict-like structure mapping model_name -> organs
        model_to_organs = dict(zip(df_tissues['model_name'], df_tissues['organs']))

        # Add organs column as a serialized dict for barplot to use
        import json
        df_merged['organs'] = df_merged.apply(
            lambda row: json.dumps(model_to_organs),
            axis=1
        )

        # Save merged TSV
        df_merged.to_csv(output.merged_tsv, sep='\t', index=False)
        print(f"Created merged TSV with {len(df_merged)} variants and organs metadata")

rule generate_variant_barplot:
    """
    Generate comprehensive model-specificity barplot for a variant.

    TEMPORARY CHANGE: This rule passes ALL KUN_HDMA and KUN_FB models to show
    the full organ/tissue context. The barplot displays:
    - Total number of models per organ/tissue type
    - Number of prioritized models per organ (highlighted)

    Uses a merged TSV that combines prioritization data with organs metadata.
    """
    wildcard_constraints:
        variant_id=r"chr[^/]+:[^/]+:[^/]+:[^/]+"  # chr:pos:ref:alt format
    input:
        variants_tsv = "data/barplot_merged_kun_hdma_fb.tsv"
    output:
        png = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/{variant_id}/01-model-specificity-barplot.png",
        md = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/{variant_id}/01-model-specificity-barplot.md"
    params:
        variant = lambda w: w.variant_id,
        variant_dataset = lambda w: w.variant_dataset,
        variant_datasets_param = lambda w: f"{w.variant_dataset}:{w.variant_subdataset}" if hasattr(w, 'variant_subdataset') and w.variant_subdataset else w.variant_dataset,
        model_dataset_name = lambda w: w.model_dataset,
        # Get ALL KUN_HDMA and KUN_FB models
        models_args = lambda w: get_all_models_args_for_barplot(w.variant_dataset)
    shell:
        """
        START_TIME=$(date +%s)
        echo "[TIMING] Starting generate_variant_barplot for {params.variant} at $(date)" >&2
        {VARBOOK_CMD} plot variant model-specificity-barplot \
            {input.variants_tsv} \
            "{params.variant}" \
            --variant-datasets "{params.variant_datasets_param}" \
            --model-dataset "{params.model_dataset_name}" \
            --models {params.models_args} \
            -o "{output.png}"
        END_TIME=$(date +%s)
        DURATION=$((END_TIME - START_TIME))
        echo "[TIMING] Finished generate_variant_barplot for {params.variant} in $$DURATIONs at $(date)" >&2

        # Create markdown reference
        echo "![Model Specificity Barplot](01-model-specificity-barplot.png)" > "{output.md}"
        """

def get_finemo_dummy_file():
    """Get path to a dummy file that always exists (for optional finemo dependencies)."""
    import os
    dummy_file = "/tmp/.finemo_dummy_always_exists"
    if not os.path.exists(dummy_file):
        os.makedirs(os.path.dirname(dummy_file), exist_ok=True)
        with open(dummy_file, 'w') as f:
            f.write("")
    return dummy_file

# Cache for optional models to avoid repeated config access
_optional_models_cache = None

def _get_optional_models():
    """Get optional models from config (cached)."""
    global _optional_models_cache
    if _optional_models_cache is None:
        try:
            models_config = config.get('models', None)
            if models_config:
                if isinstance(models_config, str):
                    _optional_models_cache = [m.strip() for m in models_config.split(',') if m.strip()]
                elif isinstance(models_config, list):
                    _optional_models_cache = models_config
                else:
                    _optional_models_cache = []
            else:
                _optional_models_cache = []
        except Exception:
            _optional_models_cache = []
    return _optional_models_cache

def get_finemo_input_for_profile(wildcards):
    """Get finemo input for profile rule.
    
    For optional models (that don't prioritize the variant), returns a dummy file
    to avoid triggering unnecessary finemo annotation. For prioritized models,
    returns the actual finemo TSV path.
    
    If optional models are specified in config.models, always require finemo for those models.
    """
    import os
    
    # Check if this model is in the optional models list from config (cached)
    optional_models = _get_optional_models()
    
    # If this model is in optional models, always require finemo
    if wildcards.model_name in optional_models:
        return get_finemo_tsv_path(wildcards)
    
    try:
        prioritized_models = get_prioritized_models_for_variant(
            wildcards.variant_dataset,
            wildcards.variant_id,
            wildcards.model_dataset
        )
        if wildcards.model_name in prioritized_models:
            return get_finemo_tsv_path(wildcards)
        else:
            # Model doesn't prioritize variant - use dummy file to avoid finemo dependency
            return get_finemo_dummy_file()
    except Exception:
        # If we can't determine prioritization (e.g., files don't exist yet),
        # default to requiring the finemo file to ensure finemo runs for optional models
        return get_finemo_tsv_path(wildcards)

def get_finemo_ready_for_profile(wildcards):
    """Get finemo file path if needed, or a dummy file if not.
    
    For optional models (that don't prioritize the variant), returns a dummy file
    to avoid triggering unnecessary finemo annotation. For prioritized models,
    returns the specific model's finemo file path (not the global sentinel) to ensure
    finemo runs for that specific model when optional models are specified.
    
    If optional models are specified in config.models, always require finemo for those models.
    """
    import os
    
    # Check if this model is in the optional models list from config (cached)
    optional_models = _get_optional_models()
    
    # If this model is in optional models, always require finemo
    if wildcards.model_name in optional_models:
        return get_finemo_tsv_path(wildcards)
    
    try:
        prioritized_models = get_prioritized_models_for_variant(
            wildcards.variant_dataset,
            wildcards.variant_id,
            wildcards.model_dataset
        )
        if wildcards.model_name in prioritized_models:
            # Return the specific model's finemo file to ensure it gets created
            # This ensures finemo runs for optional models that prioritize the variant
            return get_finemo_tsv_path(wildcards)
        else:
            # Model doesn't prioritize variant - use dummy file to avoid finemo dependency
            return get_finemo_dummy_file()
    except Exception:
        # If we can't determine (e.g., files don't exist yet), check if this is an optional model
        # by checking if the finemo file exists. If it doesn't exist and we can't determine
        # prioritization, use dummy file to allow the rule to proceed
        finemo_path = get_finemo_tsv_path(wildcards)
        if os.path.exists(finemo_path):
            # Finemo file exists, use it (model likely prioritizes variant)
            return finemo_path
        else:
            # Finemo file doesn't exist and we can't determine prioritization
            # Use dummy file to allow the rule to proceed without blocking on finemo
            return get_finemo_dummy_file()

# rule generate_variant_profile:
#     """Generate profile plot for a variant and specific model.
    
#     If the model prioritizes the variant, generates profile with finemo motifs.
#     If the model doesn't prioritize the variant (e.g., optional model), generates
#     profile without requiring finemo files (avoids triggering unnecessary finemo annotation).
#     """
#     input:
#         variants_tsv = lambda w: get_variants_tsv_path(w.variant_dataset),
#         model_paths = MODEL_PATHS_TSV,
#         motifs_tsv = get_finemo_input_for_profile,
#         finemo_ready = get_finemo_ready_for_profile
#     output:
#         png = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/{variant_id}/03-profile-{model_name}.png",
#         md = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/{variant_id}/03-profile-{model_name}.md"
#     params:
#         variant = lambda w: w.variant_id,
#         model = lambda w: w.model_name,
#         variant_dataset = lambda w: w.variant_dataset,
#         model_dataset = lambda w: w.model_dataset
#     run:
#         # Check if this model prioritizes this variant
#         prioritized_models = get_prioritized_models_for_variant(
#             params.variant_dataset,
#             params.variant,
#             params.model_dataset
#         )

#         import os
        
#         # Check if finemo file is a dummy file (indicates optional model)
#         # The dummy file path is always /tmp/.finemo_dummy_always_exists
#         is_dummy_finemo = input.motifs_tsv == "/tmp/.finemo_dummy_always_exists" or os.path.basename(input.motifs_tsv) == ".finemo_dummy_always_exists"
        
#         # Check if this model is in optional models from config
#         is_optional_model = False
#         try:
#             models_config = config.get('models', None)
#             if models_config:
#                 if isinstance(models_config, str):
#                     optional_models_list = [m.strip() for m in models_config.split(',') if m.strip()]
#                 elif isinstance(models_config, list):
#                     optional_models_list = models_config
#                 else:
#                     optional_models_list = []
#                 is_optional_model = params.model in optional_models_list
#         except:
#             pass
        
#         # Use finemo if: (1) model prioritizes variant, OR (2) model is in optional models list
#         # AND finemo is not a dummy file
#         # Note: If finemo file exists (not dummy), we should use it regardless of prioritization
#         use_finemo = not is_dummy_finemo and (params.model in prioritized_models or is_optional_model)
        
#         # Debug output
#         print(f"DEBUG generate_variant_profile: model={params.model}, prioritized={params.model in prioritized_models}, is_optional={is_optional_model}, is_dummy={is_dummy_finemo}, use_finemo={use_finemo}")
#         print(f"DEBUG generate_variant_profile: motifs_tsv={input.motifs_tsv}")
#         if not is_dummy_finemo:
#             import os
#             if os.path.exists(input.motifs_tsv):
#                 import pandas as pd
#                 try:
#                     finemo_df = pd.read_csv(input.motifs_tsv, sep='\t', nrows=5)
#                     print(f"DEBUG generate_variant_profile: Finemo file exists, has {len(finemo_df)} rows (showing first 5), columns: {list(finemo_df.columns)}")
#                     if 'variant_id' in finemo_df.columns:
#                         print(f"DEBUG generate_variant_profile: Sample variant IDs in finemo: {list(finemo_df['variant_id'].head())}")
#                         variant_in_finemo = params.variant in finemo_df['variant_id'].values
#                         print(f"DEBUG generate_variant_profile: Variant {params.variant} in finemo file: {variant_in_finemo}")
#                 except Exception as e:
#                     print(f"DEBUG generate_variant_profile: Error reading finemo file: {e}")
#             else:
#                 print(f"DEBUG generate_variant_profile: Finemo file does not exist: {input.motifs_tsv}")
        
#         if use_finemo:
#             # Generate profile plot with motifs (finemo required)
#             # Try GPUs in order: 1, 3, 0, 2 (based on typical free memory), then fall back to CPU
#             shell("""
#                 START_TIME=$(date +%s)
#                 echo "[TIMING] Starting generate_variant_profile for {params.variant}/{params.model} at $(date)" >&2
#                 (export CUDA_VISIBLE_DEVICES=1 && {VARBOOK_CMD} plot variant profiles \
#                     "{input.variants_tsv}" \
#                     "{params.variant}" \
#                     "{params.model}" \
#                     --model-paths-tsv "{input.model_paths}" \
#                     --motifs-tsv "{input.motifs_tsv}" \
#                     -o "{output.png}" \
#                     --n-shuffles 20 \
#                     --device cuda) || \
#                 (export CUDA_VISIBLE_DEVICES=3 && {VARBOOK_CMD} plot variant profiles \
#                     "{input.variants_tsv}" \
#                     "{params.variant}" \
#                     "{params.model}" \
#                     --model-paths-tsv "{input.model_paths}" \
#                     --motifs-tsv "{input.motifs_tsv}" \
#                     -o "{output.png}" \
#                     --n-shuffles 20 \
#                     --device cuda) || \
#                 (export CUDA_VISIBLE_DEVICES=0 && {VARBOOK_CMD} plot variant profiles \
#                     "{input.variants_tsv}" \
#                     "{params.variant}" \
#                     "{params.model}" \
#                     --model-paths-tsv "{input.model_paths}" \
#                     --motifs-tsv "{input.motifs_tsv}" \
#                     -o "{output.png}" \
#                     --n-shuffles 20 \
#                     --device cuda) || \
#                 ({VARBOOK_CMD} plot variant profiles \
#                     "{input.variants_tsv}" \
#                     "{params.variant}" \
#                     "{params.model}" \
#                     --model-paths-tsv "{input.model_paths}" \
#                     --motifs-tsv "{input.motifs_tsv}" \
#                     -o "{output.png}" \
#                     --n-shuffles 20 \
#                     --device cpu)
#                 END_TIME=$(date +%s)
#                 DURATION=$((END_TIME - START_TIME))
#                 echo "[TIMING] Finished generate_variant_profile for {params.variant}/{params.model} in $$DURATIONs at $(date)" >&2

#                 # Generate URLs for profile
#                 mitra-utils url "{output.png}" || true
#                 mitra-utils url "{output.md}" || true

#                 # Create markdown reference
#                 echo "![Profile for {params.model}](03-profile-{params.model}.png)" > "{output.md}"
#             """)
#         else:
#             # Model doesn't prioritize this variant AND finemo is dummy - generate profile without finemo
#             # Try GPUs in order: 1, 3, 0, 2 (based on typical free memory), then fall back to CPU
#             shell("""
#                 START_TIME=$(date +%s)
#                 echo "[TIMING] Starting generate_variant_profile for {params.variant}/{params.model} (no finemo) at $(date)" >&2
#                 (export CUDA_VISIBLE_DEVICES=1 && {VARBOOK_CMD} plot variant profiles \
#                     "{input.variants_tsv}" \
#                     "{params.variant}" \
#                     "{params.model}" \
#                     --model-paths-tsv "{input.model_paths}" \
#                     -o "{output.png}" \
#                     --n-shuffles 20 \
#                     --device cuda) || \
#                 (export CUDA_VISIBLE_DEVICES=3 && {VARBOOK_CMD} plot variant profiles \
#                     "{input.variants_tsv}" \
#                     "{params.variant}" \
#                     "{params.model}" \
#                     --model-paths-tsv "{input.model_paths}" \
#                     -o "{output.png}" \
#                     --n-shuffles 20 \
#                     --device cuda) || \
#                 (export CUDA_VISIBLE_DEVICES=0 && {VARBOOK_CMD} plot variant profiles \
#                     "{input.variants_tsv}" \
#                     "{params.variant}" \
#                     "{params.model}" \
#                     --model-paths-tsv "{input.model_paths}" \
#                     -o "{output.png}" \
#                     --n-shuffles 20 \
#                     --device cuda) || \
#                 ({VARBOOK_CMD} plot variant profiles \
#                     "{input.variants_tsv}" \
#                     "{params.variant}" \
#                     "{params.model}" \
#                     --model-paths-tsv "{input.model_paths}" \
#                     -o "{output.png}" \
#                     --n-shuffles 20 \
#                     --device cpu)
#                 END_TIME=$(date +%s)
#                 DURATION=$((END_TIME - START_TIME))
#                 echo "[TIMING] Finished generate_variant_profile for {params.variant}/{params.model} in $$DURATIONs at $(date)" >&2

#                 # Generate URLs for profile
#                 mitra-utils url "{output.png}" || true
#                 mitra-utils url "{output.md}" || true

#                 # Create markdown reference
#                 echo "![Profile for {params.model}](03-profile-{params.model}.png)" > "{output.md}"
#             """)

rule generate_variant_profile_batch:
    """Generate profile plots for ALL variants prioritized by a model (batch mode).

    This rule loads the model ONCE and processes all variants sequentially,
    dramatically reducing model loading overhead (~25-30% speedup).

    Output is a sentinel file; actual PNG/MD files are created as side effects.
    Profiles are generated at variant_dataset level (shared across all model datasets).

    IMPORTANT: Only one instance per model can run at a time to avoid GPU memory issues.
    """
    input:
        variants_tsv = "data/{variant_dataset}.{model_dataset}.{variant_subdataset}.filtered.tsv",
        model_paths = MODEL_PATHS_TSV,
        motifs_tsv = get_finemo_tsv_path,
        finemo_ready = SPLITS_DIR + "/finemo/.all_finemo_files_ready"
    output:
        sentinel = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/.profiles_batch_{model_name}.done"
    params:
        variant_dataset = lambda w: w.variant_dataset,
        model_dataset_name = lambda w: w.model_dataset,
        cluster_name = lambda w: w.variant_subdataset,
        model_name = lambda w: w.model_name
    resources:
        # Only allow 1 job per model_name at a time (prevents multiple batch processes for same model)
        gpu_per_model = lambda w: w.model_name
    run:
        import os

        # Get cluster_id from cluster_name
        cluster_id = get_cluster_id_from_name(
            params.variant_dataset,
            params.model_dataset_name,
            params.cluster_name
        )

        if cluster_id is None:
            print(f"Warning: Could not find cluster_id for '{params.cluster_name}'")
            shell("touch {output.sentinel}")
            return

        # Get model_dataset_config
        config = None
        for cfg in get_model_datasets_list(params.variant_dataset):
            if cfg['name'] == params.model_dataset_name:
                config = cfg
                break

        if config is None:
            print(f"Warning: Could not find config for model_dataset '{params.model_dataset_name}'")
            shell("touch {output.sentinel}")
            return

        # Get variants prioritized by this model in this cluster
        variants = get_variants_prioritized_by_model(
            params.variant_dataset,
            params.model_name,
            config,
            cluster_id
        )

        if not variants:
            print(f"No variants prioritized by {params.model_name} in {params.cluster_name}")
            shell("touch '{output.sentinel}'")
            return

        # Base output directory - profiles at variant_dataset level (NOT model_dataset level!)
        base_dir = f"{OUTPUT_DIR}/{params.variant_dataset}"

        # Filter to only variants that need regeneration (hybrid approach)
        # This mimics Snakemake's dependency checking but at a granular level
        # Profiles are stored at: {variant_dataset}/profiles/{variant_id}/{model}.png
        variants_to_process = []
        for variant_id in variants:
            png_file = os.path.join(base_dir, "profiles", variant_id, f"{params.model_name}.png")

            # Check if PNG file exists
            if not os.path.exists(png_file):
                variants_to_process.append(variant_id)
            else:
                # File exists - check if it's older than any input
                try:
                    png_mtime = os.path.getmtime(png_file)
                    input_mtime = max(
                        os.path.getmtime(str(input.variants_tsv)),
                        os.path.getmtime(str(input.model_paths)),
                        os.path.getmtime(str(input.motifs_tsv))
                    )

                    # If PNG is older than inputs, it needs regeneration
                    if png_mtime < input_mtime:
                        variants_to_process.append(variant_id)
                except OSError:
                    # If we can't check timestamps, regenerate to be safe
                    variants_to_process.append(variant_id)

        # If all variants are up-to-date, just update sentinel and return
        if not variants_to_process:
            print(f"All {len(variants)} variants up-to-date for {params.model_name}")
            shell("touch '{output.sentinel}'")
            return

        print(f"Batch processing {len(variants_to_process)}/{len(variants)} variants for {params.model_name}")
        if len(variants_to_process) < len(variants):
            print(f"  Skipping {len(variants) - len(variants_to_process)} already up-to-date variants")

        # Call batch command
        # NOTE: Do NOT pass --model-dataset here! Profiles are stored at variant_dataset level
        # to be shared across all model_datasets (avoids duplication)
        import time
        import subprocess
        import sys
        start_time = time.time()
        
        # Select least-used GPU
        selected_gpu = get_least_used_gpu()
        print(f"[TIMING] Starting batch profile generation for {params.model_name} ({len(variants_to_process)} variants) at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Using GPU {selected_gpu} (least used)")
        
        # Build command as list for better error handling
        # --batch-variants uses nargs='+' so variant IDs should be separate arguments
        cmd = [
            VENV_PYTHON, "-m", "varbook", "plot", "variant", "profiles",
            str(input.variants_tsv),
            params.model_name,
            "--model-paths-tsv", str(input.model_paths),
            "--motifs-tsv", str(input.motifs_tsv),
            "--batch-variants"] + variants_to_process + [
            "--variant-dataset", params.variant_dataset,
            "-o", base_dir,
            "--n-shuffles", "20",
            "--device", "cuda"
        ]
        
        print(f"Running varbook command with {len(variants_to_process)} variants...")
        
        # Set CUDA_VISIBLE_DEVICES environment variable
        env = os.environ.copy()
        env['CUDA_VISIBLE_DEVICES'] = str(selected_gpu)
        
        try:
            result = subprocess.run(
                cmd,
                check=True,
                capture_output=True,
                text=True,
                env=env
            )
            if result.stdout:
                print(result.stdout)
            if result.stderr:
                print(result.stderr, file=sys.stderr)
        except subprocess.CalledProcessError as e:
            print(f"ERROR: varbook command failed with exit code {e.returncode}", file=sys.stderr)
            if e.stdout:
                print(f"STDOUT:\n{e.stdout}", file=sys.stderr)
            if e.stderr:
                print(f"STDERR:\n{e.stderr}", file=sys.stderr)
            raise
        
        end_time = time.time()
        duration = end_time - start_time
        print(f"[TIMING] Batch profile generation for {params.model_name} completed in {duration:.1f}s")
        
        # Generate URLs for all profile PNG files that were created
        import glob
        for variant_id in variants_to_process:
            png_file = os.path.join(base_dir, "profiles", variant_id, f"{params.model_name}.png")
            if os.path.exists(png_file):
                try:
                    subprocess.run(['mitra-utils', 'url', png_file], check=False, capture_output=True, text=True)
                except Exception as e:
                    print(f"Warning: Could not generate URL for {png_file}: {e}")
        
        # Create sentinel file
        shell("touch '{output.sentinel}'")

rule symlink_variant_profiles:
    """Create symlinks from centralized profile storage to model_dataset/cluster/variant directories.

    Profiles are stored at: varbook_gen/{variant_dataset}/{variant_id}/profiles/{model}.png
    Symlinks created at: varbook_gen/{variant_dataset}/{model_dataset}/{cluster}/{variant_id}/03-profile-{model}.png

    This rule also creates the .md file at the symlink location.
    """
    input:
        sentinel = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/.profiles_batch_{model_name}.done"
    output:
        symlink_sentinel = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/.profiles_symlinked_{model_name}.done"
    params:
        variant_dataset = lambda w: w.variant_dataset,
        model_dataset_name = lambda w: w.model_dataset,
        cluster_name = lambda w: w.variant_subdataset,
        model_name = lambda w: w.model_name
    run:
        import os

        # Get cluster_id from cluster_name
        cluster_id = get_cluster_id_from_name(
            params.variant_dataset,
            params.model_dataset_name,
            params.cluster_name
        )

        if cluster_id is None:
            print(f"Warning: Could not find cluster_id for '{params.cluster_name}'")
            shell("touch {output.symlink_sentinel}")
            return

        # Get model_dataset_config
        config = None
        for cfg in get_model_datasets_list(params.variant_dataset):
            if cfg['name'] == params.model_dataset_name:
                config = cfg
                break

        if config is None:
            print(f"Warning: Could not find config for model_dataset '{params.model_dataset_name}'")
            shell("touch {output.symlink_sentinel}")
            return

        # Get variants prioritized by this model in this cluster
        variants = get_variants_prioritized_by_model(
            params.variant_dataset,
            params.model_name,
            config,
            cluster_id
        )

        if not variants:
            print(f"No variants to symlink for {params.model_name} in {params.cluster_name}")
            shell("touch '{output.symlink_sentinel}'")
            return

        # Storage location (centralized)
        storage_base = f"{OUTPUT_DIR}/{params.variant_dataset}"

        # Display location (distributed per model_dataset/cluster)
        display_base = f"{OUTPUT_DIR}/{params.variant_dataset}/{params.model_dataset_name}/{params.cluster_name}"

        # Create symlinks and metadata files for each variant
        # Profiles are stored at: {variant_dataset}/profiles/{variant_id}/{model}.png
        # Symlinks use safe variant IDs to avoid PeriodicWildcardError
        symlinks_created = 0
        for variant_id in variants:
            # Source: centralized storage (uses raw variant_id as created by varbook CLI)
            source = os.path.join(storage_base, "profiles", variant_id, f"{params.model_name}.png")

            # Only create symlink if source exists
            if not os.path.exists(source):
                continue

            # Destination: distributed display location (uses safe variant_id for Snakemake)
            safe_variant_id = get_safe_variant_id(variant_id)
            dest_dir = os.path.join(display_base, safe_variant_id)
            dest_symlink = os.path.join(dest_dir, f"03-profile-{params.model_name}.png")

            # Create destination directory
            os.makedirs(dest_dir, exist_ok=True)

            # Calculate relative path from symlink to source
            # This makes symlinks more portable
            rel_source = os.path.relpath(source, dest_dir)

            # Create symlink (remove existing if present)
            if os.path.islink(dest_symlink) or os.path.exists(dest_symlink):
                os.remove(dest_symlink)
            os.symlink(rel_source, dest_symlink)

            # # Create .md file at symlink location
            # md = os.path.join(dest_dir, f"03-profile-{params.model_name}.md")

            # # Header goes in main .md file
            # with open(md, 'w') as f:
            #     f.write(f"## Profile: {params.model_name}\n\n")
            #     f.write(f"![Profile for {params.model_name}](03-profile-{params.model_name}.png)\n")

            symlinks_created += 1

        print(f"Created {symlinks_created} symlinks for {params.model_name} in {params.cluster_name}")

        # Create sentinel file
        shell("touch '{output.symlink_sentinel}'")

# ----------------------------------------------------------------------------
# Old rules (to be removed)
# ----------------------------------------------------------------------------

def get_variant_config(wildcards):
    """Get the model_dataset config for variant plots."""
    variant_dataset = wildcards.variant_dataset
    model_dataset_name = wildcards.model_dataset

    if variant_dataset not in VARIANT_DATASET_CONFIGS:
        raise ValueError(f"Unknown variant_dataset: {variant_dataset}")

    for config in get_model_datasets_list(variant_dataset):
        if config['name'] == model_dataset_name:
            return config
    raise ValueError(f"Unknown model_dataset '{model_dataset_name}' for variant_dataset '{variant_dataset}'")

def get_common_variant_plot_params(wildcards, cluster_id=None, include_superset=False):
    """Get common parameters for variant plot rules.
    
    Parameters:
    -----------
    wildcards : object
        Snakemake wildcards object
    cluster_id : str or None, optional
        Cluster ID (use wildcards.cluster if available, None otherwise)
    include_superset : bool, optional
        If True, include models_superset_args for superset-level plots.
        Default is False.
    
    Returns:
    --------
    dict
        Dictionary with common plot parameters:
        - variant: original variant ID
        - variant_datasets_param: variant datasets parameter string
        - model_dataset_name: model dataset name
        - models_args: models argument string (cluster-level)
        - models_superset_args: models argument string (superset-level, if include_superset=True)
    """
    config = get_variant_config(wildcards)
    params = {
        'variant': get_original_variant_id(
            wildcards.variant, config, cluster_id, wildcards.variant_dataset
        ),
        'variant_datasets_param': get_variant_datasets_param(
            wildcards.variant_dataset, cluster_id
        ),
        'model_dataset_name': wildcards.model_dataset,
        'models_args': get_models_args(config, use_superset=False),
    }
    if include_superset:
        params['models_superset_args'] = get_models_args(config, use_superset=True)
    return params

# Ruleorder: Prefer clustered rules over unclustered/generic rules for variant plots
ruleorder: generate_variant_barplot_clustered > generate_variant_barplot > generate_variant_barplot_unclustered

# Ensure finemo files are generated before profiles
# This rule creates a sentinel file that depends on all required finemo files
rule ensure_finemo_files:
    """Ensure all required finemo files are generated before profiles."""
    input:
        # Ensure human-readable TSVs are complete first (all variant_datasets)
        human_readable_sentinels = lambda w: _get_all_variant_dataset_sentinels(),
        # Use expand to reference finemo rule outputs for all KUN_FB models
        finemo_files = expand(
            SPLITS_DIR + "/finemo/broad.finemo.{model}.tsv",
            model=KUN_FB_MODELS
        )
    output:
        sentinel = SPLITS_DIR + "/finemo/.all_finemo_files_ready"
    shell:
        "touch {output.sentinel}"

def _get_all_variant_dataset_sentinels():
    """Get all variant_dataset-level sentinel files.
    
    Returns list of sentinel paths: one per variant_dataset.
    These are created by ensure_human_readable_tsvs_complete rule.
    """
    sentinels = []
    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        sentinel = f"{OUTPUT_DIR}/{variant_dataset}/human_readable_spreadsheets/.all_tsvs_complete"
        sentinels.append(sentinel)
    return sentinels

def _get_human_readable_sentinel(wildcards):
    """Get all human-readable TSV sentinel files.
    
    Collects all sentinel files created by generate_human_readable_cluster_tsv rules.
    This creates a dependency so finemo runs after human-readable TSVs.
    
    NOTE: This function is kept for backward compatibility, but the recommended approach
    is to use _get_all_variant_dataset_sentinels() which depends on variant_dataset-level
    sentinels instead of individual cluster sentinels.
    """
    import os
    sentinels = []
    
    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        model_dataset_configs = get_model_datasets_list(variant_dataset)
        for model_dataset_config in model_dataset_configs:
            model_dataset_name = model_dataset_config.get('name')
            if not model_dataset_name:
                continue
            
            clusters = model_dataset_config.get('clusters', [])
            if not clusters:
                continue
            
            for cluster in clusters:
                if isinstance(cluster, dict):
                    cluster_name = cluster.get('name', cluster.get('id'))
                else:
                    cluster_name = cluster
                
                if cluster_name:
                    sentinel_file = f"{OUTPUT_DIR}/{variant_dataset}/human_readable_spreadsheets/.sentinels/{model_dataset_name}.{cluster_name}.done"
                    sentinels.append(sentinel_file)
    
    return sentinels

# ruleorder: ensure_finemo_files > generate_variant_profile
ruleorder: ensure_finemo_files > generate_variant_profile_batch

# Rule to merge logfc and aaq files from multiple model datasets for superset scatterplots
rule merge_superset_logfc_aaq:
    """Merge logfc, aaq, and prioritization files from multiple model datasets for superset-level scatterplots.
    
    Takes logfc, aaq, and prioritization files like:
    - '{variant_dataset}.logfc-KUN_FB.tsv'
    - '{variant_dataset}.logfc-KUN_HDMA.tsv'
    - '{variant_dataset}.aaq-KUN_FB.tsv'
    - '{variant_dataset}.aaq-KUN_HDMA.tsv'
    - '{variant_dataset}.model_prioritized_by_peak-KUN_FB.tsv'
    - '{variant_dataset}.model_prioritized_by_promoter-KUN_FB.tsv'
    - '{variant_dataset}.model_prioritized_by_outofpeak-KUN_FB.tsv'
    - (and same for KUN_HDMA)
    
    And merges them using merge-columns to create a single TSV with logfc, aaq, and prioritization columns
    from all model datasets in the superset. Prioritization columns are needed for scatterplot labels.
    """
    input:
        logfc_aaq_files = lambda w: get_superset_logfc_aaq_files_for_rule(w),
        prioritization_files = lambda w: get_superset_prioritization_files_for_rule(w)
    output:
        merged_tsv = "data/{variant_dataset}.logfc_aaq_superset_{model_datasets_str}.tsv"
    params:
        variant_dataset = lambda w: w.variant_dataset,
        model_datasets_str = lambda w: w.model_datasets_str
    run:
        import subprocess
        import os
        
        logfc_aaq_files = input.logfc_aaq_files
        prioritization_files = input.prioritization_files
        all_files = logfc_aaq_files + prioritization_files
        output_file = output.merged_tsv
        
        # Ensure output directory exists
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        
        # Verify all input files exist
        missing_files = [f for f in all_files if not os.path.exists(f)]
        if missing_files:
            raise FileNotFoundError(
                f"Missing files for superset merge:\n" +
                "\n".join(f"  - {f}" for f in missing_files)
            )
        
        # Build merge-columns command
        cmd = [
            "merge-columns",
            "--merge-column", "variant_id",
            "--join-type", "outer",  # Use outer join to keep all variants
            "--output", output_file
        ] + all_files
        
        print(f"Merging {len(logfc_aaq_files)} logfc/aaq files and {len(prioritization_files)} prioritization files for superset scatterplot:")
        for f in logfc_aaq_files:
            print(f"  - {f}")
        for f in prioritization_files:
            print(f"  - {f}")
        
        # Run merge-columns with timing
        import time
        start_time = time.time()
        print(f"[TIMING] Starting merge_superset_logfc_aaq for {params.variant_dataset} ({len(all_files)} files) at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        result = subprocess.run(cmd, capture_output=True, text=True)
        end_time = time.time()
        duration = end_time - start_time
        
        if result.returncode != 0:
            raise RuntimeError(
                f"merge-columns failed for superset logfc/aaq/prioritization merge:\n"
                f"stdout: {result.stdout}\n"
                f"stderr: {result.stderr}"
            )
        
        print(f"[TIMING] Finished merge_superset_logfc_aaq for {params.variant_dataset} in {duration:.1f}s")
        print(f"✓ Created merged logfc/aaq/prioritization TSV: {output_file}")

def get_superset_logfc_aaq_files_for_rule(wildcards):
    """Get list of logfc and aaq file paths for model datasets in the superset.
    
    Used by the merge_superset_logfc_aaq rule. Gets model datasets directly from
    the config by matching the model_datasets_str wildcard to a model_superset
    in the variant_dataset config. Constructs logfc and aaq file paths.
    Always uses variant_dataset name in file paths.
    """
    import os
    
    variant_dataset = wildcards.variant_dataset
    model_datasets_str = wildcards.model_datasets_str.lower()
    
    # Get model datasets directly from config by finding which model_dataset
    # in the variant_dataset config has a model_superset that matches
    model_datasets = None
    
    if variant_dataset in VARIANT_DATASET_CONFIGS:
        configs = get_model_datasets_list(variant_dataset)
        for model_config in configs:
            # Check if this config has a model_superset
            if 'model_superset' in model_config:
                superset_datasets = get_superset_model_datasets(model_config)
                # Check if the superset matches what we're looking for
                # by comparing the sorted lowercase joined string
                # (this matches how get_superset_scatterplot_tsv creates the string)
                superset_str = '_'.join([d.lower() for d in sorted(superset_datasets)])
                if superset_str == model_datasets_str:
                    model_datasets = superset_datasets
                    break
    
    if model_datasets is None:
        raise ValueError(
            f"Could not find model_superset matching '{model_datasets_str}' "
            f"for variant_dataset '{variant_dataset}'. "
            f"Available configs: {[c.get('name', 'unnamed') for c in get_model_datasets_list(variant_dataset)]}"
        )
    
    files = []
    for model_dataset in model_datasets:
        # Always use variant_dataset name (not "broad" fallback)
        logfc_file = f"{SPLITS_DIR}/{variant_dataset}.logfc-{model_dataset}.tsv"
        files.append(logfc_file)
        
        aaq_file = f"{SPLITS_DIR}/{variant_dataset}.aaq-{model_dataset}.tsv"
        files.append(aaq_file)
    
    return files

def get_superset_prioritization_files_for_rule(wildcards):
    """Get list of prioritization file paths (peak, promoter, outofpeak) for model datasets in the superset.
    
    Used by the merge_superset_logfc_aaq rule to also include prioritization columns
    needed for scatterplot labels.
    """
    import os
    
    variant_dataset = wildcards.variant_dataset
    model_datasets_str = wildcards.model_datasets_str.lower()
    
    # Get model datasets (same logic as get_superset_logfc_aaq_files_for_rule)
    model_datasets = None
    
    if variant_dataset in VARIANT_DATASET_CONFIGS:
        configs = get_model_datasets_list(variant_dataset)
        for model_config in configs:
            if 'model_superset' in model_config:
                superset_datasets = get_superset_model_datasets(model_config)
                superset_str = '_'.join([d.lower() for d in sorted(superset_datasets)])
                if superset_str == model_datasets_str:
                    model_datasets = superset_datasets
                    break
    
    if model_datasets is None:
        return []  # Return empty list if no match (rule will handle gracefully)
    
    files = []
    for model_dataset in model_datasets:
        # Get prioritization files for peak, promoter, outofpeak (REQUIRED for scatterplot labels)
        for prio_type in ['peak', 'promoter', 'outofpeak']:
            # Try model-specific file first
            prio_file = f"{SPLITS_DIR}/{variant_dataset}.model_prioritized_by_{prio_type}-{model_dataset}.tsv"
            if os.path.exists(prio_file):
                files.append(prio_file)
            else:
                # Try non-model-specific file as fallback
                prio_file_generic = f"{SPLITS_DIR}/{variant_dataset}.model_prioritized_by_{prio_type}.tsv"
                if os.path.exists(prio_file_generic):
                    if prio_file_generic not in files:
                        files.append(prio_file_generic)
                else:
                    # File is required - will cause merge rule to fail with clear error
                    files.append(prio_file)  # Add expected path even if missing (for clear error message)
    
    return files

# Rule for variants WITH cluster subdirectory
def get_superset_model_datasets(model_dataset_config):
    """Extract model dataset names from model_superset patterns.
    
    For patterns like ['KUN_FB*', 'KUN_HDMA*'], extracts ['KUN_FB', 'KUN_HDMA'].
    """
    model_superset = model_dataset_config.get('model_superset') or model_dataset_config.get('models', [])
    # Extract base dataset names (remove wildcards)
    datasets = []
    for pattern in model_superset:
        # Remove trailing * and extract base name
        base = pattern.rstrip('*').rstrip('_')
        if base and base not in datasets:
            datasets.append(base)
    return datasets

def get_superset_scatterplot_tsv(wildcards):
    """Get the TSV file path for superset-level scatterplots.
    
    Returns the path to the merged logfc/aaq TSV file (which will be generated
    by the merge_superset_logfc_aaq rule).
    """
    import os
    
    # First try scatterplot_context_tsv (has logfc/aaq for multiple model sets)
    context_tsv = get_scatterplot_context_tsv(wildcards.variant_dataset)
    if context_tsv:
        return context_tsv
    
    # Otherwise, use merged logfc/aaq TSV from merge_superset_logfc_aaq rule
    model_config = get_variant_config(wildcards)
    model_datasets = get_superset_model_datasets(model_config)
    
    if not model_datasets:
        # No superset, fall back to filtered TSV
        return f"data/{wildcards.variant_dataset}.{wildcards.model_dataset}.{wildcards.cluster}.filtered.tsv"
    
    # Create a merged TSV path based on model datasets (includes both logfc and aaq)
    model_datasets_str = '_'.join([d.lower() for d in sorted(model_datasets)])
    return f"data/{wildcards.variant_dataset}.logfc_aaq_superset_{model_datasets_str}.tsv"

# NOTE: get_superset_scatterplot_tsv_with_rule was inlined - it just called get_superset_scatterplot_tsv

rule generate_variant_scatterplot_clustered:
    """Generate model scatterplot for a variant in a cluster.
    
    Generates two versions:
    1. Cluster-level scatterplot (using 'models' field) - shows models in the current cluster/model dataset
    2. Superset-level scatterplot (using 'model_superset' if available) - shows broader context across model sets
    
    For the superset scatterplot, uses a TSV file that contains columns for all models in the superset
    (e.g., both KUN_FB and KUN_HDMA columns).
    """
    input:
        variants_tsv = lambda w: get_scatterplot_input_tsv(w, f"data/{w.variant_dataset}.{w.model_dataset}.{w.cluster}.filtered.tsv"),
        variants_tsv_superset = lambda w: get_superset_scatterplot_tsv(w)
    output:
        html = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{cluster}/{variant}/02-model-scatterplot.html",
        md = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{cluster}/{variant}/02-model-scatterplot.md",
        html_superset = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{cluster}/{variant}/02-model-scatterplot-superset.html",
        md_superset = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{cluster}/{variant}/02-model-scatterplot-superset.md"
    params:
        variant = lambda w: get_original_variant_id(w.variant, get_variant_config(w), w.cluster, w.variant_dataset),
        variant_datasets_param = lambda w: get_variant_datasets_param(w.variant_dataset, w.cluster),
        model_dataset_name = lambda w: w.model_dataset,
        models_args = lambda w: get_models_args(get_variant_config(w)),
        models_superset_args = lambda w: get_models_args(get_variant_config(w), use_superset=True)
    shell:
        """
        START_TIME=$(date +%s)
        echo "[TIMING] Starting generate_variant_scatterplot_clustered for {params.variant} at $(date)" >&2
        
        # Generate cluster-level scatterplot (using 'models' field)
        CLUSTER_START=$(date +%s)
        {VARBOOK_CMD} plot variant model-scatterplot \
            "{input.variants_tsv}" \
            "{params.variant}" \
            logfc \
            aaq \
            --variant-datasets "{params.variant_datasets_param}" \
            --model-dataset "{params.model_dataset_name}" \
            --models {params.models_args} \
            --interactive-plot \
            --label-cols model_prioritized_by_peak model_prioritized_by_promoter model_prioritized_by_outofpeak \
            --label-names "Prioritized in Peak" "Prioritized in Promoter" "Prioritized in Out-of-peak" \
            --label-colors "green" "blue" "orange" \
            --output "{output.md}" "{output.html}"
        CLUSTER_END=$(date +%s)
        CLUSTER_DURATION=$((CLUSTER_END - CLUSTER_START))
        echo "[TIMING] Cluster-level scatterplot took ${{CLUSTER_DURATION}}s" >&2
        
        # Generate URLs for cluster-level scatterplot
        mitra-utils url "{output.html}" || true
        mitra-utils url "{output.md}" || true
        
        # Generate superset-level scatterplot (using 'model_superset' if available)
        # Uses a TSV file that contains columns for all models in the superset
        SUPERSET_START=$(date +%s)
        {VARBOOK_CMD} plot variant model-scatterplot \
            "{input.variants_tsv_superset}" \
            "{params.variant}" \
            logfc \
            aaq \
            --variant-datasets "{params.variant_datasets_param}" \
            --model-dataset "{params.model_dataset_name}" \
            --models {params.models_superset_args} \
            --interactive-plot \
            --label-cols model_prioritized_by_peak model_prioritized_by_promoter model_prioritized_by_outofpeak \
            --label-names "Prioritized in Peak" "Prioritized in Promoter" "Prioritized in Out-of-peak" \
            --label-colors "green" "blue" "orange" \
            --output "{output.md_superset}" "{output.html_superset}"
        SUPERSET_END=$(date +%s)
        SUPERSET_DURATION=$((SUPERSET_END - SUPERSET_START))
        echo "[TIMING] Superset-level scatterplot took ${{SUPERSET_DURATION}}s" >&2
        
        # Generate URLs for superset-level scatterplot
        mitra-utils url "{output.html_superset}" || true
        mitra-utils url "{output.md_superset}" || true
        
        END_TIME=$(date +%s)
        DURATION=$((END_TIME - START_TIME))
        echo "[TIMING] Finished generate_variant_scatterplot_clustered for {params.variant} in ${{DURATION}}s (cluster: ${{CLUSTER_DURATION}}s, superset: ${{SUPERSET_DURATION}}s) at $(date)" >&2
        """

# Rule for variants WITHOUT cluster subdirectory (if no clustering)
rule generate_variant_barplot_clustered:
    """Generate model-specificity barplot for a variant in a cluster.
    
    Generates two versions:
    1. Cluster-level barplot (using 'models' field) - shows models in the current cluster/model dataset
    2. Superset-level barplot (using 'model_superset' if available) - shows broader context across model sets
    
    For the superset barplot, uses a TSV file that contains columns for all models in the superset
    (e.g., both KUN_FB and KUN_HDMA columns).
    """
    input:
        variants_tsv = "data/{variant_dataset}.{model_dataset}.{cluster}.filtered.tsv",
        variants_tsv_superset = lambda w: get_superset_scatterplot_tsv(w)
    output:
        png = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{cluster}/{variant}/01-model-specificity-barplot.png",
        png_superset = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{cluster}/{variant}/01-model-specificity-barplot-superset.png"
    params:
        variant = lambda w: get_original_variant_id(w.variant, get_variant_config(w), w.cluster, w.variant_dataset),
        variant_datasets_param = lambda w: get_variant_datasets_param(w.variant_dataset, w.cluster),
        variant_dataset = lambda w: w.variant_dataset,
        model_dataset_name = lambda w: w.model_dataset,
        models_args = lambda w: get_models_args(get_variant_config(w)),
        models_superset_args = lambda w: get_models_args(get_variant_config(w), use_superset=True)
    run:
        import pandas as pd
        import ast
        import os
        import subprocess
        import time
        from pathlib import Path
        
        start_time = time.time()
        variant_id = params.variant
        print(f"[TIMING] Starting generate_variant_barplot_clustered for {variant_id} at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        
        # Generate cluster-level barplot (using 'models' field)
        cluster_start = time.time()
        # Parse models_args (space-separated quoted strings like '"KUN_FB*"')
        import shlex
        models_list = shlex.split(params.models_args) if params.models_args else []
        # Split VARBOOK_CMD into components (it's a string like "{VENV_PYTHON} -m varbook")
        varbook_cmd_parts = shlex.split(VARBOOK_CMD)
        cmd_cluster = varbook_cmd_parts + [
            'plot', 'variant', 'model-specificity-barplot',
            input.variants_tsv,
            variant_id,
            '--variant-datasets', params.variant_datasets_param,
            '--model-dataset', params.model_dataset_name,
            '--models'] + models_list + [
            '-o', output.png
        ]
        result = subprocess.run(cmd_cluster, capture_output=True, text=True)
        if result.returncode != 0:
            raise RuntimeError(f"Cluster-level barplot failed:\nstdout: {result.stdout}\nstderr: {result.stderr}")
        cluster_duration = time.time() - cluster_start
        print(f"[TIMING] Cluster-level barplot took {cluster_duration:.1f}s")
        
        # Generate URL for cluster-level barplot
        try:
            subprocess.run(['mitra-utils', 'url', output.png], check=False, capture_output=True, text=True)
        except Exception as e:
            print(f"Warning: Could not generate URL for {output.png}: {e}")
        
        # Generate superset-level barplot (using 'model_superset' if available)
        # Merge filtered.tsv (has organs column) with logfc_aaq_superset (has all superset model columns)
        superset_start = time.time()
        
        # Read filtered.tsv (has organs column and cluster model columns)
        print(f"Reading filtered TSV: {input.variants_tsv}")
        df_filtered = pd.read_csv(input.variants_tsv, sep='\t')
        print(f"  Filtered TSV: {len(df_filtered)} variants, {len(df_filtered.columns)} columns")
        
        # Check if variant exists in filtered TSV
        if variant_id not in df_filtered['variant_id'].values:
            raise ValueError(f"Variant {variant_id} not found in filtered TSV {input.variants_tsv}")
        
        # Read superset TSV (has all superset model logfc/aaq columns)
        print(f"Reading superset TSV: {input.variants_tsv_superset}")
        df_superset = pd.read_csv(input.variants_tsv_superset, sep='\t')
        print(f"  Superset TSV: {len(df_superset)} variants, {len(df_superset.columns)} columns")
        
        # Check if variant exists in superset TSV
        if variant_id not in df_superset['variant_id'].values:
            raise ValueError(f"Variant {variant_id} not found in superset TSV {input.variants_tsv_superset}")
        
        # Merge on variant_id, keeping all columns from both
        # Use suffixes to handle overlapping columns (prefer filtered.tsv for non-logfc/aaq columns)
        df_merged = df_filtered.merge(df_superset, on='variant_id', how='inner', suffixes=('', '_superset'))
        print(f"  Merged: {len(df_merged)} variants after inner join")
        
        # Verify variant is still in merged dataframe
        if variant_id not in df_merged['variant_id'].values:
            raise ValueError(f"Variant {variant_id} lost during merge! Filtered has {len(df_filtered)} variants, superset has {len(df_superset)} variants")
        
        # For overlapping logfc/aaq columns, prefer superset version (has all models)
        # For other columns, prefer filtered version (has organs, etc.)
        # CRITICAL: Also keep model_prioritized_by_any- columns from superset (needed for barplot)
        filtered_cols = set(df_filtered.columns)
        superset_cols = set(df_superset.columns)
        superset_logfc_aaq_cols = [c for c in df_superset.columns if c.startswith(('logfc-', 'aaq-'))]
        superset_logfc_aaq_cols_set = set(superset_logfc_aaq_cols)
        superset_prio_cols = [c for c in df_superset.columns if c.startswith('model_prioritized_by_any-')]
        superset_prio_cols_set = set(superset_prio_cols)
        
        # Build final column list: all filtered columns + superset logfc/aaq columns + superset prioritization columns
        # Columns that exist in both get _superset suffix, columns only in superset don't
        # First, identify which base columns we're keeping from superset (to avoid duplicates)
        superset_kept_base_cols = set()
        for col in df_merged.columns:
            if col.endswith('_superset'):
                base_col = col.replace('_superset', '')
                if base_col in superset_logfc_aaq_cols_set or base_col in superset_prio_cols_set:
                    superset_kept_base_cols.add(base_col)
        
        # Now build keep_cols, excluding filtered columns that we're keeping from superset
        keep_cols = []
        for col in df_merged.columns:
            if col.endswith('_superset'):
                # This is a column that existed in both dataframes
                base_col = col.replace('_superset', '')
                if base_col in superset_logfc_aaq_cols_set or base_col in superset_prio_cols_set:
                    keep_cols.append(col)  # Keep superset version for logfc/aaq and prioritization
                # Otherwise skip (prefer filtered version for non-logfc/aaq/prio columns)
            elif col in filtered_cols:
                # Column exists in filtered TSV - keep it ONLY if we're not already keeping superset version
                if col not in superset_kept_base_cols:
                    keep_cols.append(col)
            elif col in superset_cols and (col.startswith(('logfc-', 'aaq-')) or col.startswith('model_prioritized_by_any-')):
                # Column ONLY exists in superset TSV (e.g., KUN_HDMA columns)
                # This is important - these columns don't get _superset suffix because they're not in filtered TSV
                keep_cols.append(col)
        
        # Rename _superset columns back to original names
        df_merged = df_merged[keep_cols]
        for col in list(df_merged.columns):
            if col.endswith('_superset'):
                df_merged = df_merged.rename(columns={col: col.replace('_superset', '')})
        
        print(f'After column selection: {len(df_merged.columns)} columns')
        kun_hdma_cols = [c for c in df_merged.columns if 'KUN_HDMA' in c]
        print(f'KUN_HDMA columns in merged dataframe: {len(kun_hdma_cols)}')
        
        # Check which models we have from logfc/aaq columns (these are the models we need prioritization for)
        models_from_logfc_aaq = set()
        for col in df_merged.columns:
            if col.startswith(('logfc-', 'aaq-')):
                model_name = col.replace('logfc-', '').replace('aaq-', '')
                models_from_logfc_aaq.add(model_name)
        
        # Check which models we already have prioritization columns for
        prio_cols_in_merged = [c for c in df_merged.columns if c.startswith('model_prioritized_by_any-')]
        models_with_prio = {c.replace('model_prioritized_by_any-', '') for c in prio_cols_in_merged}
        print(f'Found {len(prio_cols_in_merged)} model_prioritized_by_any- columns in merged dataframe')
        print(f'Models from logfc/aaq: {len(models_from_logfc_aaq)}, Models with prioritization: {len(models_with_prio)}')
        
        # Check which models are missing prioritization columns
        models_missing_prio = models_from_logfc_aaq - models_with_prio
        if models_missing_prio:
            print(f'Warning: {len(models_missing_prio)} models missing prioritization columns. Attempting to merge from prioritization files...')
            variant_dataset = params.variant_dataset
            
            # Determine which model datasets we need by checking which dataset prefixes the missing models belong to
            model_datasets_needed = set()
            for model in models_missing_prio:
                # Extract dataset prefix (e.g., "KUN_FB" from "KUN_FB_microglia", "KUN_HDMA" from "KUN_HDMA_...")
                parts = model.split('_')
                if len(parts) >= 2:
                    # Try 2-part prefix first (KUN_FB, KUN_HDMA)
                    dataset_prefix = '_'.join(parts[:2])
                    model_datasets_needed.add(dataset_prefix)
                elif len(parts) >= 1:
                    # Fallback to single part
                    model_datasets_needed.add(parts[0])
            
            print(f'Need prioritization files for datasets: {model_datasets_needed}')
            prioritization_files = []
            for dataset_prefix in model_datasets_needed:
                # Try variant_dataset-specific file first
                prio_file = f"{SPLITS_DIR}/{variant_dataset}.model_prioritized_by_any-{dataset_prefix}.tsv"
                if os.path.exists(prio_file):
                    prioritization_files.append(prio_file)
                    print(f'  Found: {prio_file}')
                else:
                    # Try fallback to 'broad' prefix
                    prio_file_fallback = f"{SPLITS_DIR}/broad.model_prioritized_by_any-{dataset_prefix}.tsv"
                    if os.path.exists(prio_file_fallback):
                        prioritization_files.append(prio_file_fallback)
                        print(f'  Found: {prio_file_fallback}')
                    else:
                        print(f'  Warning: Prioritization file not found for {dataset_prefix}')
            
            if prioritization_files:
                print(f'Merging {len(prioritization_files)} prioritization files...')
                for prio_file in prioritization_files:
                    if os.path.exists(prio_file):
                        df_prio = pd.read_csv(prio_file, sep='\t')
                        # Get all prioritization columns from this file
                        prio_cols = [c for c in df_prio.columns if c.startswith('model_prioritized_by_any-')]
                        # Filter to only columns for models we're missing AND that don't already exist in df_merged
                        prio_cols_to_merge = []
                        for c in prio_cols:
                            model_name = c.replace('model_prioritized_by_any-', '')
                            if model_name in models_missing_prio and c not in df_merged.columns:
                                prio_cols_to_merge.append(c)
                        
                        if prio_cols_to_merge:
                            print(f'  Merging {len(prio_cols_to_merge)} columns from {os.path.basename(prio_file)}')
                            df_prio_subset = df_prio[['variant_id'] + prio_cols_to_merge]
                            # Merge with explicit handling to avoid duplicates
                            # Only merge columns that don't already exist
                            df_merged = df_merged.merge(df_prio_subset, on='variant_id', how='left', suffixes=('', '_skip'))
                            # Remove any _skip suffixes (shouldn't happen, but clean up just in case)
                            cols_to_rename = [c for c in df_merged.columns if c.endswith('_skip')]
                            for col in cols_to_rename:
                                base_col = col.replace('_skip', '')
                                if base_col not in df_merged.columns:
                                    df_merged = df_merged.rename(columns={col: base_col})
                                else:
                                    df_merged = df_merged.drop(columns=[col])
                print(f'After merging prioritization files: {len(df_merged.columns)} columns')
                prio_cols_in_merged = [c for c in df_merged.columns if c.startswith('model_prioritized_by_any-')]
                print(f'Now have {len(prio_cols_in_merged)} model_prioritized_by_any- columns')
                
                # Check for any duplicate column names (pandas adds .1, .2, etc. for duplicates)
                # This can happen if pandas detects duplicate column names during merge
                duplicate_cols = [col for col in df_merged.columns if any(col.endswith(f'.{i}') for i in range(1, 10))]
                if duplicate_cols:
                    print(f'Warning: Found {len(duplicate_cols)} duplicate columns with numeric suffixes')
                    # Remove duplicates, keeping the first occurrence (without suffix)
                    for dup_col in duplicate_cols:
                        if dup_col in df_merged.columns:
                            # Extract base column name (everything before .N)
                            base_col = '.'.join(dup_col.split('.')[:-1])
                            if base_col in df_merged.columns:
                                # Base column exists, so drop the duplicate
                                df_merged = df_merged.drop(columns=[dup_col])
                            else:
                                # Base column doesn't exist, rename the duplicate to base
                                df_merged = df_merged.rename(columns={dup_col: base_col})
                    print(f'  Cleaned up duplicate columns')
        
        # CRITICAL: Extract all model names from merged dataframe AFTER merge and column renaming
        # PRIMARY SOURCE: model_prioritized_by_any- columns (most reliable, directly used by barplot)
        # SECONDARY SOURCE: logfc/aaq columns (in case some models don't have prioritization columns)
        superset_models = set()
        # First, extract from model_prioritized_by_any- columns (primary source)
        for col in df_merged.columns:
            if col.startswith('model_prioritized_by_any-'):
                model_name = col.replace('model_prioritized_by_any-', '')
                superset_models.add(model_name)
        # Then, add any models from logfc/aaq columns that weren't already found
        for col in df_merged.columns:
            if col.startswith(('logfc-', 'aaq-')):
                model_name = col.replace('logfc-', '').replace('aaq-', '')
                superset_models.add(model_name)
        
        print(f'Found {len(superset_models)} superset models: {len([c for c in df_merged.columns if c.startswith("model_prioritized_by_any-")])} from prioritization columns, {len([c for c in df_merged.columns if c.startswith(("logfc-", "aaq-"))])} from logfc/aaq columns')
        
        # Debug: Check for KUN_HDMA models specifically
        kun_hdma_models = [m for m in superset_models if m.startswith('KUN_HDMA')]
        kun_fb_models = [m for m in superset_models if m.startswith('KUN_FB')]
        print(f'  KUN_HDMA models: {len(kun_hdma_models)} (showing first 5: {kun_hdma_models[:5] if kun_hdma_models else "none"})')
        print(f'  KUN_FB models: {len(kun_fb_models)} (showing first 5: {kun_fb_models[:5] if kun_fb_models else "none"})')
        
        # Debug: Check for KUN_HDMA prioritization columns
        kun_hdma_prio_cols = [c for c in df_merged.columns if c.startswith('model_prioritized_by_any-') and 'KUN_HDMA' in c]
        print(f'  KUN_HDMA prioritization columns: {len(kun_hdma_prio_cols)} (showing first 5: {kun_hdma_prio_cols[:5] if kun_hdma_prio_cols else "none"})')
        
        # Start with existing organs column (if it exists and is parseable)
        existing_organs_dict = {}
        if 'organs' in df_merged.columns:
            # Filter to variant row first
            variant_row = df_merged[df_merged['variant_id'] == variant_id]
            if len(variant_row) > 0:
                organs_val = variant_row['organs'].iloc[0]
                if pd.notna(organs_val):
                    try:
                        if isinstance(organs_val, dict):
                            existing_organs_dict = organs_val
                        elif isinstance(organs_val, str):
                            # Try to parse string representation of dict/list
                            if (organs_val.startswith('{') and organs_val.endswith('}')) or \
                               (organs_val.startswith('[') and organs_val.endswith(']')):
                                parsed = ast.literal_eval(organs_val)
                                if isinstance(parsed, dict):
                                    existing_organs_dict = parsed
                    except Exception as e:
                        print(f'Warning: Could not parse existing organs column: {e}')
                        existing_organs_dict = {}
        
        # Load model_tissues.tsv to get organ mappings for ALL superset models
        variant_dataset = params.variant_dataset
        model_tissues_path = f'/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/splits/{variant_dataset}.model_tissues.tsv'
        if not os.path.exists(model_tissues_path):
            model_tissues_path = f'splits/{variant_dataset}.model_tissues.tsv'
        
        if os.path.exists(model_tissues_path):
            df_tissues = pd.read_csv(model_tissues_path, sep='\t')
            # Create dictionary mapping model names to organs for ALL models in model_tissues.tsv
            model_to_organs = dict(zip(df_tissues['model_name'], df_tissues['organs']))
            
            # Start with existing organs dict, then update/add mappings for superset models
            # This preserves any models that were in the original organs column but ensures
            # superset models have correct mappings from model_tissues.tsv
            final_organs_dict = existing_organs_dict.copy()
            
            # Track which models we found vs not found for debugging
            found_in_tissues = []  # Models found in model_tissues.tsv
            found_in_existing = []  # Models found in existing_organs_dict but not in model_tissues.tsv
            missing_models = []  # Models not found in either
            
            for model in superset_models:
                if model in model_to_organs:
                    final_organs_dict[model] = model_to_organs[model]
                    found_in_tissues.append(model)
                elif model in existing_organs_dict:
                    # Keep existing mapping if available (fallback for models not in model_tissues.tsv)
                    final_organs_dict[model] = existing_organs_dict[model]
                    found_in_existing.append(model)
                else:
                    # Only set to Unknown if we don't have a mapping from existing_organs_dict or model_tissues.tsv
                    final_organs_dict[model] = 'Unknown'
                    missing_models.append(model)
                    print(f'Missing model: {model}')
            
            # Update organs column with merged dictionary (string representation)
            df_merged['organs'] = str(final_organs_dict)
            print(f'Updated organs column with {len(final_organs_dict)} model mappings ({len(superset_models)} superset models, {len(existing_organs_dict)} from original)')
            print(f'Found organ annotations for {len(found_in_tissues)} models in model_tissues.tsv')
            if found_in_existing:
                print(f'Found organ annotations for {len(found_in_existing)} models in existing organs column (not in model_tissues.tsv)')
            if missing_models:
                # Separate KUN_HDMA models from others for better debugging
                kun_hdma_missing = [m for m in missing_models if m.startswith('KUN_HDMA')]
                other_missing = [m for m in missing_models if not m.startswith('KUN_HDMA')]
                if kun_hdma_missing:
                    print(f'Warning: {len(kun_hdma_missing)} KUN_HDMA models not found in model_tissues.tsv or existing organs (set to Unknown): {kun_hdma_missing[:10]}...' if len(kun_hdma_missing) > 10 else f'Warning: {len(kun_hdma_missing)} KUN_HDMA models not found in model_tissues.tsv or existing organs (set to Unknown): {kun_hdma_missing}')
                if other_missing:
                    print(f'Warning: {len(other_missing)} other models not found in model_tissues.tsv or existing organs (set to Unknown): {other_missing[:10]}...' if len(other_missing) > 10 else f'Warning: {len(other_missing)} other models not found in model_tissues.tsv or existing organs (set to Unknown): {other_missing}')
        else:
            print(f'Warning: model_tissues.tsv not found at {model_tissues_path}')
            if existing_organs_dict:
                # Keep existing organs dict if model_tissues.tsv not found
                df_merged['organs'] = str(existing_organs_dict)
                print(f'Kept original organs column with {len(existing_organs_dict)} model mappings')
            else:
                # Create empty dict if no organs data available
                df_merged['organs'] = str({})
                print(f'No existing organs column to fall back to - set to empty dict')
        
        # Verify variant is still present and has organs column
        variant_row = df_merged[df_merged['variant_id'] == variant_id]
        if len(variant_row) == 0:
            raise ValueError(f"Variant {variant_id} not found in merged dataframe after processing!")
        if 'organs' not in df_merged.columns:
            raise ValueError("organs column missing from merged dataframe")
        if pd.isna(variant_row['organs'].iloc[0]):
            raise ValueError(f"organs column is null for variant {variant_id}")
        
        # Save merged TSV
        merged_tsv = f'/tmp/barplot_superset_{variant_id.replace(":", "_").replace("/", "_")}.tsv'
        df_merged.to_csv(merged_tsv, sep='\t', index=False)
        print(f'Merged {len(df_merged)} variants with {len(df_merged.columns)} columns')
        
        # Use merged TSV for barplot
        # Parse models_superset_args (space-separated quoted strings like '"KUN_FB*" "KUN_HDMA*"')
        models_superset_list = shlex.split(params.models_superset_args) if params.models_superset_args else []
        # Split VARBOOK_CMD into components (it's a string like "{VENV_PYTHON} -m varbook")
        varbook_cmd_parts = shlex.split(VARBOOK_CMD)
        cmd_superset = varbook_cmd_parts + [
            'plot', 'variant', 'model-specificity-barplot',
            merged_tsv,
            variant_id,
            '--variant-datasets', params.variant_datasets_param,
            '--model-dataset', params.model_dataset_name,
            '--models'] + models_superset_list + [
            '-o', output.png_superset
        ]
        result = subprocess.run(cmd_superset, capture_output=True, text=True)
        if result.returncode != 0:
            raise RuntimeError(f"Superset-level barplot failed:\nstdout: {result.stdout}\nstderr: {result.stderr}")
        
        # Generate URL for superset-level barplot
        try:
            subprocess.run(['mitra-utils', 'url', output.png_superset], check=False, capture_output=True, text=True)
        except Exception as e:
            print(f"Warning: Could not generate URL for {output.png_superset}: {e}")
        
        # Clean up temp file
        os.remove(merged_tsv)
        
        superset_duration = time.time() - superset_start
        print(f"[TIMING] Superset-level barplot took {superset_duration:.1f}s")
        
        total_duration = time.time() - start_time
        print(f"[TIMING] Finished generate_variant_barplot_clustered for {variant_id} in {total_duration:.1f}s (cluster: {cluster_duration:.1f}s, superset: {superset_duration:.1f}s) at {time.strftime('%Y-%m-%d %H:%M:%S')}")

rule generate_variant_barplot_unclustered:
    """Generate model-specificity barplot for a variant (no clustering)."""
    input:
        variants_tsv = lambda w: get_variants_tsv_path(w.variant_dataset)
    output:
        png = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant}/model-specificity-barplot/{models_str}.png"
    wildcard_constraints:
        variant = r"[^/]+"
    params:
        variant = lambda w: get_original_variant_id(w.variant, get_variant_config(w), None, w.variant_dataset),
        variant_datasets_param = lambda w: get_variant_datasets_param(w.variant_dataset),
        model_dataset_name = lambda w: w.model_dataset,
        models_args = lambda w: get_models_args(get_variant_config(w))
    shell:
        """
        START_TIME=$(date +%s)
        echo "[TIMING] Starting generate_variant_barplot_unclustered for {params.variant} at $(date)" >&2
        {VARBOOK_CMD} plot variant model-specificity-barplot \
            "{input.variants_tsv}" \
            "{params.variant}" \
            --variant-datasets "{params.variant_datasets_param}" \
            --model-dataset "{params.model_dataset_name}" \
            --models {params.models_args} \
            -o "{output.png}"
        END_TIME=$(date +%s)
        DURATION=$((END_TIME - START_TIME))
        echo "[TIMING] Finished generate_variant_barplot_unclustered for {params.variant} in $$DURATIONs at $(date)" >&2
        
        # Generate URL for barplot
        mitra-utils url "{output.png}" || true
        """

rule generate_shared_profile:
    """Generate chromatin profile in shared location (dataset-agnostic)."""
    input:
        variants_tsv = lambda w: get_variants_tsv_path(w.variant_dataset),
        model_paths = MODEL_PATHS_TSV
    output:
        "varbook_gen/profiles/{variant}/{model}.png"
    params:
        variant = lambda w: w.variant,
        model = lambda w: w.model
    shell:
        """
        START_TIME=$(date +%s)
        echo "[TIMING] Starting generate_shared_profile for {params.variant}/{params.model} at $(date)" >&2
        {VARBOOK_CMD} plot variant profiles \
            {input.variants_tsv} \
            "{params.variant}" \
            {params.model} \
            --model-paths-tsv {input.model_paths} \
            -o "{output}" \
            --n-shuffles 20 \
            --device cuda
        END_TIME=$(date +%s)
        DURATION=$((END_TIME - START_TIME))
        echo "[TIMING] Finished generate_shared_profile for {params.variant}/{params.model} in $$DURATIONs at $(date)" >&2
        
        # Generate URL for profile
        mitra-utils url "{output}" || true
        """

# NOTE: link_profile_clustered and link_profile_unclustered rules were removed.
# Profile symlinks are now created by the symlink_variant_profiles rule which outputs
# to 03-profile-{model}.png (not profiles/{model}.png), so these rules were unused.

# ----------------------------------------------------------------------------
# Variant Summary HTML Rules
# ----------------------------------------------------------------------------

def get_variant_summary_inputs(wildcards):
    """Get all plot inputs for variant summary HTML generation.
    
    Returns list of files that should exist for the variant:
    - Cluster-level barplot PNG
    - Superset-level barplot PNG
    - Cluster-level scatterplot HTML
    - Superset-level scatterplot HTML
    - Profile PNGs (for each prioritized model)
    """
    inputs = []
    
    # Cluster-level Barplot
    barplot_path = f"{OUTPUT_DIR}/{wildcards.variant_dataset}/{wildcards.model_dataset}/{wildcards.variant_subdataset}/{wildcards.variant_id}/01-model-specificity-barplot.png"
    inputs.append(barplot_path)
    
    # Superset-level Barplot
    barplot_superset_path = f"{OUTPUT_DIR}/{wildcards.variant_dataset}/{wildcards.model_dataset}/{wildcards.variant_subdataset}/{wildcards.variant_id}/01-model-specificity-barplot-superset.png"
    inputs.append(barplot_superset_path)
    
    # Cluster-level scatterplot
    scatterplot_path = f"{OUTPUT_DIR}/{wildcards.variant_dataset}/{wildcards.model_dataset}/{wildcards.variant_subdataset}/{wildcards.variant_id}/02-model-scatterplot.html"
    inputs.append(scatterplot_path)
    
    # Superset-level scatterplot
    scatterplot_superset_path = f"{OUTPUT_DIR}/{wildcards.variant_dataset}/{wildcards.model_dataset}/{wildcards.variant_subdataset}/{wildcards.variant_id}/02-model-scatterplot-superset.html"
    inputs.append(scatterplot_superset_path)
    
    # Profile plots - get prioritized models for this variant
    prioritized_models = get_prioritized_models_for_variant(
        wildcards.variant_dataset,
        wildcards.variant_id,
        wildcards.model_dataset
    )
    
    for model in prioritized_models:
        profile_path = f"{OUTPUT_DIR}/{wildcards.variant_dataset}/{wildcards.model_dataset}/{wildcards.variant_subdataset}/{wildcards.variant_id}/03-profile-{model}.png"
        inputs.append(profile_path)
    
    return inputs

rule generate_variant_summary_html:
    """Generate HTML summary page for a variant displaying all plots in an easily-readable format.
    
    This rule creates a single HTML page that includes:
    - Model specificity barplots (cluster-level and superset-level)
    - Interactive model scatterplots (cluster-level and superset-level)
    - All profile plots for prioritized models
    
    HTML is generated early (after checkpoints complete) and references plots even if they
    don't exist yet. Plots will appear automatically when they are generated.
    """
    input:
        # No plot dependencies - HTML will reference plots even if they don't exist
        # clustered.tsv is created by checkpoints, so HTML will generate after checkpoints complete
        clustered_tsv = "data/{variant_dataset}.{model_dataset}.clustered.tsv"
    output:
        html = OUTPUT_DIR + "/{variant_dataset}/{model_dataset}/{variant_subdataset}/{variant_id}/00-summary.html"
    params:
        variant_dir = lambda w: f"{OUTPUT_DIR}/{w.variant_dataset}/{w.model_dataset}/{w.variant_subdataset}/{w.variant_id}",
        variant_id = lambda w: w.variant_id,
        model_dataset = lambda w: w.model_dataset
    shell:
        """
        START_TIME=$(date +%s)
        echo "[TIMING] Starting generate_variant_summary_html for {params.variant_id} at $(date)" >&2
        python generate_variant_summary_html.py \
            "{params.variant_dir}" \
            "{params.variant_id}" \
            "{input.clustered_tsv}" \
            "{params.model_dataset}" \
            "{output.html}"
        END_TIME=$(date +%s)
        DURATION=$((END_TIME - START_TIME))
        echo "[TIMING] Finished generate_variant_summary_html for {params.variant_id} in $$DURATIONs at $(date)" >&2
        """

# ----------------------------------------------------------------------------
# Helper Functions for User-Provided Variants
# ----------------------------------------------------------------------------

def get_variant_metadata(variant_id):
    """Look up variant metadata to determine dataset and cluster assignments.
    
    Searches through all clustered.tsv files to find which variant_datasets,
    model_datasets, and clusters contain this variant.
    
    Parameters:
    -----------
    variant_id : str
        Variant ID (e.g., "chr1:123456:A:G")
    
    Returns:
    --------
    list of dict
        Each dict contains:
            - variant_id: str
            - variant_dataset: str
            - model_dataset: str
            - cluster_id: int or None (kmeans_35 value)
            - cluster_name: str or None (human-readable name)
    """
    import pandas as pd
    import os
    
    metadata_list = []
    
    # Iterate through all variant_datasets and model_datasets
    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        model_dataset_configs = get_model_datasets_list(variant_dataset)
        for model_dataset_config in model_dataset_configs:
            model_dataset_name = model_dataset_config.get('name')
            clustered_tsv = f"data/{variant_dataset}.{model_dataset_name}.clustered.tsv"
            
            # Check if clustered.tsv exists
            if not os.path.exists(clustered_tsv):
                continue
            
            # Read clustered.tsv and check if variant exists
            try:
                df = pd.read_csv(clustered_tsv, sep='\t', usecols=['variant_id', 'kmeans_35'])
                variant_row = df[df['variant_id'] == variant_id]
                
                if len(variant_row) > 0:
                    cluster_id = int(variant_row.iloc[0]['kmeans_35'])
                    cluster_name = get_cluster_name_from_id(variant_dataset, model_dataset_name, cluster_id)
                    
                    metadata_list.append({
                        'variant_id': variant_id,
                        'variant_dataset': variant_dataset,
                        'model_dataset': model_dataset_name,
                        'cluster_id': cluster_id,
                        'cluster_name': cluster_name
                    })
            except (KeyError, ValueError, pd.errors.EmptyDataError):
                # File exists but doesn't have required columns or is empty
                continue
    
    return metadata_list

def get_variant_plot_outputs(variant_id, variant_dataset, model_dataset, cluster_id, optional_models=None):
    """Get list of all plot outputs for a variant.
    
    Parameters:
    -----------
    variant_id : str
        Variant ID (original format, e.g., "chr1:123456:A:G")
    variant_dataset : str
        Variant dataset name
    model_dataset : str
        Model dataset name
    cluster_id : int or None
        Cluster ID (kmeans_35 value) or None if no clustering
    optional_models : list of str, optional
        Optional list of models to generate profile plots for (in addition to prioritized models)
    
    Returns:
    --------
    list of str
        Paths to all plot files for this variant
    """
    outputs = []
    
    # Use variant_id directly in paths (same as existing rules)
    # Barplot outputs
    if cluster_id is not None:
        cluster_name = get_cluster_name_from_id(variant_dataset, model_dataset, cluster_id)
        if cluster_name:
            outputs.append(f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset}/{cluster_name}/{variant_id}/01-model-specificity-barplot.png")
            outputs.append(f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset}/{cluster_name}/{variant_id}/01-model-specificity-barplot-superset.png")
    else:
        outputs.append(f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset}/{variant_id}/01-model-specificity-barplot.png")
    
    # Scatterplot outputs
    if cluster_id is not None:
        cluster_name = get_cluster_name_from_id(variant_dataset, model_dataset, cluster_id)
        if cluster_name:
            outputs.append(f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset}/{cluster_name}/{variant_id}/02-model-scatterplot.html")
            outputs.append(f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset}/{cluster_name}/{variant_id}/02-model-scatterplot-superset.html")
    else:
        outputs.append(f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset}/{variant_id}/02-model-scatterplot.html")
    
    # Profile plots - collect all models to generate
    models_to_plot = set()
    
    # Add prioritized models
    try:
        prioritized_models = get_prioritized_models_for_variant(variant_dataset, variant_id, model_dataset)
        models_to_plot.update(prioritized_models)
    except Exception as e:
        # If we can't get prioritized models, just skip (will use optional_models if provided)
        pass
    
    # Add optional models (user-specified)
    if optional_models:
        if isinstance(optional_models, str):
            # Handle comma-separated string
            optional_models = [m.strip() for m in optional_models.split(',') if m.strip()]
        models_to_plot.update(optional_models)
    
    # Generate profile plot paths for all models
    for model in models_to_plot:
        if cluster_id is not None:
            cluster_name = get_cluster_name_from_id(variant_dataset, model_dataset, cluster_id)
            if cluster_name:
                outputs.append(f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset}/{cluster_name}/{variant_id}/03-profile-{model}.png")
                outputs.append(f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset}/{cluster_name}/{variant_id}/03-profile-{model}.md")
        else:
            outputs.append(f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset}/{variant_id}/03-profile-{model}.png")
            outputs.append(f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset}/{variant_id}/03-profile-{model}.md")
    
    return outputs

def read_variant_ids_from_file(file_path):
    """Read variant_ids from a text file (one per line).
    
    Parameters:
    -----------
    file_path : str
        Path to text file with variant IDs
    
    Returns:
    --------
    list of str
        Variant IDs (stripped, with empty lines and comments ignored)
    """
    import os
    
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Variant IDs file not found: {file_path}")
    
    variant_ids = []
    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            # Skip empty lines and comments
            if line and not line.startswith('#'):
                variant_ids.append(line)
    
    return variant_ids

def get_all_plots_for_variants(variant_ids, optional_models=None):
    """Get all plot outputs for a list of variant_ids.
    
    Parameters:
    -----------
    variant_ids : list of str or str
        List of variant IDs, or comma-separated string
    optional_models : list of str or str, optional
        Optional list of models to generate profile plots for (comma-separated string or list)
    
    Returns:
    --------
    list of str
        All plot file paths for the specified variants
    """
    # Parse variant_ids (handle both list and comma-separated string)
    if isinstance(variant_ids, str):
        variant_ids = [v.strip() for v in variant_ids.split(',') if v.strip()]
    
    # Parse optional_models if provided
    if optional_models and isinstance(optional_models, str):
        optional_models = [m.strip() for m in optional_models.split(',') if m.strip()]
    
    all_outputs = []
    found_variants = set()
    
    for variant_id in variant_ids:
        metadata_list = get_variant_metadata(variant_id)
        
        if not metadata_list:
            print(f"Warning: Variant '{variant_id}' not found in any clustered.tsv files")
            continue
        
        found_variants.add(variant_id)
        for metadata in metadata_list:
            outputs = get_variant_plot_outputs(
                metadata['variant_id'],
                metadata['variant_dataset'],
                metadata['model_dataset'],
                metadata['cluster_id'],
                optional_models=optional_models
            )
            all_outputs.extend(outputs)
    
    if not found_variants:
        raise ValueError("No variants found in any clustered.tsv files. Check variant IDs and ensure clustered.tsv files exist.")
    
    # Remove duplicates (variant might be in multiple model_datasets)
    return list(set(all_outputs))

# ----------------------------------------------------------------------------
# Finemo Annotation Rules (Split Files Per Model)
# ----------------------------------------------------------------------------

def get_variant_dataset_for_model(model_name):
    """Determine which variant_dataset(s) a model belongs to.
    
    Returns the first variant_dataset found, or None if not found.
    In practice, a model typically belongs to one variant_dataset.
    """
    import fnmatch
    
    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        model_dataset_configs = get_model_datasets_list(variant_dataset)
        for model_dataset_config in model_dataset_configs:
            model_patterns = model_dataset_config.get('models', [])
            for pattern in model_patterns:
                if fnmatch.fnmatch(model_name, pattern):
                    return variant_dataset
    return None

def get_general_tsv_for_model(model_name):
    """Get the general.tsv file path for a model.
    
    Determines the variant_dataset from the model and uses that variant_dataset's
    general.tsv file. Raises an error if the variant_dataset cannot be determined
    or if the file doesn't exist.
    """
    import os
    
    variant_dataset = get_variant_dataset_for_model(model_name)
    if variant_dataset:
        general_tsv = f"{SPLITS_DIR}/{variant_dataset}.general.tsv"
        if os.path.exists(general_tsv):
            return general_tsv
        else:
            raise FileNotFoundError(
                f"General file not found for variant_dataset '{variant_dataset}' (model: {model_name}): {general_tsv}"
            )
    
    # If we can't determine variant_dataset, raise an error
    raise ValueError(f"Could not determine variant_dataset for model '{model_name}'. Cannot determine which general.tsv to use.")

rule annotate_finemo_split_file:
    """Generate per-model finemo annotation split files with incremental processing.

    Only processes variants demanded by VARIANT_DATASET_CONFIGS (prioritized by
    the model AND in configured clusters).

    This rule uses incremental processing: it checks for new demanded variants and adds
    them to the existing file. The rule ALWAYS runs to check for new variants, even if
    the output file exists.

    IMPORTANT: Only one instance per model can run at a time to avoid GPU memory issues.
    """
    input:
        variants_tsv = lambda w: get_general_tsv_for_model(w.model),
        model_paths_tsv = MODEL_PATHS_TSV,
        modisco_h5 = MODISCO_H5
    output:
        SPLITS_DIR + "/finemo/broad.finemo.{model}.tsv"
    params:
        model = lambda w: w.model,
        varbook_dir = VARBOOK_DIR
    log:
        "logs/finemo/broad.finemo.{model}.log"
    resources:
        # Only allow 1 job per model at a time (prevents multiple finemo processes for same model)
        # This is a string resource (model name), so Snakemake ensures only one job per model runs
        gpu_per_model = lambda w: w.model,
        # Limit total number of finemo jobs running simultaneously (across all models)
        # This is an integer resource that can be set on command line: --resources gpu_total=3
        gpu_total = 1
    run:
        import pandas as pd
        import os
        import time
        from pathlib import Path

        start_time = time.time()
        model_name = params.model
        output_file = output[0]

        print(f"[TIMING] Starting annotate_finemo_split_file for {model_name} at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Getting demanded variants for {model_name}...")
        print(f"DEBUG: Model {model_name} in KUN_FB_MODELS: {model_name in KUN_FB_MODELS}")

        # Get demanded variants for THIS SPECIFIC MODEL only (much more efficient)
        demanded_variants_set = get_demanded_variants_for_model(model_name)
        
        get_demanded_time = time.time()
        print(f"[TIMING] Finished getting demanded variants in {get_demanded_time - start_time:.1f}s")
        print(f"DEBUG: Found {len(demanded_variants_set)} demanded variants for {model_name}")
        if len(demanded_variants_set) > 0:
            print(f"DEBUG: Sample variant IDs: {list(demanded_variants_set)[:5]}")
        
        # If optional models are specified, use only user-specified variants
        # This allows users to run finemo for specific variants even if they're not prioritized
        if config.get('models'):
            user_variant_ids = []
            try:
                if config.get('variant_ids'):
                    # Parse comma-separated string
                    variant_ids_str = config.get('variant_ids')
                    if isinstance(variant_ids_str, str):
                        user_variant_ids = [v.strip() for v in variant_ids_str.split(',') if v.strip()]
                    elif isinstance(variant_ids_str, list):
                        user_variant_ids = variant_ids_str
                elif config.get('variant_ids_file'):
                    # Read from file
                    user_variant_ids = read_variant_ids_from_file(config.get('variant_ids_file'))
                
                if user_variant_ids:
                    # When optional models are specified, use user-specified variants directly
                    # This allows finemo to run for variants even if they're not prioritized by the model
                    original_count = len(demanded_variants_set)
                    user_variant_set = set(user_variant_ids)
                    
                    # Use user-specified variants (they may or may not be in demanded variants)
                    # This is intentional - when optional models are specified, we want to run finemo
                    # for the user's variants regardless of prioritization status
                    demanded_variants_set = user_variant_set
                    print(f"DEBUG: Using user-specified variant_ids for optional model. Original demanded: {original_count}, User specified: {len(user_variant_set)}")
                    print(f"DEBUG: User-specified variant IDs: {list(demanded_variants_set)[:10]}")
            except Exception as e:
                print(f"WARNING: Error processing user-specified variant_ids: {e}")
                print(f"  Continuing with all demanded variants (no filtering applied)")
        
        if len(demanded_variants_set) == 0:
            print(f"No variants demanded for {model_name}")
            print(f"DEBUG: This could be due to:")
            print(f"  - clustered.tsv files not existing")
            print(f"  - model_prioritized_by_any-{model_name} column missing")
            print(f"  - No variants prioritized by this model in configured clusters")
            print(f"  - Model name doesn't match any patterns in config")
            # Create empty file with correct structure
            empty_df = pd.DataFrame(columns=[
                'variant_id',
                'finemo_motif_hits',
                'finemo_motif_hit_count',
                'finemo_motif_top_hit',
                'finemo_motif_top_score',
                'finemo_motif_positions',
                'finemo_motif_ref_summary',
                'finemo_motif_alt_summary',
                'finemo_motif_diff_summary'
            ])
            os.makedirs(os.path.dirname(output_file), exist_ok=True)
            empty_df.to_csv(output_file, sep='\t', index=False)
            # Touch file to update timestamp
            Path(output_file).touch()
            return

        demanded_variants = list(demanded_variants_set)

        # Get new variants to process (incremental)
        new_variants = get_finemo_new_variants(demanded_variants, output_file)

        if len(new_variants) == 0:
            print(f"All {len(demanded_variants)} demanded variants already processed for {model_name}")
            # Touch file to update timestamp so Snakemake knows rule ran
            Path(output_file).touch()
            return

        print(f"Processing {len(new_variants)} new variants for {model_name} " +
              f"({len(demanded_variants) - len(new_variants)} already done)")

        # Filter variants TSV to only new variants
        variants_df = pd.read_csv(input.variants_tsv, sep='\t')
        new_variants_df = variants_df[variants_df['variant_id'].isin(new_variants)]
        
        # Check for missing variants
        missing_variants = set(new_variants) - set(variants_df['variant_id'])
        if missing_variants:
            print(f"Warning: {len(missing_variants)} variants not found in general.tsv: {list(missing_variants)[:5]}")
            # Filter out missing variants from processing
            new_variants = [v for v in new_variants if v not in missing_variants]
            new_variants_df = variants_df[variants_df['variant_id'].isin(new_variants)]
        
        if len(new_variants_df) == 0:
            print(f"Warning: No valid variants to process after filtering. All variants may be missing from general.tsv.")
            # Touch output file to update timestamp
            Path(output_file).touch()
            return

        # Rename allele columns to ref/alt for finemo compatibility
        if 'allele1' in new_variants_df.columns and 'allele2' in new_variants_df.columns:
            new_variants_df = new_variants_df.rename(columns={'allele1': 'ref', 'allele2': 'alt'})

        # Save to temp file
        temp_input = f"/tmp/finemo_input_{model_name}.tsv"
        new_variants_df.to_csv(temp_input, sep='\t', index=False)

        # Create log directory and get absolute log path
        log_abs_path = os.path.abspath(log[0])
        os.makedirs(os.path.dirname(log_abs_path), exist_ok=True)

        # Run finemo (use least-used GPU)
        temp_output = f"/tmp/finemo_output_{model_name}.tsv"
        import time
        start_time = time.time()
        
        # Select least-used GPU
        selected_gpu = get_least_used_gpu()
        print(f"[TIMING] Starting finemo annotation for {model_name} ({len(new_variants)} variants) at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Using GPU {selected_gpu} (least used)")
        
        shell(f"""
        START_TIME=$(date +%s)
        echo "[TIMING] Starting finemo annotation for {model_name} ({len(new_variants)} variants) at $(date)" >&2
        echo "Using GPU {selected_gpu} (least used)" >&2
        export CUDA_VISIBLE_DEVICES={selected_gpu} && cd {params.varbook_dir} && {VENV_PYTHON} -m varbook annotate motif finemo {temp_input} variant_id \
          --model-paths-tsv {input.model_paths_tsv} \
          --models {model_name} \
          --modisco-h5 {input.modisco_h5} \
          --alpha 0.8 \
          --hits-per-variant 20 \
          --n-shuffles 20 \
          --window-size 300 \
          --device cuda \
          -o {temp_output} \
          > {log_abs_path} 2>&1
        END_TIME=$(date +%s)
        DURATION=$((END_TIME - START_TIME))
        echo "[TIMING] Finished finemo annotation for {model_name} in $$DURATIONs at $(date)" >&2
        """)
        end_time = time.time()
        duration = end_time - start_time
        print(f"[TIMING] Finemo annotation for {model_name} completed in {duration:.1f}s")

        # Check if finemo output exists (finemo may have failed)
        if not os.path.exists(temp_output):
            raise RuntimeError(f"Finemo output file not found: {temp_output}. Check log: {log_abs_path}")

        # Load new results
        new_results_df = pd.read_csv(temp_output, sep='\t')
        
        # Debug: Print available columns
        print(f"DEBUG: Columns in finemo output: {list(new_results_df.columns)}")

        # Remove model suffix from column names
        # Only rename columns that exist (varbook outputs columns with model suffix)
        rename_map = {
            f'finemo_motif_hits_{model_name}': 'finemo_motif_hits',
            f'finemo_motif_hit_count_{model_name}': 'finemo_motif_hit_count',
            f'finemo_motif_top_hit_{model_name}': 'finemo_motif_top_hit',
            f'finemo_motif_top_score_{model_name}': 'finemo_motif_top_score',
            f'finemo_motif_allele_diff_{model_name}': 'finemo_motif_allele_diff',
            f'finemo_motif_positions_{model_name}': 'finemo_motif_positions',
            f'finemo_motif_ref_summary_{model_name}': 'finemo_motif_ref_summary',
            f'finemo_motif_alt_summary_{model_name}': 'finemo_motif_alt_summary',
            f'finemo_motif_diff_summary_{model_name}': 'finemo_motif_diff_summary'
        }
        # Filter to only columns that exist
        rename_map = {k: v for k, v in rename_map.items() if k in new_results_df.columns}
        if not rename_map:
            # If no columns with suffix found, check if columns already don't have suffix
            finemo_cols_no_suffix = [c for c in new_results_df.columns if c.startswith('finemo_') and not c.endswith(f'_{model_name}')]
            if finemo_cols_no_suffix:
                print(f"DEBUG: Columns already don't have model suffix: {finemo_cols_no_suffix}")
                # Columns already don't have suffix, use as-is
            else:
                raise ValueError(
                    f"Expected finemo columns with model suffix '{model_name}' not found in output. "
                    f"Available columns: {list(new_results_df.columns)}"
                )
        else:
            new_results_df = new_results_df.rename(columns=rename_map)
            print(f"DEBUG: Renamed {len(rename_map)} columns")

        # Keep only variant_id and finemo columns
        finemo_cols = [c for c in new_results_df.columns if c.startswith('finemo_')]
        new_results_df = new_results_df[['variant_id'] + finemo_cols]

        # Merge with existing results if file exists
        if os.path.exists(output_file):
            existing_df = pd.read_csv(output_file, sep='\t')
            # Deduplicate by variant_id (in case of any overlap)
            merged_df = pd.concat([existing_df, new_results_df], ignore_index=True)
            merged_df = merged_df.drop_duplicates(subset=['variant_id'], keep='last')
        else:
            merged_df = new_results_df

        # Save merged results
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        merged_df.to_csv(output_file, sep='\t', index=False)

        # Clean up temp files
        os.remove(temp_input)
        os.remove(temp_output)

        print(f"Saved {len(merged_df)} total variants for {model_name} to {output_file}")

rule all_finemo_split_files:
    """Generate all finemo split files for KUN_FB models."""
    input:
        expand(
            SPLITS_DIR + "/finemo/broad.finemo.{model}.tsv",
            model=KUN_FB_MODELS
        )

# ----------------------------------------------------------------------------
# Generate Plots for User-Provided Variants (Optional)
# ----------------------------------------------------------------------------

rule generate_plots_for_variants:
    """Generate all plots for user-specified variants.
    
    This is an OPTIONAL rule that allows users to generate plots for specific
    variants without running the entire pipeline. It does not change the default
    behavior of the pipeline.
    
    Usage:
        # From command line with variant_ids file:
        snakemake generate_plots_for_variants --config variant_ids_file=variants.txt
        
        # From command line with comma-separated list:
        snakemake generate_plots_for_variants --config variant_ids="chr1:123:A:G,chr2:456:C:T"
        
        # With optional models for profile plots:
        snakemake generate_plots_for_variants --config variant_ids="chr1:123:A:G" models="KUN_FB_glutamatergic_neuron_2,KUN_FB_glutamatergic_neuron_3"
        
        # From config.yaml:
        # variant_ids_file: "variants.txt"
        # OR
        # variant_ids: "chr1:123:A:G,chr2:456:C:T"
        # models: "KUN_FB_glutamatergic_neuron_2,KUN_FB_glutamatergic_neuron_3"  # Optional
    
    The rule automatically:
    - Finds which variant_datasets and model_datasets contain each variant
    - Determines which clusters each variant belongs to
    - Generates all plot types (barplot, scatterplot, profile plots)
    - If models are specified, generates profile plots for those models (in addition to prioritized models)
    """
    input:
        # Dynamically generate list of all plot outputs based on variant_ids and optional models
        # Note: config is accessed at rule definition time, not execution time
        lambda wildcards: get_all_plots_for_variants(
            config.get('variant_ids', '') if config.get('variant_ids') else (
                read_variant_ids_from_file(config.get('variant_ids_file', '')) 
                if config.get('variant_ids_file') else []
            ),
            optional_models=config.get('models', None)
        )
    output:
        # Sentinel file to indicate completion
        "plots_for_variants.done"
    params:
        variant_ids = lambda wildcards: config.get('variant_ids', ''),
        variant_ids_file = lambda wildcards: config.get('variant_ids_file', ''),
        models = lambda wildcards: config.get('models', '')
    run:
        import os
        
        # Get variant_ids from config
        variant_ids = params.variant_ids
        if not variant_ids and params.variant_ids_file:
            variant_ids = read_variant_ids_from_file(params.variant_ids_file)
        
        if not variant_ids:
            raise ValueError(
                "No variant_ids provided. Use --config variant_ids='...' or "
                "--config variant_ids_file=path/to/file.txt"
            )
        
        # Parse variant_ids (handle both list and comma-separated string)
        if isinstance(variant_ids, str):
            variant_ids = [v.strip() for v in variant_ids.split(',') if v.strip()]
        
        # Parse optional models
        optional_models = params.models
        if optional_models:
            if isinstance(optional_models, str):
                optional_models = [m.strip() for m in optional_models.split(',') if m.strip()]
            print(f"Generating profile plots for {len(optional_models)} optional model(s): {', '.join(optional_models)}")
        
        print(f"Generating plots for {len(variant_ids)} variant(s):")
        for vid in variant_ids[:10]:  # Show first 10
            print(f"  - {vid}")
        if len(variant_ids) > 10:
            print(f"  ... and {len(variant_ids) - 10} more")
        
        # Get all plot outputs (should match input list)
        all_outputs = get_all_plots_for_variants(variant_ids, optional_models=optional_models)
        print(f"Total plot outputs to generate: {len(all_outputs)}")
        
        # Verify all outputs exist (they should, since this rule depends on them)
        missing = [f for f in all_outputs if not os.path.exists(f)]
        if missing:
            # This shouldn't happen if dependencies are correct, but provide helpful error
            print(f"Warning: {len(missing)} plot outputs are missing (this may indicate a dependency issue)")
            print(f"First few missing files:")
            for f in missing[:5]:
                print(f"  - {f}")
        
        # Create sentinel file
        shell("touch {output}")

# ----------------------------------------------------------------------------
# HTML Report Generation
# ----------------------------------------------------------------------------

def get_variant_dataset_paths():
    """Build list of variant_dataset paths with full hierarchy for HTML generation.

    Returns paths like:
    - "Broad neurodevelopmental and neuromuscular disorders:Fetal Brain:microglia-specific cluster (#3)"

    Uses colon as separator which gets converted to / by build_hierarchical_file_list().
    """
    paths = []
    for variant_dataset in VARIANT_DATASET_CONFIGS.keys():
        model_dataset_configs = get_model_datasets_list(variant_dataset)
        for model_dataset_config in model_dataset_configs:
            model_dataset_name = model_dataset_config['name']
            clusters = model_dataset_config.get('clusters', [])

            if clusters:
                for cluster in clusters:
                    if isinstance(cluster, dict):
                        cluster_name = cluster.get('name', cluster.get('id'))
                    else:
                        cluster_name = cluster

                    # Build full path with colon separator: variant_dataset:model_dataset:cluster
                    path = f"{variant_dataset}:{model_dataset_name}:{cluster_name}"
                    paths.append(path)
            else:
                # No clusters, just variant_dataset:model_dataset
                path = f"{variant_dataset}:{model_dataset_name}"
                paths.append(path)

    return paths

rule generate_html:
    """Generate HTML report with live-editing server.

    IMPORTANT: This rule must run from the snakemake/ directory because
    generate_html() looks for variant files in varbook_gen/ relative to CWD.
    The Snakefile is already in snakemake/, so this works correctly when
    running 'snakemake' from the snakemake/ directory.
    """
    input:
        heatmaps = get_all_heatmap_outputs(),
        variants = get_all_variant_outputs(),
        finemo_splits = get_all_finemo_split_files(),
        comprehensive_tsvs = get_comprehensive_variants_tsvs()
    output:
        "variant_report.html"
    params:
        variant_datasets = " ".join(get_variant_dataset_paths()),
        # Use the first comprehensive TSV (there's only one for "Broad neurodevelopmental and neuromuscular disorders")
        variants_tsv = lambda w, input: input.comprehensive_tsvs[0] if input.comprehensive_tsvs else ""
    shell:
        """
        START_TIME=$(date +%s)
        echo "[TIMING] Starting generate_html at $(date)" >&2
        {VARBOOK_CMD} write html-live '{output}' \
            --variant-datasets '{params.variant_datasets}' \
            --variants-tsv '{params.variants_tsv}' \
            --toc \
            --debug-paths \
            --port 8765
        END_TIME=$(date +%s)
        DURATION=$((END_TIME - START_TIME))
        echo "[TIMING] Finished generate_html in $$DURATIONs at $(date)" >&2
        """

# ----------------------------------------------------------------------------
# Utility rules
# ----------------------------------------------------------------------------

rule clean:
    """Clean all generated files."""
    shell:
        "rm -rf {OUTPUT_DIR} variant_report.html"
