#!/usr/bin/env python3
"""
Script to generate clustered TSV files from filtered TSV files.

This script:
1. Runs KMeans clustering on score columns
2. Adds cluster assignments (kmeans_35 column)
3. Adds organs metadata column
4. Adds finemo motif columns (ref_specific_motifs, alt_specific_motifs, unchanged_motifs_in_region)
   - Maps motif names using hdma_motif_mapping.tsv
5. Drops logfc and score columns

Configuration is read from environment variables (set by config.sh).
"""

import argparse
import pandas as pd
import numpy as np
import os
import sys
import json
import time
import random
from pathlib import Path
from typing import Dict, List, Set, Optional

# Configuration - read from environment variables (set by config.sh)
SPLITS_DIR = os.environ.get('SPLITS_DIR')
LOCAL_SPLITS_DIR = os.environ.get('LOCAL_SPLITS_DIR')
VARBOOK_DIR = os.environ.get('VARBOOK_DIR')
FINEMO_DIR = os.environ.get('FINEMO_DIR', '/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/splits/finemo/')

# Add VARBOOK_DIR to path for importing varbook modules
if VARBOOK_DIR:
    sys.path.insert(0, VARBOOK_DIR)

from sklearn.cluster import KMeans
from varbook.annotate.kmeans import sort_clusters_by_size


def load_motif_mapping(mapping_file: str = 'hdma_motif_mapping.tsv') -> Dict[str, str]:
    """
    Load motif name mapping from hdma_motif_mapping.tsv.
    
    Parameters:
    -----------
    mapping_file : str
        Path to the mapping TSV file (default: hdma_motif_mapping.tsv)
    
    Returns:
    --------
    dict
        Dictionary mapping pattern -> motif_name
        If pattern not found, returns the original pattern name
    """
    mapping = {}
    
    # Try to find the mapping file in the current directory or script directory
    script_dir = os.path.dirname(os.path.abspath(__file__))
    possible_paths = [
        mapping_file,
        os.path.join(script_dir, mapping_file),
        os.path.join(os.getcwd(), mapping_file),
    ]
    
    mapping_path = None
    for path in possible_paths:
        if os.path.exists(path):
            mapping_path = path
            break
    
    if mapping_path is None:
        print(f"Warning: Motif mapping file not found at any of: {possible_paths}", file=sys.stderr)
        print("Continuing without motif name mapping...", file=sys.stderr)
        return {}
    
    try:
        df = pd.read_csv(mapping_path, sep='\t')
        if 'pattern' not in df.columns or 'motif_name' not in df.columns:
            print(f"Warning: Mapping file missing required columns (pattern, motif_name)", file=sys.stderr)
            return {}
        
        # Create mapping dictionary: pattern -> motif_name
        for _, row in df.iterrows():
            pattern = str(row['pattern']).strip()
            annotation = str(row['motif_name']).strip()
            if pattern and annotation and annotation != 'nan':
                mapping[pattern] = annotation
        
        print(f"Loaded {len(mapping)} motif mappings from {mapping_path}")
        return mapping
    except Exception as e:
        print(f"Warning: Error loading motif mapping file: {e}", file=sys.stderr)
        print("Continuing without motif name mapping...", file=sys.stderr)
        return {}


def map_motif_name(motif_name: str, mapping: Dict[str, str]) -> str:
    """
    Map a motif name from pattern format to motif_name format.
    
    Parameters:
    -----------
    motif_name : str
        Original motif name (pattern format, e.g., "pos.Average_305__merged_pattern_0")
    mapping : dict
        Dictionary mapping pattern -> motif_name
    
    Returns:
    --------
    str
        Mapped motif name (motif_name) if found, otherwise original name
    """
    if not mapping:
        return motif_name
    
    # Try exact match first
    if motif_name in mapping:
        return mapping[motif_name]
    
    # Return original if not found
    return motif_name


