#!/usr/bin/env python3
"""
Simple script to run KMeans clustering on score columns from filtered TSV.

Configuration is read from environment variables (set by config.sh).
You can source the config file before running:
    source config.sh
    python run_kmeans.py --input "data/variant_dataset.model_dataset.filtered.tsv"

Or set environment variables directly.

Usage:
    # Option 1: Direct input file
    python run_kmeans.py --input "data/Broad neurodevelopmental and neuromuscular disorders.Fetal Brain.filtered.tsv"
    
    # Option 2: Auto-construct path
    python run_kmeans.py \
        --variant-dataset "Broad neurodevelopmental and neuromuscular disorders" \
        --model-dataset "Fetal Brain"
"""

import argparse
import pandas as pd
import numpy as np
import os
import sys
import time
from pathlib import Path

# Configuration - read from environment variables (set by config.sh)
SPLITS_DIR = os.environ.get('SPLITS_DIR')
VARBOOK_DIR = os.environ.get('VARBOOK_DIR')

# Add VARBOOK_DIR to path for importing varbook modules
if VARBOOK_DIR:
    sys.path.insert(0, VARBOOK_DIR)

from sklearn.cluster import KMeans
from varbook.annotate.kmeans import sort_clusters_by_size


def get_filtered_tsv_path(variant_dataset, model_dataset, splits_dir=None):
    """
    Construct path to filtered TSV file.
    
    Parameters:
    -----------
    variant_dataset : str
        Variant dataset name (e.g., "Broad neurodevelopmental and neuromuscular disorders")
    model_dataset : str
        Model dataset name (e.g., "Fetal Brain")
    splits_dir : str, optional
        Directory containing split files. Defaults to SPLITS_DIR from environment.
        Used to infer the location of the data directory.
    
    Returns:
    --------
    str
        Path to filtered TSV file
    """
    if splits_dir is None:
        splits_dir = SPLITS_DIR
    
    # Construct path: data/{variant_dataset}.{model_dataset}.filtered.tsv
    # Try multiple locations for data directory
    data_dir = None
    
    # Option 1: Check current directory
    if os.path.exists("data"):
        data_dir = "data"
    # Option 2: Check if we're in snakemake directory and data is here
    elif os.path.exists(os.path.join(os.getcwd(), "data")):
        data_dir = os.path.join(os.getcwd(), "data")
    # Option 3: Infer from splits_dir (splits_dir is usually parent of snakemake)
    elif splits_dir:
        # If splits_dir is like ".../broad/splits", data might be at ".../broad/varbook-container/snakemake/data"
        splits_parent = os.path.dirname(splits_dir)
        potential_data_dir = os.path.join(splits_parent, "varbook-container", "snakemake", "data")
        if os.path.exists(potential_data_dir):
            data_dir = potential_data_dir
        else:
            # Try simpler: snakemake/data relative to splits parent
            potential_data_dir = os.path.join(splits_parent, "snakemake", "data")
            if os.path.exists(potential_data_dir):
                data_dir = potential_data_dir
    
    if data_dir is None:
        raise ValueError(
            f"Could not find data directory. Tried:\n"
            f"  - ./data\n"
            f"  - {os.path.join(os.getcwd(), 'data')}\n"
            f"  - (inferred from SPLITS_DIR={splits_dir})\n"
            f"Please specify input file directly with --input or ensure data directory exists."
        )
    
    filtered_tsv = os.path.join(data_dir, f"{variant_dataset}.{model_dataset}.filtered.tsv")
    return filtered_tsv


