#!/usr/bin/env python3
"""
Simple script to run finemo motif annotation for models.
By default, runs for all models on their prioritized variants.

Configuration is read from environment variables (set by config.sh).
You can source the config file before running:
    source config.sh
    python run_finemo.py

Or set environment variables directly.

Usage: python run_finemo.py [--model "KUN_FB.*"] [--variant-ids "chr1:123:A:G,chr2:456:C:T"]
"""

import argparse
import pandas as pd
import os
import re
import subprocess
import sys
import time
from pathlib import Path
from multiprocessing import Pool, Lock
from functools import partial

# Configuration - read from environment variables (set by config.sh)
# Defaults are provided for backward compatibility
VENV_PYTHON = os.environ.get('VENV_PYTHON')
MODEL_PATHS_TSV = os.environ.get('MODEL_PATHS_TSV')
SPLITS_DIR = os.environ.get('SPLITS_DIR')
MODISCO_H5 = os.environ.get('MODISCO_H5')
VARBOOK_DIR = os.environ.get('VARBOOK_DIR')


def get_available_gpus(skip_gpus=None):
    """Get list of all available GPUs, excluding those in skip_gpus.
    
    Parameters:
    -----------
    skip_gpus : list of int, optional
        List of GPU IDs to skip/exclude
        
    Returns:
    --------
    list of int
        List of available GPU IDs (excluding skipped ones)
    """
    if skip_gpus is None:
        skip_gpus = []
    
    try:
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=index', '--format=csv,noheader,nounits'],
            capture_output=True,
            text=True,
            check=True
        )
        
        gpu_ids = []
        for line in result.stdout.strip().split('\n'):
            if line.strip():
                try:
                    gpu_id = int(line.strip())
                    if gpu_id not in skip_gpus:
                        gpu_ids.append(gpu_id)
                except ValueError:
                    continue
        
        if not gpu_ids:
            # If all GPUs are skipped, warn but still return a default
            print(f"Warning: All GPUs are skipped ({skip_gpus}), defaulting to GPU 0")
            if 0 not in skip_gpus:
                return [0]
            else:
                # If even GPU 0 is skipped, return the first non-skipped GPU
                all_gpus = [int(line.strip()) for line in result.stdout.strip().split('\n') if line.strip()]
                for gpu in all_gpus:
                    if gpu not in skip_gpus:
                        print(f"Warning: Using GPU {gpu} as fallback")
                        return [gpu]
                return [0]  # Last resort
        
        return sorted(gpu_ids)
        
    except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
        print(f"Warning: Could not determine available GPUs ({e}), defaulting to GPU 0")
        if 0 not in (skip_gpus or []):
            return [0]
        return [0]  # Last resort


def get_least_used_gpu(skip_gpus=None):
    """Find the GPU with the least memory usage, excluding those in skip_gpus.
    
    Parameters:
    -----------
    skip_gpus : list of int, optional
        List of GPU IDs to skip/exclude
        
    Returns:
    --------
    int
        GPU device ID with least memory usage (excluding skipped GPUs)
    """
    if skip_gpus is None:
        skip_gpus = []
    
    try:
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=index,memory.used', '--format=csv,noheader,nounits'],
            capture_output=True,
            text=True,
            check=True
        )
        
        gpu_usage = []
        for line in result.stdout.strip().split('\n'):
            if line.strip():
                parts = line.split(',')
                if len(parts) == 2:
                    gpu_id = int(parts[0].strip())
                    if gpu_id not in skip_gpus:
                        memory_used = int(parts[1].strip())
                        gpu_usage.append((gpu_id, memory_used))
        
        if not gpu_usage:
            # If all GPUs are skipped, find any non-skipped GPU
            all_gpus = get_available_gpus(skip_gpus)
            if all_gpus:
                fallback_gpu = all_gpus[0]
                print(f"Warning: Could not parse nvidia-smi output, defaulting to GPU {fallback_gpu}")
                return fallback_gpu
            else:
                print("Warning: Could not parse nvidia-smi output, defaulting to GPU 2")
                return 2
        
        gpu_usage.sort(key=lambda x: x[1])
        least_used_gpu = gpu_usage[0][0]
        least_used_memory = gpu_usage[0][1]
        
        print(f"GPU usage: {dict(gpu_usage)}")
        print(f"Selected GPU {least_used_gpu} (memory used: {least_used_memory} MB)")
        return least_used_gpu
        
    except (subprocess.CalledProcessError, FileNotFoundError, ValueError) as e:
        print(f"Warning: Could not determine GPU usage ({e}), defaulting to GPU 2")
        return 2