def get_filtered_tsv_path(variant_dataset, model_dataset, splits_dir=None):
    """Construct path to filtered TSV file."""
    if splits_dir is None:
        splits_dir = SPLITS_DIR
    
    data_dir = None
    
    # Option 1: Check current directory
    if os.path.exists("data"):
        data_dir = "data"
    # Option 2: Check if we're in snakemake directory and data is here
    elif os.path.exists(os.path.join(os.getcwd(), "data")):
        data_dir = os.path.join(os.getcwd(), "data")
    # Option 3: Infer from splits_dir
    elif splits_dir:
        splits_parent = os.path.dirname(splits_dir)
        potential_data_dir = os.path.join(splits_parent, "varbook-container", "snakemake", "data")
        if os.path.exists(potential_data_dir):
            data_dir = potential_data_dir
        else:
            potential_data_dir = os.path.join(splits_parent, "snakemake", "data")
            if os.path.exists(potential_data_dir):
                data_dir = potential_data_dir
    
    if data_dir is None:
        raise ValueError(
            f"Could not find data directory. Tried:\n"
            f"  - ./data\n"
            f"  - {os.path.join(os.getcwd(), 'data')}\n"
            f"  - (inferred from SPLITS_DIR={splits_dir})\n"
            f"Please specify input file directly with --input or ensure data directory exists."
        )
    
    filtered_tsv = os.path.join(data_dir, f"{variant_dataset}.{model_dataset}.filtered.tsv")
    return filtered_tsv


def get_model_tissues_path():
    """Construct path to model_tissues.tsv."""
    possible_paths = []
    
    if SPLITS_DIR:
        possible_paths.append(os.path.join(SPLITS_DIR, "model_tissues.tsv"))
    if LOCAL_SPLITS_DIR:
        possible_paths.append(os.path.join(LOCAL_SPLITS_DIR, "model_tissues.tsv"))
    
    for path in possible_paths:
        if os.path.exists(path):
            return path
    
    return None


def get_finemo_file_path(model_name: str) -> Optional[str]:
    """Construct path to finemo TSV for a model."""
    finemo_file = os.path.join(FINEMO_DIR, f"broad.finemo.{model_name}.tsv")
    if os.path.exists(finemo_file):
        return finemo_file
    return None


def process_model_finemo_file(
    model_name: str,
    finemo_file_path: str,
    prioritized_variant_ids: Set[str],
    motif_dicts: Dict[str, Dict],
    motif_mapping: Dict[str, str]
) -> Dict[str, Dict]:
    """
    Process a single model's finemo file and update motif dictionaries.
    
    Maps motif names using the provided motif_mapping dictionary.
    """
    try:
        finemo_df = pd.read_csv(finemo_file_path, sep='\t')
        
        if 'variant_id' not in finemo_df.columns:
            print(f"Warning: {finemo_file_path} missing variant_id column, skipping", file=sys.stderr)
            return motif_dicts
        
        if 'finemo_motif_positions' not in finemo_df.columns:
            print(f"Warning: {finemo_file_path} missing finemo_motif_positions column, skipping", file=sys.stderr)
            return motif_dicts
        
        # Filter to prioritized variants
        finemo_df = finemo_df[finemo_df['variant_id'].isin(prioritized_variant_ids)]
        
        for _, row in finemo_df.iterrows():
            variant_id = row['variant_id']
            motif_positions_str = row['finemo_motif_positions']
            
            if pd.isna(motif_positions_str) or not motif_positions_str:
                continue
            
            try:
                motif_positions = json.loads(motif_positions_str)
            except (json.JSONDecodeError, TypeError) as e:
                print(f"Warning: Failed to parse finemo_motif_positions for {variant_id}: {e}", file=sys.stderr)
                continue
            
            # Separate ref and alt motifs
            ref_motifs = {}
            alt_motifs = {}
            
            for hit in motif_positions:
                if not isinstance(hit, dict):
                    continue
                
                score = hit.get('score', 0)
                if score < 1.0:
                    continue
                
                motif_name = hit.get('motif', '')
                if not motif_name:
                    continue
                
                # Map motif name using the mapping dictionary
                mapped_motif_name = map_motif_name(motif_name, motif_mapping)
                
                allele = hit.get('allele', '')
                
                if allele == 'ref':
                    ref_motifs[mapped_motif_name] = ref_motifs.get(mapped_motif_name, 0) + 1
                elif allele == 'alt':
                    alt_motifs[mapped_motif_name] = alt_motifs.get(mapped_motif_name, 0) + 1
            
            # Update motif dictionaries
            # ref_specific: motifs only in ref
            for motif_name, count in ref_motifs.items():
                if motif_name not in alt_motifs:
                    if motif_name not in motif_dicts['ref_specific_motifs']:
                        motif_dicts['ref_specific_motifs'][motif_name] = {'count': 0, 'models': set()}
                    motif_dicts['ref_specific_motifs'][motif_name]['count'] += count
                    motif_dicts['ref_specific_motifs'][motif_name]['models'].add(model_name)
            
            # alt_specific: motifs only in alt
            for motif_name, count in alt_motifs.items():
                if motif_name not in ref_motifs:
                    if motif_name not in motif_dicts['alt_specific_motifs']:
                        motif_dicts['alt_specific_motifs'][motif_name] = {'count': 0, 'models': set()}
                    motif_dicts['alt_specific_motifs'][motif_name]['count'] += count
                    motif_dicts['alt_specific_motifs'][motif_name]['models'].add(model_name)
            
            # unchanged: motifs in both ref and alt
            for motif_name in set(ref_motifs.keys()) & set(alt_motifs.keys()):
                if motif_name not in motif_dicts['unchanged_motifs_in_region']:
                    motif_dicts['unchanged_motifs_in_region'][motif_name] = 0
                motif_dicts['unchanged_motifs_in_region'][motif_name] += 1
        
    except FileNotFoundError:
        print(f"Warning: Finemo file not found: {finemo_file_path}, skipping", file=sys.stderr)
    except Exception as e:
        print(f"Warning: Error processing finemo file {finemo_file_path}: {e}", file=sys.stderr)
    
    return motif_dicts


