"""
Simplified Snakefile for motif (finemo) generation.
Usage: snakemake --snakefile Snakefile.finemo splits/finemo/broad.finemo.MODEL_NAME.tsv
"""

import pandas as pd
import os
import time
from pathlib import Path

# ============================================================================
# Wildcard Constraints
# ============================================================================

wildcard_constraints:
    model="[^/]+",
    variant_id="[^/]+"

# ============================================================================
# Configuration
# ============================================================================

VENV_PYTHON = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/varbook-container/.venv/bin/python"
VARBOOK_CMD = f"{VENV_PYTHON} -m varbook"

# Input data paths
MODEL_PATHS_TSV = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/metadata_renaming/broad.model_paths.tsv"
GENOME_FA = "/oak/stanford/groups/akundaje/soumyak/refs/hg38/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta"

# Finemo-specific paths
SPLITS_DIR = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/splits"
# MODISCO_H5 = "/oak/stanford/groups/akundaje/airanman/refs/motif-compendium/modisco_invitro_avg.070325.h5"
MODISCO_H5 = "/oak/stanford/groups/akundaje/airanman/refs/motif-compendium/modisco_invitro_avg.070325.h5"
VARBOOK_DIR = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/varbook-container/varbook"

# ============================================================================
# Helper Functions
# ============================================================================

def get_kun_fb_models(model_paths_tsv):
    """Get list of all KUN_FB models from model_paths_tsv."""
    df = pd.read_csv(model_paths_tsv, sep='\t')
    kun_fb_models = df[df['model_name'].str.startswith('KUN_FB')]['model_name'].unique().tolist()
    return kun_fb_models

def get_finemo_new_variants(prioritized_variants, output_file):
    """Get list of new variants to process (not in existing output)."""
    if not os.path.exists(output_file):
        return prioritized_variants

    existing_df = pd.read_csv(output_file, sep='\t')
    already_processed = set(existing_df['variant_id'].tolist())

    return [v for v in prioritized_variants if v not in already_processed]

def get_least_used_gpu():
    """Find the GPU with the least memory usage. Returns GPU device ID (0-3)."""
    import subprocess
    import re
    
    try:
        # Run nvidia-smi to get GPU memory usage
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=index,memory.used', '--format=csv,noheader,nounits'],
            capture_output=True,
            text=True,
            check=True
        )
        
        # Parse output: "0, 1234\n1, 5678\n..."
        gpu_usage = []
        for line in result.stdout.strip().split('\n'):
            if line.strip():
                parts = line.split(',')
                if len(parts) == 2:
                    gpu_id = int(parts[0].strip())
                    memory_used = int(parts[1].strip())
                    gpu_usage.append((gpu_id, memory_used))
        
        if not gpu_usage:
            # Fallback to GPU 2 if we can't parse
            print("Warning: Could not parse nvidia-smi output, defaulting to GPU 2")
            return 2
        
        # Sort by memory usage (ascending) and return GPU with least usage
        gpu_usage.sort(key=lambda x: x[1])
        least_used_gpu = gpu_usage[0][0]
        least_used_memory = gpu_usage[0][1]
        
        print(f"GPU usage: {dict(gpu_usage)}")
        print(f"Selected GPU {least_used_gpu} (memory used: {least_used_memory} MB)")
        return least_used_gpu
        
    except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
        # Fallback to GPU 2 if nvidia-smi fails
        print(f"Warning: Could not determine GPU usage ({e}), defaulting to GPU 2")
        return 2