def read_variant_ids_from_file(file_path):
    """Read variant IDs from a file (one per line)."""
    variant_ids = []
    if os.path.exists(file_path):
        with open(file_path, 'r') as f:
            for line in f:
                line = line.strip()
                if line and not line.startswith('#'):
                    variant_ids.append(line)
    return variant_ids


def get_prioritized_variants_for_model(model_name, prioritization_tsv=None):
    """Get variants prioritized by a specific model from prioritization TSV."""
    if prioritization_tsv is None:
        # Try multiple possible paths using the long name
        # First, try model-specific files (e.g., Broad neurodevelopmental and neuromuscular disorders.model_prioritized_by_any-KUN_FB.tsv)
        model_prefix = model_name.split('_')[0] + '_' + model_name.split('_')[1] if '_' in model_name else model_name.split('_')[0]
        variant_dataset_name = "Broad neurodevelopmental and neuromuscular disorders"
        possible_paths = [
            f"{SPLITS_DIR}/{variant_dataset_name}.model_prioritized_by_any-{model_prefix}.tsv",  # e.g., Broad neurodevelopmental and neuromuscular disorders.model_prioritized_by_any-KUN_FB.tsv
            f"{SPLITS_DIR}/{variant_dataset_name}.model_prioritized_by_any.tsv",  # General file with all models
        ]
        
        for path in possible_paths:
            if os.path.exists(path):
                prioritization_tsv = path
                break
        
        if prioritization_tsv is None:
            print(f"Warning: Could not find prioritization TSV for {model_name}. Tried: {possible_paths}")
            return []
    
    if not os.path.exists(prioritization_tsv):
        print(f"Warning: Prioritization TSV not found: {prioritization_tsv}")
        return []
    
    df = pd.read_csv(prioritization_tsv, sep='\t')
    priority_col = f'model_prioritized_by_any-{model_name}'
    
    if priority_col not in df.columns:
        print(f"Warning: Column '{priority_col}' not found in {prioritization_tsv}")
        return []
    
    prioritized_mask = (
        (df[priority_col].astype(str).str.lower() == 'true') |
        (df[priority_col] == True) |
        (df[priority_col] == 1) |
        (df[priority_col].astype(str) == '1')
    ).fillna(False)
    
    return df[prioritized_mask]['variant_id'].tolist()


def get_general_tsv_for_model(model_name):
    """Get the general.tsv file path for a model. Auto-detects common patterns."""
    variant_dataset_name = "Broad neurodevelopmental and neuromuscular disorders"
    possible_paths = [
        f"{SPLITS_DIR}/{variant_dataset_name}.general.tsv",
        f"{SPLITS_DIR}/broad.general.tsv",  # Fallback to short name
        f"{SPLITS_DIR}/{model_name.split('_')[0]}.general.tsv",
    ]
    
    for path in possible_paths:
        if os.path.exists(path):
            return path
    
    raise FileNotFoundError(
        f"General TSV file not found for model '{model_name}'. "
        f"Tried: {possible_paths}. Please specify with --general-tsv."
    )


def get_finemo_new_variants(prioritized_variants, output_file):
    """Get list of new variants to process (not in existing output)."""
    if not os.path.exists(output_file):
        return prioritized_variants

    existing_df = pd.read_csv(output_file, sep='\t')
    already_processed = set(existing_df['variant_id'].tolist())
    return [v for v in prioritized_variants if v not in already_processed]


def get_all_models(model_paths_tsv):
    """Get list of all models from model_paths_tsv."""
    df = pd.read_csv(model_paths_tsv, sep='\t')
    all_models = df['model_name'].unique().tolist()
    return sorted(all_models)


def filter_models_by_pattern(models, pattern):
    """Filter models by regex pattern."""
    try:
        regex = re.compile(pattern)
        filtered = [m for m in models if regex.search(m)]
        return filtered
    except re.error as e:
        raise ValueError(f"Invalid regex pattern '{pattern}': {e}")