def format_motif_distribution(motif_dict: Dict, include_model: bool = False) -> str:
    """
    Format motif distribution as string.
    
    For ref/alt specific: "MOTIF_NAME (count, RANDOM_MODEL), ..." sorted by count descending
    For unchanged: "MOTIF_NAME (count), ..." sorted by count descending
    """
    if not motif_dict:
        return ""
    
    items = []
    
    if include_model:
        # ref_specific or alt_specific format
        for motif_name, data in motif_dict.items():
            count = data['count']
            models = data['models']
            # Select a random model from the set
            model = random.choice(list(models)) if models else ""
            items.append((motif_name, count, model))
        
        # Sort by count descending
        items.sort(key=lambda x: x[1], reverse=True)
        
        # Format as "MOTIF_NAME (count, MODEL)"
        formatted = ", ".join([f"{name} ({count}, {model})" for name, count, model in items])
    else:
        # unchanged format
        items = [(motif_name, count) for motif_name, count in motif_dict.items()]
        items.sort(key=lambda x: x[1], reverse=True)
        formatted = ", ".join([f"{name} ({count})" for name, count in items])
    
    return formatted


def run_clustering(
    input_file: str,
    output_file: Optional[str] = None,
    n_clusters: int = 35,
    random_state: int = 42,
    motif_mapping_file: str = 'hdma_motif_mapping.tsv'
):
    """
    Run clustering and generate clustered TSV file.
    """
    start_time = time.time()
    
    print(f"\n{'='*80}")
    print(f"[TIMING] Starting clustering at {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'='*80}")
    print(f"Input file: {input_file}")
    
    # Load motif mapping
    print(f"\nLoading motif mapping from {motif_mapping_file}...")
    motif_mapping = load_motif_mapping(motif_mapping_file)
    
    # Validate input file
    if not os.path.exists(input_file):
        raise FileNotFoundError(f"Input file not found: {input_file}")
    
    # Determine output file
    if output_file is None:
        if input_file.endswith('.filtered.tsv'):
            output_file = input_file.replace('.filtered.tsv', '.clustered.tsv')
        else:
            base_name = os.path.splitext(input_file)[0]
            output_file = f"{base_name}.clustered.tsv"
    
    print(f"Output file: {output_file}")
    
    # Read filtered TSV
    print(f"\nReading filtered TSV: {input_file}")
    df = pd.read_csv(input_file, sep='\t')
    print(f"Loaded {len(df)} variants with {len(df.columns)} columns")
    
    # Extract score columns
    score_cols = [col for col in df.columns if col.startswith('score-')]
    
    if len(score_cols) == 0:
        raise ValueError(f"No score-{{model}} columns found in {input_file}")
    
    print(f"Found {len(score_cols)} score columns")
    
    # Prepare data for clustering
    print("\nPreparing data for clustering (filling NaN with 0)...")
    X = df[score_cols].fillna(0).values
    
    # Run KMeans
    print(f"Running KMeans clustering (k={n_clusters}, random_state={random_state})...")
    kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10)
    labels = kmeans.fit_predict(X)
    
    # Sort clusters by size
    print("Sorting clusters by size...")
    labels, remap = sort_clusters_by_size(labels)
    
    # Add kmeans_35 column
    df['kmeans_35'] = labels
    
    print(f"\nCluster distribution (sorted by size):")
    cluster_counts = pd.Series(labels).value_counts().sort_index()
    for cluster_id, count in cluster_counts.items():
        print(f"  Cluster {cluster_id}: {count} variants")
    
    # Add organs column
    print("\nAdding organs column...")
    model_tissues_path = get_model_tissues_path()
    if model_tissues_path and os.path.exists(model_tissues_path):
        try:
            model_tissues_df = pd.read_csv(model_tissues_path, sep='\t')
            if 'model' in model_tissues_df.columns and 'organ' in model_tissues_df.columns:
                # Create mapping from model to organ
                model_to_organ = dict(zip(model_tissues_df['model'], model_tissues_df['organ']))
                
                # Get all model columns
                model_cols = [col for col in df.columns if col.startswith('model_prioritized_by_any-')]
                models = [col.replace('model_prioritized_by_any-', '') for col in model_cols]
                
                # Create organs dictionary for each variant
                organs_dicts = []
                for _, row in df.iterrows():
                    variant_organs = {}
                    for model in models:
                        col_name = f'model_prioritized_by_any-{model}'
                        if col_name in row and pd.notna(row[col_name]) and row[col_name]:
                            if model in model_to_organ:
                                variant_organs[model] = model_to_organ[model]
                    organs_dicts.append(str(variant_organs) if variant_organs else "")
                
                df['organs'] = organs_dicts
                print(f"Added organs column from {model_tissues_path}")
            else:
                print(f"Warning: model_tissues.tsv missing required columns (model, organ)", file=sys.stderr)
                df['organs'] = ""
        except Exception as e:
            print(f"Warning: Error reading model_tissues.tsv: {e}", file=sys.stderr)
            df['organs'] = ""
    else:
        print(f"Warning: model_tissues.tsv not found, skipping organs column", file=sys.stderr)
        df['organs'] = ""
    
    # Add finemo motif columns
    print("\nAdding finemo motif columns...")
    
    # Initialize motif dictionaries for each variant
    variant_motif_dicts = {}
    for variant_id in df['variant_id'].values:
        variant_motif_dicts[variant_id] = {
            'ref_specific_motifs': {},
            'alt_specific_motifs': {},
            'unchanged_motifs_in_region': {}
        }
    
    # Get all model columns
    model_cols = [col for col in df.columns if col.startswith('model_prioritized_by_any-')]
    models = [col.replace('model_prioritized_by_any-', '') for col in model_cols]
    
    print(f"Processing {len(models)} models...")
    
    # Process each model
    for model_name in models:
        col_name = f'model_prioritized_by_any-{model_name}'
        if col_name not in df.columns:
            continue
        
        # Get prioritized variants for this model
        prioritized_mask = df[col_name].notna() & (df[col_name] == True)
        prioritized_variant_ids = set(df[prioritized_mask]['variant_id'].values)
        
        if len(prioritized_variant_ids) == 0:
            continue
        
        # Get finemo file path
        finemo_file_path = get_finemo_file_path(model_name)
        if finemo_file_path is None:
            continue
        
        print(f"  Processing {model_name} ({len(prioritized_variant_ids)} prioritized variants)...")
        
        # Process finemo file for each variant
        for variant_id in prioritized_variant_ids:
            if variant_id not in variant_motif_dicts:
                continue
            
            motif_dicts = variant_motif_dicts[variant_id]
            process_model_finemo_file(
                model_name,
                finemo_file_path,
                {variant_id},
                motif_dicts,
                motif_mapping
            )
    
    # Format motif distributions
    print("Formatting motif distributions...")
    ref_specific_strs = []
    alt_specific_strs = []
    unchanged_strs = []
    
    for variant_id in df['variant_id'].values:
        if variant_id in variant_motif_dicts:
            motif_dicts = variant_motif_dicts[variant_id]
            ref_specific_strs.append(format_motif_distribution(motif_dicts['ref_specific_motifs'], include_model=True))
            alt_specific_strs.append(format_motif_distribution(motif_dicts['alt_specific_motifs'], include_model=True))
            unchanged_strs.append(format_motif_distribution(motif_dicts['unchanged_motifs_in_region'], include_model=False))
        else:
            ref_specific_strs.append("")
            alt_specific_strs.append("")
            unchanged_strs.append("")
    
    df['ref_specific_motifs'] = ref_specific_strs
    df['alt_specific_motifs'] = alt_specific_strs
    df['unchanged_motifs_in_region'] = unchanged_strs
    
    # Drop logfc and score columns
    print("\nDropping logfc and score columns...")
    logfc_cols = [col for col in df.columns if col.startswith('logfc-')]
    score_cols_to_drop = [col for col in df.columns if col.startswith('score-')]
    
    df = df.drop(columns=logfc_cols + score_cols_to_drop)
    print(f"Dropped {len(logfc_cols)} logfc columns and {len(score_cols_to_drop)} score columns")
    
    # Save output
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    df.to_csv(output_file, sep='\t', index=False)
    
    end_time = time.time()
    duration = end_time - start_time
    print(f"\n[TIMING] Clustering completed in {duration:.1f}s")
    print(f"Saved {len(df)} variants to {output_file}")
    print(f"{'='*80}\n")
    
    return output_file