def run_kmeans_clustering(input_file, output_file=None, log_file=None, n_clusters=35, random_state=42):
    """
    Run KMeans clustering on score columns from filtered TSV.
    
    Parameters:
    -----------
    input_file : str
        Path to filtered TSV file (must contain score-{model} columns)
    output_file : str, optional
        Path to output TSV file. If None, auto-generates from input_file.
    log_file : str, optional
        Path to log file. If None, prints to stdout.
    n_clusters : int
        Number of clusters (default: 35)
    random_state : int
        Random state for reproducibility (default: 42)
    
    Returns:
    --------
    str
        Path to output file
    """
    start_time = time.time()
    
    print(f"\n{'='*80}")
    print(f"[TIMING] Starting KMeans clustering at {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'='*80}")
    print(f"Input file: {input_file}")
    
    # Validate input file
    if not os.path.exists(input_file):
        raise FileNotFoundError(f"Input file not found: {input_file}")
    
    # Determine output file
    if output_file is None:
        # Auto-generate from input file
        # e.g., "data/variant.model.filtered.tsv" -> "data/variant.model.kmeans.tsv"
        if input_file.endswith('.filtered.tsv'):
            output_file = input_file.replace('.filtered.tsv', '.kmeans.tsv')
        else:
            base_name = os.path.splitext(input_file)[0]
            output_file = f"{base_name}.kmeans.tsv"
    
    print(f"Output file: {output_file}")
    
    # Determine log file
    if log_file:
        log_abs_path = os.path.abspath(log_file)
        os.makedirs(os.path.dirname(log_abs_path), exist_ok=True)
        print(f"Log file: {log_abs_path}")
    
    # Read filtered TSV
    print(f"\nReading filtered TSV: {input_file}")
    df = pd.read_csv(input_file, sep='\t')
    print(f"Loaded {len(df)} variants with {len(df.columns)} columns")
    
    # Extract score columns
    score_cols = [col for col in df.columns if col.startswith('score-')]
    
    if len(score_cols) == 0:
        raise ValueError(f"No score-{{model}} columns found in {input_file}")
    
    print(f"Found {len(score_cols)} score columns: {score_cols[:5]}..." if len(score_cols) > 5 else f"Found {len(score_cols)} score columns: {score_cols}")
    
    # Prepare data for clustering (fill NaN with 0)
    print("\nPreparing data for clustering (filling NaN with 0)...")
    X = df[score_cols].fillna(0).values
    print(f"Data shape: {X.shape} (variants × features)")
    
    # Run KMeans (k=35, random_state for reproducibility)
    print(f"\nRunning KMeans clustering (k={n_clusters}, random_state={random_state})...")
    kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10)
    labels = kmeans.fit_predict(X)
    
    # Sort clusters by size (largest → smallest)
    print("Sorting clusters by size...")
    labels, remap = sort_clusters_by_size(labels)
    
    # Create output dataframe with only variant_id and kmeans_35
    # Note: Score columns are not needed after clustering and can be dropped
    # when creating the full clustered.tsv (similar to how logfc columns are dropped)
    result_df = pd.DataFrame({
        'variant_id': df['variant_id'].values,
        'kmeans_35': labels
    })
    
    print(f"\nCluster distribution (sorted by size):")
    cluster_counts = pd.Series(labels).value_counts().sort_index()
    for cluster_id, count in cluster_counts.items():
        print(f"  Cluster {cluster_id}: {count} variants")
    
    # Save output
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    result_df.to_csv(output_file, sep='\t', index=False)
    
    end_time = time.time()
    duration = end_time - start_time
    print(f"\n[TIMING] KMeans clustering completed in {duration:.1f}s")
    print(f"Saved {len(result_df)} variants to {output_file}")
    print(f"{'='*80}\n")
    
    return output_file


def main():
    parser = argparse.ArgumentParser(
        description='Run KMeans clustering on score columns from filtered TSV. '
                    'By default, outputs variant_id and kmeans_35 columns.'
    )
    
    # Input options (mutually exclusive)
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument(
        '--input',
        help='Path to filtered TSV file (must contain score-{model} columns)'
    )
    input_group.add_argument(
        '--variant-dataset',
        help='Variant dataset name (e.g., "Broad neurodevelopmental and neuromuscular disorders"). '
             'Requires --model-dataset.'
    )
    
    parser.add_argument(
        '--model-dataset',
        help='Model dataset name (e.g., "Fetal Brain"). Required if using --variant-dataset.'
    )
    
    parser.add_argument(
        '--output',
        help='Output file path (default: auto-generated from input file, replacing .filtered.tsv with .kmeans.tsv)'
    )
    
    parser.add_argument(
        '--log',
        help='Log file path (default: print to stdout)'
    )
    
    parser.add_argument(
        '--n-clusters',
        type=int,
        default=35,
        help='Number of clusters (default: 35)'
    )
    
    parser.add_argument(
        '--random-state',
        type=int,
        default=42,
        help='Random state for reproducibility (default: 42)'
    )
    
    args = parser.parse_args()
    
    # Validate arguments
    if args.variant_dataset and not args.model_dataset:
        parser.error("--model-dataset is required when using --variant-dataset")
    
    # Determine input file
    if args.input:
        input_file = args.input
    else:
        input_file = get_filtered_tsv_path(args.variant_dataset, args.model_dataset)
        print(f"Auto-constructed input path: {input_file}")
    
    # Run clustering
    try:
        output_file = run_kmeans_clustering(
            input_file=input_file,
            output_file=args.output,
            log_file=args.log,
            n_clusters=args.n_clusters,
            random_state=args.random_state
        )
        print(f"Success! Output saved to: {output_file}")
        return 0
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        import traceback
        traceback.print_exc()
        return 1


if __name__ == '__main__':
    sys.exit(main())

