#!/usr/bin/env python3
"""
Simple script to run profile plotting for variants.
By default, runs for all models on their prioritized variants.

Configuration is read from environment variables (set by config.sh).
You can source the config file before running:
    source config.sh
    python run_profiles.py

Or set environment variables directly.

Usage: 
    python run_profiles.py [--model "KUN_FB.*"] [--variant-id "chr10:.*"]
    python run_profiles.py --variant-id "chr10:55578669:G:C" --model "KUN_FB_microglia"
"""

import argparse
import pandas as pd
import os
import re
import subprocess
import sys
import time
from pathlib import Path
from multiprocessing import Pool
from functools import partial

# Configuration - read from environment variables (set by config.sh)
# Defaults are provided for backward compatibility
VENV_PYTHON = os.environ.get('VENV_PYTHON')
MODEL_PATHS_TSV = os.environ.get('MODEL_PATHS_TSV')
SPLITS_DIR = os.environ.get('SPLITS_DIR')
VARBOOK_DIR = os.environ.get('VARBOOK_DIR')
OUTPUT_DIR = os.environ.get('OUTPUT_DIR')


def get_available_gpus():
    """Get list of all available GPUs."""
    try:
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=index', '--format=csv,noheader,nounits'],
            capture_output=True,
            text=True,
            check=True
        )
        
        gpu_ids = []
        for line in result.stdout.strip().split('\n'):
            if line.strip():
                try:
                    gpu_id = int(line.strip())
                    gpu_ids.append(gpu_id)
                except ValueError:
                    continue
        
        if not gpu_ids:
            print("Warning: Could not find any GPUs, defaulting to GPU 0")
            return [0]
        
        return sorted(gpu_ids)
        
    except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
        print(f"Warning: Could not determine available GPUs ({e}), defaulting to GPU 0")
        return [0]


def get_least_used_gpu():
    """Find the GPU with the least memory usage. Returns GPU device ID (0-3)."""
    try:
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=index,memory.used', '--format=csv,noheader,nounits'],
            capture_output=True,
            text=True,
            check=True
        )
        
        gpu_usage = []
        for line in result.stdout.strip().split('\n'):
            if line.strip():
                parts = line.split(',')
                if len(parts) == 2:
                    gpu_id = int(parts[0].strip())
                    memory_used = int(parts[1].strip())
                    gpu_usage.append((gpu_id, memory_used))
        
        if not gpu_usage:
            print("Warning: Could not parse nvidia-smi output, defaulting to GPU 0")
            return 0
        
        gpu_usage.sort(key=lambda x: x[1])
        least_used_gpu = gpu_usage[0][0]
        least_used_memory = gpu_usage[0][1]
        
        print(f"GPU usage: {dict(gpu_usage)}")
        print(f"Selected GPU {least_used_gpu} (memory used: {least_used_memory} MB)")
        return least_used_gpu
        
    except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
        print(f"Warning: Could not determine GPU usage ({e}), defaulting to GPU 0")
        return 0