def main():
    parser = argparse.ArgumentParser(
        description='Generate clustered TSV files from filtered TSV files. '
                    'Maps motif names using hdma_motif_mapping.tsv.'
    )
    
    # Input options (mutually exclusive)
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument(
        '--input',
        help='Path to filtered TSV file (must contain score-{model} columns)'
    )
    input_group.add_argument(
        '--variant-dataset',
        help='Variant dataset name (e.g., "Broad neurodevelopmental and neuromuscular disorders"). '
             'Requires --model-dataset.'
    )
    
    parser.add_argument(
        '--model-dataset',
        help='Model dataset name (e.g., "Fetal Brain"). Required if using --variant-dataset.'
    )
    
    parser.add_argument(
        '--output',
        help='Output file path (default: auto-generated from input file, replacing .filtered.tsv with .clustered.tsv)'
    )
    
    parser.add_argument(
        '--n-clusters',
        type=int,
        default=35,
        help='Number of clusters (default: 35)'
    )
    
    parser.add_argument(
        '--random-state',
        type=int,
        default=42,
        help='Random state for reproducibility (default: 42)'
    )
    
    parser.add_argument(
        '--motif-mapping-file',
        default='hdma_motif_mapping.tsv',
        help='Path to motif mapping TSV file (default: hdma_motif_mapping.tsv)'
    )
    
    args = parser.parse_args()
    
    # Validate arguments
    if args.variant_dataset and not args.model_dataset:
        parser.error("--model-dataset is required when using --variant-dataset")
    
    # Determine input file
    if args.input:
        input_file = args.input
    else:
        input_file = get_filtered_tsv_path(args.variant_dataset, args.model_dataset)
        print(f"Auto-constructed input path: {input_file}")
    
    # Run clustering
    try:
        output_file = run_clustering(
            input_file=input_file,
            output_file=args.output,
            n_clusters=args.n_clusters,
            random_state=args.random_state,
            motif_mapping_file=args.motif_mapping_file
        )
        print(f"Success! Output saved to: {output_file}")
        return 0
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        import traceback
        traceback.print_exc()
        return 1


if __name__ == '__main__':
    sys.exit(main())