def process_model_with_gpu_wrapper(args_tuple):
    """Wrapper function for parallel processing that can be pickled."""
    model_name, index, process_args, available_gpus = args_tuple
    gpu_id = available_gpus[index % len(available_gpus)]
    print(f"\n[Job {index+1}] Processing model: {model_name} on GPU {gpu_id}")
    try:
        process_model(model_name, process_args, gpu_id=gpu_id)
        return (model_name, True, None)
    except Exception as e:
        print(f"Error processing {model_name}: {e}")
        return (model_name, False, str(e))


def process_model(model_name, args, gpu_id=None):
    """Process finemo annotation for a single model."""
    start_time = time.time()
    
    # Determine output file
    if args.output:
        output_file = args.output
    else:
        output_file = f"{SPLITS_DIR}/finemo/broad.finemo.{model_name}.tsv"
    
    # Determine log file
    if args.log:
        log_file = args.log
    else:
        log_file = f"logs/finemo/broad.finemo.{model_name}.log"
    
    print(f"\n{'='*80}")
    print(f"[TIMING] Starting finemo annotation for {model_name} at {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'='*80}")
    
    # Determine which variants to process
    variant_ids_to_process = []
    
    if args.variant_ids:
        variant_ids_to_process = [v.strip() for v in args.variant_ids.split(',') if v.strip()]
        print(f"Using {len(variant_ids_to_process)} user-specified variant IDs")
    elif args.variant_ids_file:
        variant_ids_to_process = read_variant_ids_from_file(args.variant_ids_file)
        print(f"Using {len(variant_ids_to_process)} variant IDs from file: {args.variant_ids_file}")
    else:
        variant_ids_to_process = get_prioritized_variants_for_model(model_name, args.prioritization_tsv)
        print(f"Using {len(variant_ids_to_process)} prioritized variants for {model_name}")
    
    if len(variant_ids_to_process) == 0:
        print(f"No variants to process for {model_name}")
        # Create empty file with correct structure
        empty_df = pd.DataFrame(columns=[
            'variant_id',
            'finemo_motif_hits',
            'finemo_motif_hit_count',
            'finemo_motif_top_hit',
            'finemo_motif_top_score',
            'finemo_motif_positions',
            'finemo_motif_ref_summary',
            'finemo_motif_alt_summary',
            'finemo_motif_diff_summary'
        ])
        os.makedirs(os.path.dirname(output_file), exist_ok=True)
        empty_df.to_csv(output_file, sep='\t', index=False)
        return
    
    # Get new variants to process (incremental)
    new_variants = get_finemo_new_variants(variant_ids_to_process, output_file)
    
    if len(new_variants) == 0:
        print(f"All {len(variant_ids_to_process)} variants already processed for {model_name}")
        return
    
    print(f"Processing {len(new_variants)} new variants for {model_name} "
          f"({len(variant_ids_to_process) - len(new_variants)} already done)")
    
    # Get general TSV
    if args.general_tsv:
        general_tsv = args.general_tsv
    else:
        general_tsv = get_general_tsv_for_model(model_name)
    
    if not os.path.exists(general_tsv):
        print(f"Warning: General TSV file not found: {general_tsv}. Skipping {model_name}.")
        return
    
    # Read variants and filter to new variants
    variants_df = pd.read_csv(general_tsv, sep='\t')
    new_variants_df = variants_df[variants_df['variant_id'].isin(new_variants)].copy()
    
    # Check for missing variants
    missing_variants = set(new_variants) - set(variants_df['variant_id'])
    if missing_variants:
        print(f"Warning: {len(missing_variants)} variants not found in general.tsv: {list(missing_variants)[:5]}")
        new_variants = [v for v in new_variants if v not in missing_variants]
        new_variants_df = variants_df[variants_df['variant_id'].isin(new_variants)].copy()
    
    if len(new_variants_df) == 0:
        print(f"Warning: No valid variants to process after filtering.")
        return
    
    # Rename allele columns to ref/alt for finemo compatibility
    if 'allele1' in new_variants_df.columns and 'allele2' in new_variants_df.columns:
        new_variants_df = new_variants_df.rename(columns={'allele1': 'ref', 'allele2': 'alt'})
    
    # Save to temp file
    temp_input = f"/tmp/finemo_input_{model_name}.tsv"
    new_variants_df.to_csv(temp_input, sep='\t', index=False)
    
    # Create log directory
    os.makedirs(os.path.dirname(log_file), exist_ok=True)
    log_abs_path = os.path.abspath(log_file)
    
    # Run finemo
    temp_output = f"/tmp/finemo_output_{model_name}.tsv"
    
    # Select GPU
    skip_gpus = getattr(args, 'skip_gpus', None) or []
    if gpu_id is not None:
        selected_gpu = gpu_id
        if selected_gpu in skip_gpus:
            print(f"Warning: GPU {selected_gpu} is in skip list, but using it anyway (assigned by parallel processing)")
    elif args.gpu is not None:
        selected_gpu = args.gpu
        if selected_gpu in skip_gpus:
            print(f"Warning: GPU {selected_gpu} is in skip list, but using it anyway (manually specified)")
    else:
        selected_gpu = get_least_used_gpu(skip_gpus=skip_gpus)
    
    print(f"[TIMING] Starting finemo annotation for {model_name} ({len(new_variants)} variants) at {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Using GPU {selected_gpu}")
    
    # Run finemo command
    cmd = [
        'bash', '-c',
        f'export CUDA_VISIBLE_DEVICES={selected_gpu} && '
        f'cd {VARBOOK_DIR} && '
        f'{VENV_PYTHON} -m varbook annotate motif finemo {temp_input} variant_id '
        f'--model-paths-tsv {MODEL_PATHS_TSV} '
        f'--models {model_name} '
        f'--modisco-h5 {MODISCO_H5} '
        f'--alpha 0.8 '
        f'--hits-per-variant 20 '
        f'--n-shuffles 20 '
        f'--window-size 300 '
        f'--device cuda '
        f'-o {temp_output}'
    ]
    
    with open(log_abs_path, 'w') as log_f:
        result = subprocess.run(
            cmd,
            stdout=log_f,
            stderr=subprocess.STDOUT,
            text=True
        )
    
    if result.returncode != 0:
        print(f"Error: Finemo command failed for {model_name}. Check log: {log_abs_path}")
        if os.path.exists(temp_input):
            os.remove(temp_input)
        return
    
    # Check if finemo output exists
    if not os.path.exists(temp_output):
        print(f"Error: Finemo output file not found: {temp_output}. Check log: {log_abs_path}")
        if os.path.exists(temp_input):
            os.remove(temp_input)
        return
    
    # Load new results
    new_results_df = pd.read_csv(temp_output, sep='\t')
    
    # Remove model suffix from column names (if present)
    rename_map = {
        f'finemo_motif_hits_{model_name}': 'finemo_motif_hits',
        f'finemo_motif_hit_count_{model_name}': 'finemo_motif_hit_count',
        f'finemo_motif_top_hit_{model_name}': 'finemo_motif_top_hit',
        f'finemo_motif_top_score_{model_name}': 'finemo_motif_top_score',
        f'finemo_motif_allele_diff_{model_name}': 'finemo_motif_allele_diff',
        f'finemo_motif_positions_{model_name}': 'finemo_motif_positions',
        f'finemo_motif_ref_summary_{model_name}': 'finemo_motif_ref_summary',
        f'finemo_motif_alt_summary_{model_name}': 'finemo_motif_alt_summary',
        f'finemo_motif_diff_summary_{model_name}': 'finemo_motif_diff_summary'
    }
    rename_map = {k: v for k, v in rename_map.items() if k in new_results_df.columns}
    if rename_map:
        new_results_df = new_results_df.rename(columns=rename_map)
    
    # Keep only variant_id and finemo columns
    finemo_cols = [c for c in new_results_df.columns if c.startswith('finemo_')]
    new_results_df = new_results_df[['variant_id'] + finemo_cols]
    
    # Merge with existing results if file exists
    if os.path.exists(output_file):
        existing_df = pd.read_csv(output_file, sep='\t')
        merged_df = pd.concat([existing_df, new_results_df], ignore_index=True)
        merged_df = merged_df.drop_duplicates(subset=['variant_id'], keep='last')
    else:
        merged_df = new_results_df
    
    # Save merged results
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    merged_df.to_csv(output_file, sep='\t', index=False)
    
    # Clean up temp files
    os.remove(temp_input)
    os.remove(temp_output)
    
    end_time = time.time()
    duration = end_time - start_time
    print(f"[TIMING] Finemo annotation for {model_name} completed in {duration:.1f}s")
    print(f"Saved {len(merged_df)} total variants for {model_name} to {output_file}")