def read_variant_ids_from_file(file_path):
    """Read variant IDs from a file (one per line)."""
    variant_ids = []
    if os.path.exists(file_path):
        with open(file_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line and not line.startswith('#'):
                    variant_ids.append(line)
    return variant_ids

def get_prioritized_variants_for_model(model_name, prioritization_tsv=None):
    """Get variants prioritized by a specific model from prioritization TSV."""
    # Try to find prioritization TSV if not provided
    if prioritization_tsv is None:
        # Look for common prioritization TSV patterns
        possible_paths = [
            f"{SPLITS_DIR}/broad.model_prioritized_by_any-{model_name}.tsv",
            f"{SPLITS_DIR}/broad.model_prioritized_by_any-KUN_FB.tsv",
        ]
        
        for path in possible_paths:
            if os.path.exists(path):
                prioritization_tsv = path
                break
        
        if prioritization_tsv is None:
            print(f"Warning: Could not find prioritization TSV for {model_name}")
            return []
    
    if not os.path.exists(prioritization_tsv):
        print(f"Warning: Prioritization TSV not found: {prioritization_tsv}")
        return []
    
    # Read prioritization TSV
    df = pd.read_csv(prioritization_tsv, sep='\t')
    priority_col = f'model_prioritized_by_any-{model_name}'
    
    if priority_col not in df.columns:
        print(f"Warning: Column '{priority_col}' not found in {prioritization_tsv}")
        return []
    
    # Filter to prioritized variants
    prioritized_mask = (
        (df[priority_col].astype(str).str.lower() == 'true') |
        (df[priority_col] == True) |
        (df[priority_col] == 1) |
        (df[priority_col].astype(str) == '1')
    ).fillna(False)
    
    prioritized_variants = df[prioritized_mask]['variant_id'].tolist()
    return prioritized_variants

def get_general_tsv_for_model(model_name):
    """Get the general.tsv file path for a model. Auto-detects common patterns."""
    # Try common patterns
    possible_paths = [
        f"{SPLITS_DIR}/broad.general.tsv",
        f"{SPLITS_DIR}/{model_name.split('_')[0]}.general.tsv",  # e.g., KUN_FB -> KUN_FB.general.tsv
    ]
    
    for path in possible_paths:
        if os.path.exists(path):
            return path
    
    # If not found, raise error
    raise FileNotFoundError(
        f"General TSV file not found for model '{model_name}'. "
        f"Tried: {possible_paths}. "
        f"Please specify the correct path in the config."
    )

# Get list of all KUN_FB models
KUN_FB_MODELS = get_kun_fb_models(MODEL_PATHS_TSV)

# ============================================================================
# Rules
# ============================================================================

rule finemo:
    """Generate finemo motif annotations for a specific model.
    Usage: snakemake splits/finemo/broad.finemo.MODEL_NAME.tsv --config variant_ids="chr1:123:A:G"
    """
    output:
        finemo_tsv = SPLITS_DIR + "/finemo/broad.finemo.{model}.tsv"
    params:
        model = lambda w: w.model,
        general_tsv = lambda w: config.get('general_tsv', get_general_tsv_for_model(w.model)),
        prioritization_tsv = lambda w: config.get('prioritization_tsv', None),
    log:
        "logs/finemo/broad.finemo.{model}.log"
    resources:
        # Only allow 1 job per model at a time (prevents multiple finemo processes for same model)
        gpu_per_model = lambda w: w.model,
        # Limit total number of finemo jobs running simultaneously
        gpu_total = 1
    run:
        import pandas as pd
        import os
        import time
        from pathlib import Path

        start_time = time.time()
        model_name = params.model
        output_file = output.finemo_tsv
        general_tsv = params.general_tsv
        prioritization_tsv = params.prioritization_tsv

        print(f"[TIMING] Starting finemo annotation for {model_name} at {time.strftime('%Y-%m-%d %H:%M:%S')}")

        # Determine which variants to process
        variant_ids_to_process = []
        
        # Option 1: User-specified variant IDs (highest priority)
        if config.get('variant_ids'):
            variant_ids_str = config.get('variant_ids')
            if isinstance(variant_ids_str, str):
                variant_ids_to_process = [v.strip() for v in variant_ids_str.split(',') if v.strip()]
            elif isinstance(variant_ids_str, list):
                variant_ids_to_process = variant_ids_str
            print(f"Using {len(variant_ids_to_process)} user-specified variant IDs")
        
        # Option 2: Variant IDs from file
        elif config.get('variant_ids_file'):
            variant_ids_to_process = read_variant_ids_from_file(config.get('variant_ids_file'))
            print(f"Using {len(variant_ids_to_process)} variant IDs from file: {config.get('variant_ids_file')}")
        
        # Option 3: Prioritized variants (default)
        else:
            variant_ids_to_process = get_prioritized_variants_for_model(model_name, prioritization_tsv)
            print(f"Using {len(variant_ids_to_process)} prioritized variants for {model_name}")
        
        if len(variant_ids_to_process) == 0:
            print(f"No variants to process for {model_name}")
            # Create empty file with correct structure
            empty_df = pd.DataFrame(columns=[
                'variant_id',
                'finemo_motif_hits',
                'finemo_motif_hit_count',
                'finemo_motif_top_hit',
                'finemo_motif_top_score',
                'finemo_motif_positions',
                'finemo_motif_ref_summary',
                'finemo_motif_alt_summary',
                'finemo_motif_diff_summary'
            ])
            os.makedirs(os.path.dirname(output_file), exist_ok=True)
            empty_df.to_csv(output_file, sep='\t', index=False)
            Path(output_file).touch()
            return

        # Get new variants to process (incremental)
        new_variants = get_finemo_new_variants(variant_ids_to_process, output_file)

        if len(new_variants) == 0:
            print(f"All {len(variant_ids_to_process)} variants already processed for {model_name}")
            # Touch file to update timestamp so Snakemake knows rule ran
            Path(output_file).touch()
            return

        print(f"Processing {len(new_variants)} new variants for {model_name} " +
              f"({len(variant_ids_to_process) - len(new_variants)} already done)")

        # Read variants TSV and filter to new variants
        if not os.path.exists(general_tsv):
            raise FileNotFoundError(f"General TSV file not found: {general_tsv}")
        
        variants_df = pd.read_csv(general_tsv, sep='\t')
        new_variants_df = variants_df[variants_df['variant_id'].isin(new_variants)].copy()
        
        # Check for missing variants
        missing_variants = set(new_variants) - set(variants_df['variant_id'])
        if missing_variants:
            print(f"Warning: {len(missing_variants)} variants not found in general.tsv: {list(missing_variants)[:5]}")
            # Filter out missing variants from processing
            new_variants = [v for v in new_variants if v not in missing_variants]
            new_variants_df = variants_df[variants_df['variant_id'].isin(new_variants)].copy()
        
        if len(new_variants_df) == 0:
            print(f"Warning: No valid variants to process after filtering.")
            Path(output_file).touch()
            return

        # Rename allele columns to ref/alt for finemo compatibility
        if 'allele1' in new_variants_df.columns and 'allele2' in new_variants_df.columns:
            new_variants_df = new_variants_df.rename(columns={'allele1': 'ref', 'allele2': 'alt'})

        # Save to temp file
        temp_input = f"/tmp/finemo_input_{model_name}.tsv"
        new_variants_df.to_csv(temp_input, sep='\t', index=False)

        # Create log directory and get absolute log path
        log_abs_path = os.path.abspath(log[0])
        os.makedirs(os.path.dirname(log_abs_path), exist_ok=True)

        # Run finemo (use least-used GPU)
        temp_output = f"/tmp/finemo_output_{model_name}.tsv"
        
        # Select least-used GPU
        selected_gpu = get_least_used_gpu()
        print(f"[TIMING] Starting finemo annotation for {model_name} ({len(new_variants)} variants) at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Using GPU {selected_gpu} (least used)")
        
        shell(f"""
        START_TIME=$(date +%s)
        echo "[TIMING] Starting finemo annotation for {model_name} ({len(new_variants)} variants) at $(date)" >&2
        echo "Using GPU {selected_gpu} (least used)" >&2
        export CUDA_VISIBLE_DEVICES={selected_gpu} && cd {VARBOOK_DIR} && {VENV_PYTHON} -m varbook annotate motif finemo {temp_input} variant_id \
          --model-paths-tsv {MODEL_PATHS_TSV} \
          --models {model_name} \
          --modisco-h5 {MODISCO_H5} \
          --alpha 0.8 \
          --hits-per-variant 20 \
          --n-shuffles 20 \
          --window-size 300 \
          --device cuda \
          -o {temp_output} \
          > {log_abs_path} 2>&1
        END_TIME=$(date +%s)
        DURATION=$((END_TIME - START_TIME))
        echo "[TIMING] Finished finemo annotation for {model_name} in $$DURATIONs at $(date)" >&2
        """)
        
        end_time = time.time()
        duration = end_time - start_time
        print(f"[TIMING] Finemo annotation for {model_name} completed in {duration:.1f}s")

        # Check if finemo output exists
        if not os.path.exists(temp_output):
            raise RuntimeError(f"Finemo output file not found: {temp_output}. Check log: {log_abs_path}")

        # Load new results
        new_results_df = pd.read_csv(temp_output, sep='\t')
        
        # Remove model suffix from column names (if present)
        rename_map = {
            f'finemo_motif_hits_{model_name}': 'finemo_motif_hits',
            f'finemo_motif_hit_count_{model_name}': 'finemo_motif_hit_count',
            f'finemo_motif_top_hit_{model_name}': 'finemo_motif_top_hit',
            f'finemo_motif_top_score_{model_name}': 'finemo_motif_top_score',
            f'finemo_motif_allele_diff_{model_name}': 'finemo_motif_allele_diff',
            f'finemo_motif_positions_{model_name}': 'finemo_motif_positions',
            f'finemo_motif_ref_summary_{model_name}': 'finemo_motif_ref_summary',
            f'finemo_motif_alt_summary_{model_name}': 'finemo_motif_alt_summary',
            f'finemo_motif_diff_summary_{model_name}': 'finemo_motif_diff_summary'
        }
        # Filter to only columns that exist
        rename_map = {k: v for k, v in rename_map.items() if k in new_results_df.columns}
        if rename_map:
            new_results_df = new_results_df.rename(columns=rename_map)

        # Keep only variant_id and finemo columns
        finemo_cols = [c for c in new_results_df.columns if c.startswith('finemo_')]
        new_results_df = new_results_df[['variant_id'] + finemo_cols]

        # Merge with existing results if file exists
        if os.path.exists(output_file):
            existing_df = pd.read_csv(output_file, sep='\t')
            # Deduplicate by variant_id (in case of any overlap)
            merged_df = pd.concat([existing_df, new_results_df], ignore_index=True)
            merged_df = merged_df.drop_duplicates(subset=['variant_id'], keep='last')
        else:
            merged_df = new_results_df

        # Save merged results
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        merged_df.to_csv(output_file, sep='\t', index=False)

        # Clean up temp files
        os.remove(temp_input)
        os.remove(temp_output)

        print(f"Saved {len(merged_df)} total variants for {model_name} to {output_file}")

rule finemo_model:
    """Convenience rule to run finemo for a model specified in config.
    Usage: snakemake finemo_model --config model=KUN_FB_microglia
    """
    input:
        finemo_tsv = lambda: SPLITS_DIR + "/finemo/broad.finemo." + config.get('model', '') + ".tsv"
    output:
        sentinel = "finemo_model.done"
    shell:
        "touch {output.sentinel}"