def read_variant_ids_from_file(file_path):
    """Read variant IDs from a file (one per line)."""
    variant_ids = []
    if os.path.exists(file_path):
        with open(file_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line and not line.startswith('#'):
                    variant_ids.append(line)
    return variant_ids


def find_clustered_tsv_files():
    """Find all clustered.tsv files in the data directory.
    
    Returns:
    --------
    list of tuples
        List of (variant_dataset, model_dataset, path) tuples
    """
    clustered_files = []
    
    # Look for clustered.tsv files in common locations
    # Based on Snakefile: data/{variant_dataset}.{model_dataset}.clustered.tsv
    possible_dirs = [
        "data",
        os.path.join(os.path.dirname(os.path.dirname(__file__)), "data"),
        os.path.join(SPLITS_DIR, "..", "data") if SPLITS_DIR else None,
    ]
    
    for base_dir in possible_dirs:
        if base_dir and os.path.exists(base_dir):
            for filename in os.listdir(base_dir):
                if filename.endswith('.clustered.tsv'):
                    # Parse: variant_dataset.model_dataset.clustered.tsv
                    parts = filename.replace('.clustered.tsv', '').split('.')
                    if len(parts) >= 2:
                        # Join all but last part as variant_dataset (may contain dots)
                        variant_dataset = '.'.join(parts[:-1])
                        model_dataset = parts[-1]
                        path = os.path.join(base_dir, filename)
                        clustered_files.append((variant_dataset, model_dataset, path))
    
    return clustered_files


def get_clustered_tsv_path(variant_dataset, model_dataset):
    """Get the path to clustered.tsv file.
    
    Parameters:
    -----------
    variant_dataset : str
        Variant dataset name
    model_dataset : str
        Model dataset name
    
    Returns:
    --------
    str or None
        Path to clustered.tsv file, or None if not found
    """
    # Try multiple possible locations
    possible_paths = [
        f"data/{variant_dataset}.{model_dataset}.clustered.tsv",
        os.path.join(os.path.dirname(os.path.dirname(__file__)), f"data/{variant_dataset}.{model_dataset}.clustered.tsv"),
        os.path.join(SPLITS_DIR, "..", f"data/{variant_dataset}.{model_dataset}.clustered.tsv") if SPLITS_DIR else None,
    ]
    
    for path in possible_paths:
        if path and os.path.exists(path):
            return path
    
    return None


def get_prioritized_variants_for_model(model_name, variant_dataset=None, model_dataset=None):
    """Get variants prioritized by a specific model from clustered.tsv.
    
    Parameters:
    -----------
    model_name : str
        Model name (e.g., "KUN_FB_microglia")
    variant_dataset : str, optional
        Variant dataset name (auto-detect if not provided)
    model_dataset : str, optional
        Model dataset name (auto-detect if not provided)
    
    Returns:
    --------
    list of str
        List of variant IDs prioritized by this model
    """
    # Auto-detect variant_dataset and model_dataset if not provided
    if variant_dataset is None or model_dataset is None:
        clustered_files = find_clustered_tsv_files()
        for vd, md, path in clustered_files:
            df = pd.read_csv(path, sep='\t', nrows=1)  # Just check columns
            model_col = f'model_prioritized_by_any-{model_name}'
            if model_col in df.columns:
                variant_dataset = vd
                model_dataset = md
                break
        
        if variant_dataset is None or model_dataset is None:
            print(f"Warning: Could not find clustered.tsv with prioritization for {model_name}")
            return []
    
    clustered_tsv = get_clustered_tsv_path(variant_dataset, model_dataset)
    if not clustered_tsv or not os.path.exists(clustered_tsv):
        print(f"Warning: Clustered TSV not found: {clustered_tsv}")
        return []
    
    df = pd.read_csv(clustered_tsv, sep='\t')
    model_col = f'model_prioritized_by_any-{model_name}'
    
    if model_col not in df.columns:
        print(f"Warning: Column '{model_col}' not found in {clustered_tsv}")
        return []
    
    prioritized_mask = (
        (df[model_col].astype(str).str.lower() == 'true') |
        (df[model_col] == True) |
        (df[model_col] == 1) |
        (df[model_col].astype(str) == '1')
    ).fillna(False)
    
    return df[prioritized_mask]['variant_id'].tolist()


def get_prioritized_models_for_variant(variant_id, variant_dataset=None, model_dataset=None):
    """Get list of models that prioritize a specific variant.
    
    Parameters:
    -----------
    variant_id : str
        Variant ID (e.g., "chr10:55578669:G:C")
    variant_dataset : str, optional
        Variant dataset name (auto-detect if not provided)
    model_dataset : str, optional
        Model dataset name (auto-detect if not provided)
    
    Returns:
    --------
    list of str
        List of model names that prioritize this variant
    """
    # Auto-detect variant_dataset and model_dataset if not provided
    if variant_dataset is None or model_dataset is None:
        clustered_files = find_clustered_tsv_files()
        for vd, md, path in clustered_files:
            df = pd.read_csv(path, sep='\t')
            if variant_id in df['variant_id'].values:
                variant_dataset = vd
                model_dataset = md
                break
        
        if variant_dataset is None or model_dataset is None:
            print(f"Warning: Could not find variant {variant_id} in any clustered.tsv")
            return []
    
    clustered_tsv = get_clustered_tsv_path(variant_dataset, model_dataset)
    if not clustered_tsv or not os.path.exists(clustered_tsv):
        print(f"Warning: Clustered TSV not found: {clustered_tsv}")
        return []
    
    df = pd.read_csv(clustered_tsv, sep='\t')
    variant_row = df[df['variant_id'] == variant_id]
    
    if variant_row.empty:
        return []
    
    prioritized_models = []
    
    # Check all model_prioritized_by_any-* columns
    for col in df.columns:
        if col.startswith('model_prioritized_by_any-'):
            value = variant_row[col].iloc[0]
            if pd.notna(value) and (
                (str(value).lower() == 'true') or
                (value == True) or
                (value == 1) or
                (str(value) == '1')
            ):
                # Extract model name from column: model_prioritized_by_any-{model_name}
                model_name = col.replace('model_prioritized_by_any-', '')
                prioritized_models.append(model_name)
    
    return prioritized_models


def get_all_models(model_paths_tsv):
    """Get list of all models from model_paths_tsv."""
    df = pd.read_csv(model_paths_tsv, sep='\t')
    all_models = df['model_name'].unique().tolist()
    return sorted(all_models)


def filter_models_by_pattern(models, pattern):
    """Filter models by regex pattern."""
    try:
        regex = re.compile(pattern)
        filtered = [m for m in models if regex.search(m)]
        return filtered
    except re.error as e:
        raise ValueError(f"Invalid regex pattern '{pattern}': {e}")


def filter_variants_by_pattern(variants, pattern):
    """Filter variants by regex pattern."""
    try:
        regex = re.compile(pattern)
        filtered = [v for v in variants if regex.search(v)]
        return filtered
    except re.error as e:
        raise ValueError(f"Invalid regex pattern '{pattern}': {e}")


def get_variants_tsv_path(variant_dataset):
    """Get the path to variants TSV file (general.tsv or clustered.tsv).
    
    Parameters:
    -----------
    variant_dataset : str
        Variant dataset name
    
    Returns:
    --------
    str or None
        Path to variants TSV file, or None if not found
    """
    # Try multiple possible paths
    possible_paths = [
        f"{SPLITS_DIR}/{variant_dataset}.general.tsv",
        f"{SPLITS_DIR}/broad.general.tsv",  # Fallback to short name
        f"{SPLITS_DIR}/{variant_dataset.split()[0]}.general.tsv",  # First word
    ]
    
    for path in possible_paths:
        if path and os.path.exists(path):
            return path
    
    return None


def get_finemo_tsv_path(model_name):
    """Get the path to finemo TSV file for a model.
    
    Parameters:
    -----------
    model_name : str
        Model name
    
    Returns:
    --------
    str or None
        Path to finemo TSV file, or None if not found
    """
    possible_paths = [
        f"{SPLITS_DIR}/finemo/broad.finemo.{model_name}.tsv",
        f"{SPLITS_DIR}/finemo/{model_name}.finemo.tsv",
    ]
    
    for path in possible_paths:
        if path and os.path.exists(path):
            return path
    
    return None


def process_model_profiles(model_name, args, gpu_id=None):
    """Process profile generation for a single model (batch mode).
    
    Parameters:
    -----------
    model_name : str
        Model name
    args : argparse.Namespace
        Command-line arguments
    gpu_id : int, optional
        GPU device ID to use (overrides args.gpu if provided)
    """
    start_time = time.time()
    
    print(f"\n{'='*80}")
    print(f"[TIMING] Starting profile generation for {model_name} at {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'='*80}")
    
    # Determine which variants to process
    variant_ids_to_process = []
    variant_dataset = args.variant_dataset
    model_dataset = args.model_dataset
    
    if args.variant_id:
        # Filter variants by regex pattern
        # First, get all prioritized variants for this model
        all_prioritized = get_prioritized_variants_for_model(model_name, variant_dataset, model_dataset)
        variant_ids_to_process = filter_variants_by_pattern(all_prioritized, args.variant_id)
        print(f"Using {len(variant_ids_to_process)} variants matching pattern '{args.variant_id}' for {model_name}")
    elif args.variant_ids_file:
        variant_ids_to_process = read_variant_ids_from_file(args.variant_ids_file)
        print(f"Using {len(variant_ids_to_process)} variant IDs from file: {args.variant_ids_file}")
    else:
        variant_ids_to_process = get_prioritized_variants_for_model(model_name, variant_dataset, model_dataset)
        print(f"Using {len(variant_ids_to_process)} prioritized variants for {model_name}")
    
    if len(variant_ids_to_process) == 0:
        print(f"No variants to process for {model_name}")
        return
    
    # Get variants TSV path
    if variant_dataset:
        variants_tsv = get_variants_tsv_path(variant_dataset)
    else:
        # Auto-detect from clustered files
        clustered_files = find_clustered_tsv_files()
        if clustered_files:
            variant_dataset = clustered_files[0][0]
            variants_tsv = get_variants_tsv_path(variant_dataset)
        else:
            print(f"Warning: Could not find variants TSV file")
            return
    
    if not variants_tsv or not os.path.exists(variants_tsv):
        print(f"Warning: Variants TSV file not found: {variants_tsv}. Skipping {model_name}.")
        return
    
    # Check for existing profiles (incremental processing)
    if args.overwrite:
        # Overwrite mode: process all variants regardless of existing files
        variants_to_process = variant_ids_to_process
        print(f"Overwrite mode: will regenerate all {len(variants_to_process)} variants for {model_name}")
    elif OUTPUT_DIR and variant_dataset:
        base_dir = os.path.join(OUTPUT_DIR, variant_dataset, "profiles")
        variants_to_process = []
        for variant_id in variant_ids_to_process:
            png_file = os.path.join(base_dir, variant_id, f"{model_name}.png")
            if not os.path.exists(png_file):
                variants_to_process.append(variant_id)
            else:
                # Check if PNG is older than variants_tsv (needs regeneration)
                try:
                    png_mtime = os.path.getmtime(png_file)
                    input_mtime = os.path.getmtime(variants_tsv)
                    if png_mtime < input_mtime:
                        variants_to_process.append(variant_id)
                except OSError:
                    variants_to_process.append(variant_id)
    else:
        variants_to_process = variant_ids_to_process
    
    if len(variants_to_process) == 0:
        print(f"All {len(variant_ids_to_process)} variants already have profiles for {model_name}")
        if not args.overwrite:
            print(f"  Use --overwrite to force regeneration")
        return
    
    print(f"Processing {len(variants_to_process)}/{len(variant_ids_to_process)} variants for {model_name}")
    
    # Get finemo TSV path (optional, for motif overlays)
    motifs_tsv = args.motifs_tsv
    if not motifs_tsv:
        motifs_tsv = get_finemo_tsv_path(model_name)
    
    # Build varbook command (batch mode)
    # Format: variants_tsv model_name --batch-variants variant1 variant2 ... --other-flags
    cmd = [
        VENV_PYTHON, "-m", "varbook", "plot", "variant", "profiles",
        variants_tsv,
        model_name,
        "--model-paths-tsv", MODEL_PATHS_TSV,
        "--batch-variants"
    ] + variants_to_process + [
        "--variant-dataset", variant_dataset,
        "--n-shuffles", str(args.n_shuffles),
        "--device", args.device
    ]
    
    # Add motifs TSV if available
    if motifs_tsv and os.path.exists(motifs_tsv):
        cmd.extend(["--motifs-tsv", motifs_tsv])
    
    # Don't pass -o - let varbook use its default path structure
    # Profiles will be saved to: {OUTPUT_DIR}/{variant_dataset}/profiles/{variant_id}/{model_name}.png
    
    # Select GPU if using CUDA
    selected_gpu = None
    if args.device == 'cuda':
        if gpu_id is not None:
            selected_gpu = gpu_id
        elif args.gpu is not None:
            selected_gpu = args.gpu
        else:
            selected_gpu = get_least_used_gpu()
        print(f"[TIMING] Starting profile generation for {model_name} ({len(variants_to_process)} variants) at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Using GPU {selected_gpu}")
    else:
        print(f"[TIMING] Starting profile generation for {model_name} ({len(variants_to_process)} variants) at {time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"Using device: {args.device} (note: CPU is much slower than GPU)")
    print(f"SHAP shuffles: {args.n_shuffles} (reduce with --n-shuffles for faster processing)")
    print()
    
    # Set CUDA_VISIBLE_DEVICES if using GPU
    env = os.environ.copy()
    if args.device == 'cuda' and selected_gpu is not None:
        env['CUDA_VISIBLE_DEVICES'] = str(selected_gpu)
    
    # Set VARBOOK_DEFAULT_OUTPUT_DIR to absolute path so varbook uses absolute paths
    if OUTPUT_DIR:
        env['VARBOOK_DEFAULT_OUTPUT_DIR'] = OUTPUT_DIR
    
    # Run varbook command with real-time output
    try:
        result = subprocess.run(
            cmd,
            check=True,
            text=True,
            cwd=VARBOOK_DIR,
            env=env
        )
    except subprocess.CalledProcessError as e:
        print(f"Error: Profile generation failed for {model_name}. Exit code: {e.returncode}", file=sys.stderr)
        if e.stdout:
            print(f"STDOUT:\n{e.stdout}", file=sys.stderr)
        if e.stderr:
            print(f"STDERR:\n{e.stderr}", file=sys.stderr)
        raise
    
    end_time = time.time()
    duration = end_time - start_time
    print(f"[TIMING] Profile generation for {model_name} completed in {duration:.1f}s")
    print(f"Generated profiles for {len(variants_to_process)} variants")
    
    # Generate URLs for all profile SVG files (including existing ones)
    # varbook saves dataset-agnostically to: {OUTPUT_DIR}/profiles/{variant_id}/{model_name}.svg
    # Note: varbook may use relative paths, so check both relative (from VARBOOK_DIR) and absolute paths
    print(f"\nGenerating URLs for profile SVG files...")
    for variant_id in variant_ids_to_process:
        svg_file = None
        if OUTPUT_DIR:
            # Check absolute path first
            svg_file = os.path.join(OUTPUT_DIR, "profiles", variant_id, f"{model_name}.svg")
            if not os.path.exists(svg_file):
                # Try relative path from VARBOOK_DIR (where varbook runs)
                relative_path = os.path.join("varbook_gen", "profiles", variant_id, f"{model_name}.svg")
                if VARBOOK_DIR:
                    svg_file = os.path.join(VARBOOK_DIR, relative_path)
                else:
                    svg_file = relative_path
            
            # Also check dataset-specific symlink if variant_dataset is provided
            if variant_dataset:
                symlink_file = os.path.join(OUTPUT_DIR, variant_dataset, "profiles", variant_id, f"{model_name}.svg")
                if os.path.exists(symlink_file):
                    svg_file = symlink_file  # Use symlink for URL generation
        
        if svg_file and os.path.exists(svg_file):
            try:
                result = subprocess.run(
                    ['mitra-utils', 'url', svg_file],
                    check=False,
                    capture_output=True,
                    text=True
                )
                if result.returncode == 0:
                    url_output = result.stdout.strip() if result.stdout else ""
                    if not url_output and result.stderr:
                        url_output = result.stderr.strip()
                    if url_output:
                        print(f"  {variant_id}/{model_name}.svg: {url_output}")
                    else:
                        print(f"  {variant_id}/{model_name}.svg: (URL generation returned empty)")
                else:
                    print(f"  Warning: Could not generate URL for {svg_file} (exit code {result.returncode})")
                    if result.stderr:
                        print(f"    {result.stderr.strip()}")
            except Exception as e:
                print(f"  Warning: Could not generate URL for {svg_file}: {e}")
        else:
            print(f"  Warning: Profile SVG file not found for {variant_id}/{model_name}.svg")


def process_variant_profiles(variant_id, args):
    """Process profile generation for a single variant (multiple models).
    
    Parameters:
    -----------
    variant_id : str
        Variant ID
    args : argparse.Namespace
        Command-line arguments
    """
    start_time = time.time()
    
    print(f"\n{'='*80}")
    print(f"[TIMING] Starting profile generation for variant {variant_id} at {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'='*80}")
    
    # Get models to process
    models_to_process = []
    variant_dataset = args.variant_dataset
    model_dataset = args.model_dataset
    
    if args.model:
        # When --model is specified, use it directly (could be exact name or regex pattern)
        # First try to see if it matches any models in MODEL_PATHS_TSV (for regex patterns)
        all_models = get_all_models(MODEL_PATHS_TSV)
        models_to_process = filter_models_by_pattern(all_models, args.model)
        
        if len(models_to_process) > 0:
            print(f"Found {len(models_to_process)} models matching pattern '{args.model}'")
        else:
            # No matches in MODEL_PATHS_TSV - check if pattern contains | (multiple models)
            if '|' in args.model:
                # Split by | and try each model individually
                individual_models = [m.strip() for m in args.model.split('|') if m.strip()]
                print(f"No models in MODEL_PATHS_TSV match pattern '{args.model}'")
                print(f"Splitting into {len(individual_models)} individual models and using them directly (varbook will validate)")
                models_to_process = individual_models
            else:
                # Single model name - use it directly (varbook will validate if the model exists)
                print(f"No models in MODEL_PATHS_TSV match pattern '{args.model}'")
                print(f"Using model name '{args.model}' directly (varbook will validate)")
                models_to_process = [args.model]
    else:
        # Get prioritized models for this variant
        models_to_process = get_prioritized_models_for_variant(variant_id, variant_dataset, model_dataset)
        print(f"Found {len(models_to_process)} prioritized models for {variant_id}")
    
    if len(models_to_process) == 0:
        print(f"No models to process for variant {variant_id}")
        return
    
    # Get variants TSV path
    if variant_dataset:
        variants_tsv = get_variants_tsv_path(variant_dataset)
    else:
        # Auto-detect from clustered files
        clustered_files = find_clustered_tsv_files()
        if clustered_files:
            variant_dataset = clustered_files[0][0]
            variants_tsv = get_variants_tsv_path(variant_dataset)
        else:
            print(f"Warning: Could not find variants TSV file")
            return
    
    if not variants_tsv or not os.path.exists(variants_tsv):
        print(f"Warning: Variants TSV file not found: {variants_tsv}")
        return
    
    # Process each model
    for model_name in models_to_process:
        # Check if profile already exists
        # varbook saves to: {OUTPUT_DIR}/profiles/{variant_id}/{model_name}.png (primary location)
        # Symlink may exist at: {OUTPUT_DIR}/{variant_dataset}/profiles/{variant_id}/{model_name}.png
        profile_exists = False
        existing_png_file = None
        if OUTPUT_DIR:
            # Check primary dataset-agnostic location first
            png_file = os.path.join(OUTPUT_DIR, "profiles", variant_id, f"{model_name}.png")
            if not os.path.exists(png_file) and variant_dataset:
                # Fallback: check dataset-specific symlink location
                symlink_file = os.path.join(OUTPUT_DIR, variant_dataset, "profiles", variant_id, f"{model_name}.png")
                if os.path.exists(symlink_file):
                    png_file = symlink_file
            
            # Always print what we're checking (for debugging)
            print(f"  Checking for existing profile at: {png_file}")
            print(f"    Exists: {os.path.exists(png_file)}")
            
            if os.path.exists(png_file):
                # Verify it's actually a file (not a broken symlink)
                if not os.path.isfile(png_file):
                    print(f"  Warning: {png_file} exists but is not a regular file (might be broken symlink)")
                    print(f"  Will regenerate the file...")
                    # Continue to generate the file
                else:
                    try:
                        png_mtime = os.path.getmtime(png_file)
                        input_mtime = os.path.getmtime(variants_tsv)
                        file_size = os.path.getsize(png_file)
                        print(f"    File size: {file_size} bytes")
                        print(f"    File mtime: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(png_mtime))}")
                        print(f"    Input mtime: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(input_mtime))}")
                        
                        if png_mtime >= input_mtime and not args.overwrite:
                            print(f"Skipping {model_name} (profile already exists and is up-to-date)")
                            print(f"  File: {png_file}")
                            print(f"  Use --overwrite to force regeneration")
                            profile_exists = True
                            existing_png_file = png_file
                            
                            # Still generate URL even if skipping generation (use SVG)
                            existing_svg_file = existing_png_file.replace('.png', '.svg')
                            if os.path.exists(existing_svg_file):
                                try:
                                    url_result = subprocess.run(
                                        ['mitra-utils', 'url', existing_svg_file],
                                        check=False,
                                        capture_output=True,
                                        text=True
                                    )
                                    if url_result.returncode == 0:
                                        url_output = url_result.stdout.strip() if url_result.stdout else ""
                                        if not url_output and url_result.stderr:
                                            url_output = url_result.stderr.strip()
                                        if url_output:
                                            print(f"  URL: {url_output}")
                                except Exception as e:
                                    print(f"  Warning: Could not generate URL: {e}")
                            continue
                        elif args.overwrite:
                            print(f"  Overwrite mode: will regenerate {model_name} even though file exists")
                        else:
                            print(f"  File exists but is older than input, will regenerate...")
                    except OSError as e:
                        print(f"  Warning: Could not check file timestamps: {e}")
                        # Continue to generate the file
            else:
                print(f"  File does not exist, will generate...")
        
        # Get finemo TSV path (optional, for motif overlays)
        motifs_tsv = args.motifs_tsv
        if not motifs_tsv:
            motifs_tsv = get_finemo_tsv_path(model_name)
        
        # Build varbook command (single variant, single model)
        cmd = [
            VENV_PYTHON, "-m", "varbook", "plot", "variant", "profiles",
            variants_tsv,
            variant_id,
            model_name,
            "--model-paths-tsv", MODEL_PATHS_TSV,
            "--variant-dataset", variant_dataset,
            "--n-shuffles", str(args.n_shuffles),
            "--device", args.device
        ]
        
        # Add motifs TSV if available
        if motifs_tsv and os.path.exists(motifs_tsv):
            cmd.extend(["--motifs-tsv", motifs_tsv])
        
        # Don't pass -o - let varbook use its default path structure
        # Profiles will be saved to: {OUTPUT_DIR}/{variant_dataset}/profiles/{variant_id}/{model_name}.png
        
        # Select GPU if using CUDA
        selected_gpu = None
        if args.device == 'cuda':
            if args.gpu is not None:
                selected_gpu = args.gpu
            else:
                selected_gpu = get_least_used_gpu()
            print(f"Generating profile for {variant_id} with model {model_name}...")
            print(f"Using GPU {selected_gpu}")
        else:
            print(f"Generating profile for {variant_id} with model {model_name}...")
            print(f"Note: Using CPU (much slower). Steps include:")
            print(f"  1. Loading model (5 folds)")
            print(f"  2. Extracting sequences from genome")
            print(f"  3. Making predictions (forward + reverse complement)")
            print(f"  4. Calculating SHAP attributions ({args.n_shuffles} shuffles on {args.device})")
            print(f"  5. Generating plot")
        print()
        
        # Set CUDA_VISIBLE_DEVICES if using GPU
        env = os.environ.copy()
        if args.device == 'cuda' and selected_gpu is not None:
            env['CUDA_VISIBLE_DEVICES'] = str(selected_gpu)
        
        # Set VARBOOK_DEFAULT_OUTPUT_DIR to absolute path so varbook uses absolute paths
        if OUTPUT_DIR:
            env['VARBOOK_DEFAULT_OUTPUT_DIR'] = OUTPUT_DIR
            print(f"  OUTPUT_DIR: {OUTPUT_DIR}")
            print(f"  Expected primary location: {os.path.join(OUTPUT_DIR, 'profiles', variant_id, f'{model_name}.png')}")
            if variant_dataset:
                print(f"  Expected symlink location: {os.path.join(OUTPUT_DIR, variant_dataset, 'profiles', variant_id, f'{model_name}.png')}")
        
        # Run varbook command with real-time output
        try:
            result = subprocess.run(
                cmd,
                check=True,
                text=True,
                cwd=VARBOOK_DIR,
                env=env
            )
            
            # Generate URL for the profile SVG file if it was created
            # varbook saves dataset-agnostically to: {OUTPUT_DIR}/profiles/{variant_id}/{model_name}.svg
            # If variant_dataset is provided, there's also a symlink at: {OUTPUT_DIR}/{variant_dataset}/profiles/{variant_id}/{model_name}.svg
            if OUTPUT_DIR:
                # Check dataset-agnostic location first (primary) - this is where varbook actually saves
                primary_file = os.path.join(OUTPUT_DIR, "profiles", variant_id, f"{model_name}.svg")
                svg_file = None
                
                print(f"\n  Checking for generated profile SVG file...")
                print(f"    Primary location: {primary_file}")
                print(f"      Exists: {os.path.exists(primary_file)}")
                
                if os.path.exists(primary_file):
                    svg_file = primary_file
                    print(f"    ✓ Found at primary location")
                else:
                    # Check dataset-specific symlink if variant_dataset is provided
                    if variant_dataset:
                        symlink_file = os.path.join(OUTPUT_DIR, variant_dataset, "profiles", variant_id, f"{model_name}.svg")
                        print(f"    Symlink location: {symlink_file}")
                        print(f"      Exists: {os.path.exists(symlink_file)}")
                        if os.path.exists(symlink_file):
                            svg_file = symlink_file
                            print(f"    ✓ Found at symlink location")
                
                if svg_file and os.path.exists(svg_file):
                    # Verify it's actually a file
                    if os.path.isfile(svg_file):
                        file_size = os.path.getsize(svg_file)
                        print(f"    File size: {file_size} bytes")
                        try:
                            url_result = subprocess.run(
                                ['mitra-utils', 'url', svg_file],
                                check=False,
                                capture_output=True,
                                text=True
                            )
                            if url_result.returncode == 0:
                                url_output = url_result.stdout.strip() if url_result.stdout else ""
                                if not url_output and url_result.stderr:
                                    url_output = url_result.stderr.strip()
                                if url_output:
                                    print(f"\nURL: {url_output}")
                                else:
                                    print(f"  Warning: URL generation returned empty output")
                            else:
                                print(f"  Warning: mitra-utils returned exit code {url_result.returncode}")
                                if url_result.stderr:
                                    print(f"  {url_result.stderr.strip()}")
                        except Exception as e:
                            print(f"  Warning: Could not generate URL: {e}")
                    else:
                        print(f"  Warning: {svg_file} exists but is not a regular file (might be broken symlink)")
                else:
                    print(f"  ⚠ Warning: Profile SVG file not found after generation!")
                    print(f"    Checked primary: {primary_file} (exists: {os.path.exists(primary_file)})")
                    if variant_dataset:
                        symlink_file = os.path.join(OUTPUT_DIR, variant_dataset, "profiles", variant_id, f"{model_name}.svg")
                        print(f"    Checked symlink: {symlink_file} (exists: {os.path.exists(symlink_file)})")
                    print(f"    OUTPUT_DIR: {OUTPUT_DIR}")
                    print(f"    Check varbook output above for actual save location")
        except subprocess.CalledProcessError as e:
            print(f"Error: Profile generation failed for {variant_id} with {model_name}. Exit code: {e.returncode}", file=sys.stderr)
            if e.stdout:
                print(f"STDOUT:\n{e.stdout}", file=sys.stderr)
            if e.stderr:
                print(f"STDERR:\n{e.stderr}", file=sys.stderr)
            continue
    
    end_time = time.time()
    duration = end_time - start_time
    print(f"[TIMING] Profile generation for variant {variant_id} completed in {duration:.1f}s")


def process_model_with_gpu_wrapper(args_tuple):
    """Wrapper function for parallel processing that can be pickled."""
    model_name, index, process_args, available_gpus = args_tuple
    gpu_id = available_gpus[index % len(available_gpus)] if available_gpus else None
    if gpu_id is not None:
        print(f"\n[Job {index+1}] Processing model: {model_name} on GPU {gpu_id}")
    else:
        print(f"\n[Job {index+1}] Processing model: {model_name} on CPU")
    try:
        process_model_profiles(model_name, process_args, gpu_id=gpu_id)
        return (model_name, True, None)
    except Exception as e:
        print(f"Error processing {model_name}: {e}")
        return (model_name, False, str(e))


def main():
    parser = argparse.ArgumentParser(
        description='Run profile plotting for variants. By default, runs for all models on their prioritized variants.'
    )
    parser.add_argument(
        '--model',
        help='Regex pattern to match model names (e.g., "KUN_FB.*" or "KUN_FB_microglia"). If not specified, runs for all models.'
    )
    parser.add_argument(
        '--variant-id',
        help='Regex pattern to match variant IDs (e.g., "chr10:.*" or "chr10:55578669:G:C"). If not specified, uses all prioritized variants.'
    )
    parser.add_argument(
        '--variant-ids-file',
        help='Path to file with variant IDs (one per line)'
    )
    parser.add_argument(
        '--variant-dataset',
        help='Variant dataset name (auto-detect if not provided)'
    )
    parser.add_argument(
        '--model-dataset',
        help='Model dataset name (auto-detect if not provided)'
    )
    parser.add_argument(
        '--motifs-tsv',
        help='Path to finemo motifs TSV file (auto-detect if not provided)'
    )
    parser.add_argument(
        '--device',
        choices=['cpu', 'cuda'],
        default='cuda',
        help='Device for computation (default: cuda). Use cpu only if GPUs unavailable.'
    )
    parser.add_argument(
        '--gpu',
        type=int,
        help='GPU device ID to use (auto-selects least used if not specified)'
    )
    parser.add_argument(
        '--n-shuffles',
        type=int,
        default=20,
        help='Number of dinucleotide shuffles for SHAP background (default: 20)'
    )
    parser.add_argument(
        '--max-jobs',
        type=int,
        default=1,
        help='Maximum number of parallel jobs (default: 1)'
    )
    parser.add_argument(
        '--overwrite',
        action='store_true',
        help='Force regeneration of profiles even if they already exist and are up-to-date'
    )
    
    args = parser.parse_args()
    
    # Replace semicolons with pipe (OR) in model pattern for regex matching
    if args.model:
        args.model = args.model.replace(';', '|')
    
    # Validate configuration
    if not VENV_PYTHON:
        print("Error: VENV_PYTHON not set. Please source config.sh or set environment variable.")
        sys.exit(1)
    if not MODEL_PATHS_TSV:
        print("Error: MODEL_PATHS_TSV not set. Please source config.sh or set environment variable.")
        sys.exit(1)
    if not OUTPUT_DIR:
        print("Error: OUTPUT_DIR not set. Please source config.sh or set environment variable.")
        sys.exit(1)
    
    # Determine processing mode
    if args.variant_id or args.variant_ids_file:
        # Variant-centric mode: process specific variant(s) with their prioritized models
        if args.variant_ids_file:
            variant_ids = read_variant_ids_from_file(args.variant_ids_file)
        else:
            # Find variants directly in clustered.tsv files (don't iterate through all models)
            variant_ids = []
            clustered_files = find_clustered_tsv_files()
            
            if not clustered_files:
                print("Error: No clustered.tsv files found. Cannot find variants.")
                sys.exit(1)
            
            # Collect all unique variant IDs from clustered files
            all_variants = set()
            for variant_dataset, model_dataset, path in clustered_files:
                try:
                    df = pd.read_csv(path, sep='\t', usecols=['variant_id'])
                    all_variants.update(df['variant_id'].tolist())
                except Exception as e:
                    print(f"Warning: Could not read {path}: {e}")
                    continue
            
            # Filter by regex pattern if specified
            if args.variant_id:
                variant_ids = filter_variants_by_pattern(list(all_variants), args.variant_id)
            else:
                variant_ids = list(all_variants)
        
        if not variant_ids:
            print("Warning: No variants found matching the specified criteria")
            sys.exit(1)
        
        print(f"Processing {len(variant_ids)} variant(s)...")
        
        # Process each variant
        for variant_id in variant_ids:
            try:
                process_variant_profiles(variant_id, args)
            except Exception as e:
                print(f"Error processing variant {variant_id}: {e}")
                continue
    else:
        # Model-centric mode: process all models with their prioritized variants (default)
        all_models = get_all_models(MODEL_PATHS_TSV)
        
        # Determine which models to process
        if args.model:
            models = filter_models_by_pattern(all_models, args.model)
            if not models:
                # Check if pattern contains | (multiple models)
                if '|' in args.model:
                    # Split by | and try each model individually
                    individual_models = [m.strip() for m in args.model.split('|') if m.strip()]
                    print(f"No models in MODEL_PATHS_TSV match pattern '{args.model}'")
                    print(f"Splitting into {len(individual_models)} individual models and using them directly (varbook will validate)")
                    models = individual_models
                else:
                    print(f"Warning: No models matched pattern '{args.model}'")
                    sys.exit(1)
            if models:
                print(f"Found {len(models)} models matching pattern '{args.model}'")
        else:
            models = all_models
            print(f"No model pattern specified. Running for all {len(models)} models...")
        
        # Get available GPUs if using CUDA
        available_gpus = []
        if args.device == 'cuda':
            available_gpus = get_available_gpus()
            print(f"Available GPUs: {available_gpus}")
        
        # Process models
        total_start_time = time.time()
        
        if args.max_jobs == 1:
            # Sequential processing
            for i, model_name in enumerate(models, 1):
                print(f"\n[{i}/{len(models)}] Processing model: {model_name}")
                try:
                    process_model_profiles(model_name, args)
                except Exception as e:
                    print(f"Error processing {model_name}: {e}")
                    continue
        else:
            # Parallel processing
            if args.device == 'cuda':
                max_jobs = min(args.max_jobs, len(available_gpus), len(models))
                print(f"Running {max_jobs} jobs in parallel across {len(available_gpus)} GPUs")
            else:
                max_jobs = min(args.max_jobs, len(models))
                print(f"Running {max_jobs} jobs in parallel (CPU mode)")
            
            # Prepare arguments for parallel processing
            # Each tuple contains: (model_name, index, args, available_gpus)
            if args.device == 'cuda':
                process_tuples = [(model_name, i, args, available_gpus) for i, model_name in enumerate(models)]
            else:
                # For CPU mode, pass empty list for GPUs
                process_tuples = [(model_name, i, args, []) for i, model_name in enumerate(models)]
            
            # Process models in parallel
            with Pool(processes=max_jobs) as pool:
                results = pool.map(process_model_with_gpu_wrapper, process_tuples)
            
            # Print summary
            successful = sum(1 for _, success, _ in results if success)
            failed = len(results) - successful
            print(f"\n{'='*80}")
            print(f"Processing complete: {successful} successful, {failed} failed")
            if failed > 0:
                print("Failed models:")
                for model_name, success, error in results:
                    if not success:
                        print(f"  - {model_name}: {error}")
        
        total_end_time = time.time()
        total_duration = total_end_time - total_start_time
        print(f"\n{'='*80}")
        print(f"[TIMING] Total time for {len(models)} models: {total_duration:.1f}s ({total_duration/60:.1f} minutes)")
        print(f"{'='*80}")


if __name__ == '__main__':
    main()