def main():
    parser = argparse.ArgumentParser(description='Run finemo motif annotation for models. By default, runs for all models on their prioritized variants.')
    parser.add_argument('--model', help='Regex pattern to match model names (e.g., "KUN_FB.*" or "KUN_FB_microglia"). If not specified, runs for all models.')
    parser.add_argument('--variant-ids', help='Comma-separated list of variant IDs')
    parser.add_argument('--variant-ids-file', help='Path to file with variant IDs (one per line)')
    parser.add_argument('--general-tsv', help='Path to general variants TSV (auto-detected if not provided)')
    parser.add_argument('--prioritization-tsv', help='Path to prioritization TSV (auto-detected if not provided)')
    parser.add_argument('--output', help=f'Output file path (default: {SPLITS_DIR}/finemo/broad.finemo.MODEL.tsv). Only used with single model.')
    parser.add_argument('--log', help='Log file path (default: logs/finemo/broad.finemo.MODEL.log). Only used with single model.')
    parser.add_argument('--gpu', type=int, help='GPU device ID to use (auto-selects least used if not specified). Only used when --max-jobs=1.')
    parser.add_argument('--max-jobs', type=int, default=1, help='Maximum number of parallel jobs (default: 1). Set to number of GPUs for parallel processing.')
    parser.add_argument('--skip-gpus', type=str, help='Comma-separated list of GPU IDs to skip (e.g., "0,1" to skip GPUs 0 and 1)')
    
    args = parser.parse_args()
    
    # Parse skip_gpus argument
    if args.skip_gpus:
        try:
            args.skip_gpus = [int(gpu.strip()) for gpu in args.skip_gpus.split(',') if gpu.strip()]
            print(f"Skipping GPUs: {args.skip_gpus}")
        except ValueError as e:
            print(f"Error: Invalid --skip-gpus format. Use comma-separated integers (e.g., '0,1'): {e}")
            sys.exit(1)
    else:
        args.skip_gpus = []
    
    # Get all models
    all_models = get_all_models(MODEL_PATHS_TSV)
    
    # Determine which models to process
    if args.model:
        models = filter_models_by_pattern(all_models, args.model)
        if not models:
            print(f"Warning: No models matched pattern '{args.model}'")
            sys.exit(1)
        print(f"Found {len(models)} models matching pattern '{args.model}'")
    else:
        models = all_models
        print(f"No model pattern specified. Running for all {len(models)} models...")
    
    # Get available GPUs (excluding skipped ones)
    available_gpus = get_available_gpus(skip_gpus=args.skip_gpus)
    print(f"Available GPUs: {available_gpus}")
    if args.skip_gpus:
        print(f"Skipped GPUs: {args.skip_gpus}")
    
    # Process models
    total_start_time = time.time()
    
    if args.max_jobs == 1:
        # Sequential processing
        for i, model_name in enumerate(models, 1):
            print(f"\n[{i}/{len(models)}] Processing model: {model_name}")
            try:
                process_model(model_name, args)
            except Exception as e:
                print(f"Error processing {model_name}: {e}")
                continue
    else:
        # Parallel processing
        max_jobs = min(args.max_jobs, len(available_gpus), len(models))
        print(f"Running {max_jobs} jobs in parallel across {len(available_gpus)} GPUs")
        
        # Prepare arguments for parallel processing
        # Each tuple contains: (model_name, index, args, available_gpus)
        process_tuples = [(model_name, i, args, available_gpus) for i, model_name in enumerate(models)]
        
        # Process models in parallel
        with Pool(processes=max_jobs) as pool:
            results = pool.map(process_model_with_gpu_wrapper, process_tuples)
        
        # Print summary
        successful = sum(1 for _, success, _ in results if success)
        failed = len(results) - successful
        print(f"\n{'='*80}")
        print(f"Processing complete: {successful} successful, {failed} failed")
        if failed > 0:
            print("Failed models:")
            for model_name, success, error in results:
                if not success:
                    print(f"  - {model_name}: {error}")
    
    total_end_time = time.time()
    total_duration = total_end_time - total_start_time
    print(f"\n{'='*80}")
    print(f"[TIMING] Total time for {len(models)} models: {total_duration:.1f}s ({total_duration/60:.1f} minutes)")
    print(f"{'='*80}")


if __name__ == '__main__':
    main()

