#!/usr/bin/env python3
"""
Script to generate clustered TSV files from filtered TSV files with human-readable columns.

This script:
1. Runs KMeans clustering on score columns
2. Adds cluster assignments (kmeans_35 column)
3. Adds organs metadata column
4. Adds finemo motif columns (ref_specific_motifs, alt_specific_motifs, unchanged_motifs_in_region)
   - Maps motif names using hdma_motif_mapping.tsv
5. Merges additional columns from multiple TSV files:
   - closest_elements.tsv (genes, miRNA, lncRNA with distances)
   - logfc.tsv (logfc-{model} columns)
   - aaq.tsv (aaq-{model} columns)
   - VEP.most_severe_csqs.tsv
   - GENCODE.region_type.tsv
   - gnomad.tsv
   - patient_hpo_expanded.tsv
6. Filters per-model columns to only prioritized models
7. Adds HTML URL column
8. Creates merged closest element columns (closest_genes, closest_miRNA, closest_lncRNA)
9. Creates simons_searchlight_genes_in_closest_genes column
10. Creates G2P intersection columns (g2p_genes_in_closest_genes, g2p_miRNA_in_closest_miRNA, g2p_lncRNA_in_closest_lncRNA)
11. Drops logfc and score columns

Configuration is read from environment variables (set by config.sh).
"""

import argparse
import pandas as pd
import numpy as np
import os
import sys
import json
import time
import random
import subprocess
import tempfile
import re
from pathlib import Path
from typing import Dict, List, Set, Optional

# Configuration - read from environment variables (set by config.sh)
SPLITS_DIR = os.environ.get('SPLITS_DIR')
LOCAL_SPLITS_DIR = os.environ.get('LOCAL_SPLITS_DIR')
VARBOOK_DIR = os.environ.get('VARBOOK_DIR')
FINEMO_DIR = os.environ.get('FINEMO_DIR', '/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/splits/finemo/')
OUTPUT_DIR = os.environ.get('OUTPUT_DIR', '/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/varbook-container/snakemake/data')

# Add VARBOOK_DIR to path for importing varbook modules
if VARBOOK_DIR:
    sys.path.insert(0, VARBOOK_DIR)

from sklearn.cluster import KMeans
from varbook.annotate.kmeans import sort_clusters_by_size


def load_motif_mapping(mapping_file: str = 'hdma_motif_mapping.tsv') -> Dict[str, str]:
    """
    Load motif name mapping from hdma_motif_mapping.tsv.
    
    Parameters:
    -----------
    mapping_file : str
        Path to the mapping TSV file (default: hdma_motif_mapping.tsv)
    
    Returns:
    --------
    dict
        Dictionary mapping pattern -> annotation_broad
        If pattern not found, returns the original pattern name
    """
    mapping = {}
    
    # Try to find the mapping file in the current directory or script directory
    script_dir = os.path.dirname(os.path.abspath(__file__))
    possible_paths = [
        mapping_file,
        os.path.join(script_dir, mapping_file),
        os.path.join(os.getcwd(), mapping_file),
    ]
    
    mapping_path = None
    for path in possible_paths:
        if os.path.exists(path):
            mapping_path = path
            break
    
    if mapping_path is None:
        print(f"Warning: Motif mapping file not found at any of: {possible_paths}", file=sys.stderr)
        print("Continuing without motif name mapping...", file=sys.stderr)
        return {}
    
    try:
        df = pd.read_csv(mapping_path, sep='\t')
        if 'pattern' not in df.columns or 'annotation_broad' not in df.columns:
            print(f"Warning: Mapping file missing required columns (pattern, annotation_broad)", file=sys.stderr)
            return {}
        
        # Create mapping dictionary: pattern -> annotation_broad
        for _, row in df.iterrows():
            pattern = str(row['pattern']).strip()
            annotation = str(row['annotation_broad']).strip()
            if pattern and annotation and annotation != 'nan' and annotation.lower() != 'nan':
                mapping[pattern] = annotation
        
        print(f"Loaded {len(mapping)} motif mappings from {mapping_path}")
        if len(mapping) > 0:
            # Show a few example mappings for debugging
            sample_patterns = list(mapping.keys())[:3]
            print(f"Sample mappings: {sample_patterns[0]} -> {mapping[sample_patterns[0]]}" if sample_patterns else "")
        return mapping
    except Exception as e:
        print(f"Warning: Error loading motif mapping file: {e}", file=sys.stderr)
        print("Continuing without motif name mapping...", file=sys.stderr)
        return {}


def map_motif_name(motif_name: str, mapping: Dict[str, str]) -> str:
    """
    Map a motif name from pattern format to annotation_broad format.
    
    Handles motif names that may have prefixes like:
    - "pos_patterns.pos.Average_305__merged_pattern_0" -> "pos.Average_305__merged_pattern_0"
    - "neg_patterns.neg.Average_12__merged_pattern_0" -> "neg.Average_12__merged_pattern_0"
    
    Parameters:
    -----------
    motif_name : str
        Original motif name (pattern format, e.g., "pos.Average_305__merged_pattern_0" 
        or "pos_patterns.pos.Average_305__merged_pattern_0")
    mapping : dict
        Dictionary mapping pattern -> annotation_broad
    
    Returns:
    --------
    str
        Mapped motif name (annotation_broad) if found, otherwise original name
    """
    if not mapping:
        return motif_name
    
    # Strip whitespace from motif name
    motif_name = str(motif_name).strip()
    original_motif_name = motif_name
    
    # Try exact match first
    if motif_name in mapping:
        return mapping[motif_name]
    
    # Strip common prefixes that finemo may add
    # Handle patterns like "pos_patterns.pos.Average_305__merged_pattern_0"
    # or "neg_patterns.neg.Average_12__merged_pattern_0"
    stripped_name = motif_name
    if stripped_name.startswith('pos_patterns.'):
        stripped_name = stripped_name.replace('pos_patterns.', '', 1)
    elif stripped_name.startswith('neg_patterns.'):
        stripped_name = stripped_name.replace('neg_patterns.', '', 1)
    
    # Try lookup with stripped prefix
    if stripped_name != motif_name and stripped_name in mapping:
        return mapping[stripped_name]
    
    # If still not found and we have dots, try to extract the pattern part
    # Some formats might be like "pos_patterns.pos.Average_276__merged_pattern_2"
    # After stripping "pos_patterns.", we get "pos.Average_276__merged_pattern_2"
    # which should match the mapping file format "pos.Average_276__merged_pattern_2"
    # But if there are more prefixes, try to find the actual pattern
    if '.' in stripped_name:
        parts = stripped_name.split('.')
        # The pattern should start with 'pos.' or 'neg.'
        # Find the first part that matches this
        for i in range(len(parts)):
            if parts[i] in ['pos', 'neg'] and i + 1 < len(parts):
                # Reconstruct pattern from this point: "pos.Average_276__merged_pattern_2"
                pattern_part = '.'.join(parts[i:])
                if pattern_part in mapping:
                    return mapping[pattern_part]
    
    # Return original if not found
    return original_motif_name


def get_filtered_tsv_path(variant_dataset, model_dataset, splits_dir=None):
    """Construct path to filtered TSV file."""
    if splits_dir is None:
        splits_dir = SPLITS_DIR
    
    data_dir = None
    
    # Option 1: Check current directory
    if os.path.exists("data"):
        data_dir = "data"
    # Option 2: Check if we're in snakemake directory and data is here
    elif os.path.exists(os.path.join(os.getcwd(), "data")):
        data_dir = os.path.join(os.getcwd(), "data")
    # Option 3: Infer from splits_dir
    elif splits_dir:
        splits_parent = os.path.dirname(splits_dir)
        potential_data_dir = os.path.join(splits_parent, "varbook-container", "snakemake", "data")
        if os.path.exists(potential_data_dir):
            data_dir = potential_data_dir
        else:
            potential_data_dir = os.path.join(splits_parent, "snakemake", "data")
            if os.path.exists(potential_data_dir):
                data_dir = potential_data_dir
    
    if data_dir is None:
        raise ValueError(
            f"Could not find data directory. Tried:\n"
            f"  - ./data\n"
            f"  - {os.path.join(os.getcwd(), 'data')}\n"
            f"  - (inferred from SPLITS_DIR={splits_dir})\n"
            f"Please specify input file directly with --input or ensure data directory exists."
        )
    
    filtered_tsv = os.path.join(data_dir, f"{variant_dataset}.{model_dataset}.filtered.tsv")
    return filtered_tsv


def get_model_tissues_path():
    """Construct path to model_tissues.tsv."""
    possible_paths = []
    
    if SPLITS_DIR:
        possible_paths.append(os.path.join(SPLITS_DIR, "model_tissues.tsv"))
    if LOCAL_SPLITS_DIR:
        possible_paths.append(os.path.join(LOCAL_SPLITS_DIR, "model_tissues.tsv"))
    
    for path in possible_paths:
        if os.path.exists(path):
            return path
    
    return None


def get_finemo_file_path(model_name: str) -> Optional[str]:
    """Construct path to finemo TSV for a model."""
    finemo_file = os.path.join(FINEMO_DIR, f"broad.finemo.{model_name}.tsv")
    if os.path.exists(finemo_file):
        return finemo_file
    return None


def process_model_finemo_file(
    model_name: str,
    finemo_file_path: str,
    prioritized_variant_ids: Set[str],
    motif_dicts: Dict[str, Dict],
    motif_mapping: Dict[str, str]
) -> Dict[str, Dict]:
    """
    Process a single model's finemo file and update motif dictionaries.
    
    Maps motif names using the provided motif_mapping dictionary.
    """
    try:
        finemo_df = pd.read_csv(finemo_file_path, sep='\t')
        
        if 'variant_id' not in finemo_df.columns:
            print(f"Warning: {finemo_file_path} missing variant_id column, skipping", file=sys.stderr)
            return motif_dicts
        
        if 'finemo_motif_positions' not in finemo_df.columns:
            print(f"Warning: {finemo_file_path} missing finemo_motif_positions column, skipping", file=sys.stderr)
            return motif_dicts
        
        # Filter to prioritized variants
        finemo_df = finemo_df[finemo_df['variant_id'].isin(prioritized_variant_ids)]
        
        for _, row in finemo_df.iterrows():
            variant_id = row['variant_id']
            motif_positions_str = row['finemo_motif_positions']
            
            if pd.isna(motif_positions_str) or not motif_positions_str:
                continue
            
            try:
                motif_positions = json.loads(motif_positions_str)
            except (json.JSONDecodeError, TypeError) as e:
                print(f"Warning: Failed to parse finemo_motif_positions for {variant_id}: {e}", file=sys.stderr)
                continue
            
            # Separate ref and alt motifs
            ref_motifs = {}
            alt_motifs = {}
            
            for hit in motif_positions:
                if not isinstance(hit, dict):
                    continue
                
                score = hit.get('score', 0)
                if score < 1.0:
                    continue
                
                motif_name = hit.get('motif', '')
                if not motif_name:
                    continue
                
                # Strip whitespace from motif name
                motif_name = str(motif_name).strip()
                
                # Map motif name using the mapping dictionary
                mapped_motif_name = map_motif_name(motif_name, motif_mapping)
                
                # Debug: track unmapped motifs (only for first few variants to avoid spam)
                if not hasattr(process_model_finemo_file, '_debug_count'):
                    process_model_finemo_file._debug_count = 0
                    process_model_finemo_file._unmapped_motifs = set()
                    process_model_finemo_file._mapped_motifs = set()
                
                # Track successful mappings (with prefix stripping)
                if mapped_motif_name != motif_name:
                    process_model_finemo_file._mapped_motifs.add((motif_name, mapped_motif_name))
                
                # Track unmapped motifs
                if process_model_finemo_file._debug_count < 10 and mapped_motif_name == motif_name:
                    # Check if it's truly unmapped (not in mapping even after stripping)
                    stripped = motif_name
                    if stripped.startswith('pos_patterns.'):
                        stripped = stripped.replace('pos_patterns.', '', 1)
                    elif stripped.startswith('neg_patterns.'):
                        stripped = stripped.replace('neg_patterns.', '', 1)
                    
                    if stripped not in motif_mapping and motif_name not in motif_mapping:
                        process_model_finemo_file._unmapped_motifs.add(motif_name)
                        process_model_finemo_file._debug_count += 1
                
                allele = hit.get('allele', '')
                
                if allele == 'ref':
                    ref_motifs[mapped_motif_name] = ref_motifs.get(mapped_motif_name, 0) + 1
                elif allele == 'alt':
                    alt_motifs[mapped_motif_name] = alt_motifs.get(mapped_motif_name, 0) + 1
            
            # Update motif dictionaries
            # ref_specific: motifs only in ref
            for motif_name, count in ref_motifs.items():
                if motif_name not in alt_motifs:
                    if motif_name not in motif_dicts['ref_specific_motifs']:
                        motif_dicts['ref_specific_motifs'][motif_name] = {'count': 0, 'models': set()}
                    motif_dicts['ref_specific_motifs'][motif_name]['count'] += count
                    motif_dicts['ref_specific_motifs'][motif_name]['models'].add(model_name)
            
            # alt_specific: motifs only in alt
            for motif_name, count in alt_motifs.items():
                if motif_name not in ref_motifs:
                    if motif_name not in motif_dicts['alt_specific_motifs']:
                        motif_dicts['alt_specific_motifs'][motif_name] = {'count': 0, 'models': set()}
                    motif_dicts['alt_specific_motifs'][motif_name]['count'] += count
                    motif_dicts['alt_specific_motifs'][motif_name]['models'].add(model_name)
            
            # unchanged: motifs in both ref and alt
            for motif_name in set(ref_motifs.keys()) & set(alt_motifs.keys()):
                if motif_name not in motif_dicts['unchanged_motifs_in_region']:
                    motif_dicts['unchanged_motifs_in_region'][motif_name] = 0
                motif_dicts['unchanged_motifs_in_region'][motif_name] += 1
        
    except FileNotFoundError:
        print(f"Warning: Finemo file not found: {finemo_file_path}, skipping", file=sys.stderr)
    except Exception as e:
        print(f"Warning: Error processing finemo file {finemo_file_path}: {e}", file=sys.stderr)
    
    return motif_dicts


def format_motif_distribution(motif_dict: Dict, include_model: bool = False) -> str:
    """
    Format motif distribution as string.
    
    For ref/alt specific: "MOTIF_NAME (count, RANDOM_MODEL), ..." sorted by count descending
    For unchanged: "MOTIF_NAME (count), ..." sorted by count descending
    """
    if not motif_dict:
        return ""
    
    items = []
    
    if include_model:
        # ref_specific or alt_specific format
        for motif_name, data in motif_dict.items():
            count = data['count']
            models = data['models']
            # Select a random model from the set
            model = random.choice(list(models)) if models else ""
            items.append((motif_name, count, model))
        
        # Sort by count descending
        items.sort(key=lambda x: x[1], reverse=True)
        
        # Format as "MOTIF_NAME (count, MODEL)"
        formatted = ", ".join([f"{name} ({count}, {model})" for name, count, model in items])
    else:
        # unchanged format
        items = [(motif_name, count) for motif_name, count in motif_dict.items()]
        items.sort(key=lambda x: x[1], reverse=True)
        formatted = ", ".join([f"{name} ({count})" for name, count in items])
    
    return formatted


def merge_closest_elements(row, name_prefix, max_count):
    """Merge all closest_{name_prefix}_{1-max_count} and closest_{name_prefix}_distance_{1-max_count} into a single formatted string."""
    element_distance_pairs = []
    
    # Collect all element-distance pairs from columns 1 to max_count
    for i in range(1, max_count + 1):
        element_col = f'closest_{name_prefix}_{i}'
        dist_col = f'closest_{name_prefix}_distance_{i}'
        
        if element_col in row.index and dist_col in row.index:
            element = row[element_col]
            distance = row[dist_col]
            
            # Skip if either element or distance is missing/NaN/empty
            if pd.notna(element) and pd.notna(distance) and str(element).strip() != '' and str(element) != '.':
                try:
                    # Convert distance to numeric, handling string representations
                    dist_value = float(distance) if distance != '.' else None
                    if dist_value is not None:
                        element_distance_pairs.append((str(element).strip(), dist_value))
                except (ValueError, TypeError):
                    # Skip if distance can't be converted to float
                    pass
    
    # Sort by absolute distance
    element_distance_pairs.sort(key=lambda x: abs(x[1]))
    
    # Format as "ELEMENT (distance bp), ELEMENT (distance bp), ..."
    formatted_pairs = [f"{element} ({int(dist)} bp)" for element, dist in element_distance_pairs]
    
    return ", ".join(formatted_pairs) if formatted_pairs else ""


def filter_searchlight_genes(closest_genes_str, searchlight_genes):
    """Filter closest_genes string to only include genes in searchlight_genes set."""
    if pd.isna(closest_genes_str) or str(closest_genes_str).strip() == '':
        return ""
    
    # Parse the closest_genes string (format: "GENE1 (dist bp), GENE2 (dist bp), ...")
    filtered_pairs = []
    
    # Split by comma to get individual gene-distance pairs
    pairs = str(closest_genes_str).split(',')
    
    for pair in pairs:
        pair = pair.strip()
        if not pair:
            continue
        
        # Extract gene name (everything before the opening parenthesis)
        if '(' in pair:
            gene_name = pair.split('(')[0].strip()
            
            # Check if gene is in searchlight genes (case-insensitive)
            if gene_name.upper() in searchlight_genes:
                # Keep the original pair format
                filtered_pairs.append(pair)
    
    return ", ".join(filtered_pairs) if filtered_pairs else ""


def format_g2p_match(element_name, distance, g2p_entries):
    """Format a single element match with G2P metadata."""
    if not g2p_entries:
        return f"{element_name} ({int(distance)} bp)"
    
    # Build metadata string - for duplicate entries, show each complete entry together
    field_order = ['disease_name', 'confidence', 'molecular mechanism', 'molecular mechanism categorisation', 
                  'variant_consequence', 'variant_types', 'phenotypes', 'panel']
    metadata_parts = []
    
    # For each entry, show all its fields together
    for entry in g2p_entries:
        entry_parts = []
        for field in field_order:
            value = entry.get(field, 'N/A')
            if value and str(value).strip() and value != 'N/A':
                entry_parts.append(f"{field}: {value}")
            else:
                entry_parts.append(f"{field}: N/A")
        
        # Add all fields for this entry together
        if entry_parts:
            metadata_parts.append("; ".join(entry_parts))
    
    # Combine all entries, separated by // to distinguish between different G2P entries
    if metadata_parts:
        metadata = " // ".join(metadata_parts)
        return f"{element_name} ({int(distance)} bp; {metadata})"
    else:
        return f"{element_name} ({int(distance)} bp)"


def intersect_with_g2p(closest_elements_str, g2p_lookup_dict, element_type='gene'):
    """Intersect closest elements string with G2P lookup and format results."""
    if pd.isna(closest_elements_str) or str(closest_elements_str).strip() == '':
        return ""
    
    formatted_matches = []
    
    # Parse the closest_elements string (format: "ELEMENT1 (dist bp), ELEMENT2 (dist bp), ...")
    pairs = str(closest_elements_str).split(',')
    
    for pair in pairs:
        pair = pair.strip()
        if not pair:
            continue
        
        # Extract element name and distance
        if '(' in pair:
            element_name = pair.split('(')[0].strip()
            # Extract distance (everything between parentheses)
            distance_str = pair[pair.find('(')+1:pair.find(')')].replace('bp', '').strip()
            
            try:
                distance = float(distance_str)
            except (ValueError, TypeError):
                continue
            
            # Check if element is in G2P lookup (case-insensitive)
            element_upper = element_name.upper()
            if element_upper in g2p_lookup_dict:
                g2p_entries = g2p_lookup_dict[element_upper]
                formatted_match = format_g2p_match(element_name, distance, g2p_entries)
                formatted_matches.append(formatted_match)
    
    return ", ".join(formatted_matches) if formatted_matches else ""


def run_clustering(
    input_file: str,
    output_file: Optional[str] = None,
    n_clusters: int = 35,
    random_state: int = 42,
    motif_mapping_file: str = 'hdma_motif_mapping.tsv',
    variant_dataset: Optional[str] = None,
    model_dataset: Optional[str] = None
):
    """
    Run clustering and generate clustered TSV file with human-readable columns.
    """
    start_time = time.time()
    
    print(f"\n{'='*80}")
    print(f"[TIMING] Starting clustering at {time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'='*80}")
    print(f"Input file: {input_file}")
    
    # Load motif mapping
    print(f"\nLoading motif mapping from {motif_mapping_file}...")
    motif_mapping = load_motif_mapping(motif_mapping_file)
    
    # Validate input file
    if not os.path.exists(input_file):
        raise FileNotFoundError(f"Input file not found: {input_file}")
    
    # Determine output file
    if output_file is None:
        if input_file.endswith('.filtered.tsv'):
            output_file = input_file.replace('.filtered.tsv', '.clustered.tsv')
        else:
            base_name = os.path.splitext(input_file)[0]
            output_file = f"{base_name}.clustered.tsv"
    
    print(f"Output file: {output_file}")
    
    # Extract variant_dataset and model_dataset from input file if not provided
    if variant_dataset is None or model_dataset is None:
        # Try to extract from input file path
        basename = os.path.basename(input_file)
        if '.filtered.tsv' in basename:
            parts = basename.replace('.filtered.tsv', '').split('.', 1)
            if len(parts) == 2:
                if variant_dataset is None:
                    variant_dataset = parts[0]
                if model_dataset is None:
                    model_dataset = parts[1]
    
    # Read filtered TSV
    print(f"\nReading filtered TSV: {input_file}")
    df = pd.read_csv(input_file, sep='\t')
    print(f"Loaded {len(df)} variants with {len(df.columns)} columns")
    
    # Extract score columns
    score_cols = [col for col in df.columns if col.startswith('score-')]
    
    if len(score_cols) == 0:
        raise ValueError(f"No score-{{model}} columns found in {input_file}")
    
    print(f"Found {len(score_cols)} score columns")
    
    # Prepare data for clustering
    print("\nPreparing data for clustering (filling NaN with 0)...")
    X = df[score_cols].fillna(0).values
    
    # Run KMeans
    print(f"Running KMeans clustering (k={n_clusters}, random_state={random_state})...")
    kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10)
    labels = kmeans.fit_predict(X)
    
    # Sort clusters by size
    print("Sorting clusters by size...")
    labels, remap = sort_clusters_by_size(labels)
    
    # Add kmeans_35 column
    df['kmeans_35'] = labels
    
    print(f"\nCluster distribution (sorted by size):")
    cluster_counts = pd.Series(labels).value_counts().sort_index()
    for cluster_id, count in cluster_counts.items():
        print(f"  Cluster {cluster_id}: {count} variants")
    
    # Add organs column
    print("\nAdding organs column...")
    model_tissues_path = get_model_tissues_path()
    if model_tissues_path and os.path.exists(model_tissues_path):
        try:
            model_tissues_df = pd.read_csv(model_tissues_path, sep='\t')
            if 'model' in model_tissues_df.columns and 'organ' in model_tissues_df.columns:
                # Create mapping from model to organ
                model_to_organ = dict(zip(model_tissues_df['model'], model_tissues_df['organ']))
                
                # Get all model columns
                model_cols = [col for col in df.columns if col.startswith('model_prioritized_by_any-')]
                models = [col.replace('model_prioritized_by_any-', '') for col in model_cols]
                
                # Create organs dictionary for each variant
                organs_dicts = []
                for _, row in df.iterrows():
                    variant_organs = {}
                    for model in models:
                        col_name = f'model_prioritized_by_any-{model}'
                        if col_name in row and pd.notna(row[col_name]) and row[col_name]:
                            if model in model_to_organ:
                                variant_organs[model] = model_to_organ[model]
                    organs_dicts.append(str(variant_organs) if variant_organs else "")
                
                df['organs'] = organs_dicts
                print(f"Added organs column from {model_tissues_path}")
            else:
                print(f"Warning: model_tissues.tsv missing required columns (model, organ)", file=sys.stderr)
                df['organs'] = ""
        except Exception as e:
            print(f"Warning: Error reading model_tissues.tsv: {e}", file=sys.stderr)
            df['organs'] = ""
    else:
        print(f"Warning: model_tissues.tsv not found, skipping organs column", file=sys.stderr)
        df['organs'] = ""
    
    # Add finemo motif columns
    print("\nAdding finemo motif columns...")
    
    # Initialize motif dictionaries for each variant
    variant_motif_dicts = {}
    for variant_id in df['variant_id'].values:
        variant_motif_dicts[variant_id] = {
            'ref_specific_motifs': {},
            'alt_specific_motifs': {},
            'unchanged_motifs_in_region': {}
        }
    
    # Get all model columns
    model_cols = [col for col in df.columns if col.startswith('model_prioritized_by_any-')]
    models = [col.replace('model_prioritized_by_any-', '') for col in model_cols]
    
    print(f"Processing {len(models)} models...")
    
    # Process each model
    for model_name in models:
        col_name = f'model_prioritized_by_any-{model_name}'
        if col_name not in df.columns:
            continue
        
        # Get prioritized variants for this model
        prioritized_mask = df[col_name].notna() & (df[col_name] == True)
        prioritized_variant_ids = set(df[prioritized_mask]['variant_id'].values)
        
        if len(prioritized_variant_ids) == 0:
            continue
        
        # Get finemo file path
        finemo_file_path = get_finemo_file_path(model_name)
        if finemo_file_path is None:
            continue
        
        print(f"  Processing {model_name} ({len(prioritized_variant_ids)} prioritized variants)...")
        
        # Reset debug counters for this model
        if hasattr(process_model_finemo_file, '_debug_count'):
            process_model_finemo_file._debug_count = 0
            process_model_finemo_file._unmapped_motifs = set()
            process_model_finemo_file._mapped_motifs = set()
        
        # Process finemo file for each variant
        for variant_id in prioritized_variant_ids:
            if variant_id not in variant_motif_dicts:
                continue
            
            motif_dicts = variant_motif_dicts[variant_id]
            process_model_finemo_file(
                model_name,
                finemo_file_path,
                {variant_id},
                motif_dicts,
                motif_mapping
            )
        
        # Print mapping statistics for debugging
        if hasattr(process_model_finemo_file, '_mapped_motifs') and process_model_finemo_file._mapped_motifs:
            print(f"  Successfully mapped {len(process_model_finemo_file._mapped_motifs)} unique motifs")
            # Show a few examples
            examples = list(process_model_finemo_file._mapped_motifs)[:3]
            for original, mapped in examples:
                print(f"    Example: '{original}' -> '{mapped}'")
        
        if hasattr(process_model_finemo_file, '_unmapped_motifs') and process_model_finemo_file._unmapped_motifs:
            print(f"  Warning: Found {len(process_model_finemo_file._unmapped_motifs)} unmapped motif names (first 10): {list(process_model_finemo_file._unmapped_motifs)[:10]}")
    
    # Format motif distributions
    print("Formatting motif distributions...")
    ref_specific_strs = []
    alt_specific_strs = []
    unchanged_strs = []
    
    for variant_id in df['variant_id'].values:
        if variant_id in variant_motif_dicts:
            motif_dicts = variant_motif_dicts[variant_id]
            ref_specific_strs.append(format_motif_distribution(motif_dicts['ref_specific_motifs'], include_model=True))
            alt_specific_strs.append(format_motif_distribution(motif_dicts['alt_specific_motifs'], include_model=True))
            unchanged_strs.append(format_motif_distribution(motif_dicts['unchanged_motifs_in_region'], include_model=False))
        else:
            ref_specific_strs.append("")
            alt_specific_strs.append("")
            unchanged_strs.append("")
    
    df['ref_specific_motifs'] = ref_specific_strs
    df['alt_specific_motifs'] = alt_specific_strs
    df['unchanged_motifs_in_region'] = unchanged_strs
    
    # STEP: Merge additional TSV files
    print("\nMerging additional TSV files...")
    
    if variant_dataset and SPLITS_DIR:
        # Build list of files to merge
        files_to_merge = []
        
        # Add closest_elements.tsv
        closest_elements_file = os.path.join(SPLITS_DIR, f"{variant_dataset}.closest_elements.tsv")
        if os.path.exists(closest_elements_file):
            files_to_merge.append(closest_elements_file)
            print(f"  Adding closest_elements.tsv")
        else:
            print(f"  Warning: closest_elements.tsv not found: {closest_elements_file}")
        
        # Add logfc.tsv
        logfc_file = os.path.join(SPLITS_DIR, f"{variant_dataset}.logfc.tsv")
        if os.path.exists(logfc_file):
            files_to_merge.append(logfc_file)
            print(f"  Adding logfc.tsv")
        else:
            print(f"  Warning: logfc.tsv not found: {logfc_file}")
        
        # Add aaq.tsv
        aaq_file = os.path.join(SPLITS_DIR, f"{variant_dataset}.aaq.tsv")
        if os.path.exists(aaq_file):
            files_to_merge.append(aaq_file)
            print(f"  Adding aaq.tsv")
        else:
            print(f"  Warning: aaq.tsv not found: {aaq_file}")
        
        # Add VEP most severe consequences TSV
        vep_file = os.path.join(SPLITS_DIR, f"{variant_dataset}.VEP.most_severe_csqs.tsv")
        if os.path.exists(vep_file):
            files_to_merge.append(vep_file)
            print(f"  Adding VEP.most_severe_csqs.tsv")
        else:
            print(f"  Warning: VEP.most_severe_csqs.tsv not found: {vep_file}")
        
        # Add GENCODE region type TSV
        gencode_file = os.path.join(SPLITS_DIR, f"{variant_dataset}.GENCODE.region_type.tsv")
        if os.path.exists(gencode_file):
            files_to_merge.append(gencode_file)
            print(f"  Adding GENCODE.region_type.tsv")
        else:
            print(f"  Warning: GENCODE.region_type.tsv not found: {gencode_file}")
        
        # Add gnomAD TSV
        gnomad_file = os.path.join(SPLITS_DIR, f"{variant_dataset}.gnomad.tsv")
        if os.path.exists(gnomad_file):
            files_to_merge.append(gnomad_file)
            print(f"  Adding gnomad.tsv")
        else:
            print(f"  Warning: gnomad.tsv not found: {gnomad_file}")
        
        # Add patient_hpo_expanded.tsv
        hpo_file = os.path.join(SPLITS_DIR, f"{variant_dataset}.patient_hpo_expanded.tsv")
        if os.path.exists(hpo_file):
            files_to_merge.append(hpo_file)
            print(f"  Adding patient_hpo_expanded.tsv")
        else:
            print(f"  Warning: patient_hpo_expanded.tsv not found: {hpo_file}")
        
        # Merge all files
        if files_to_merge:
            print(f"\nMerging {len(files_to_merge)} additional files...")
            for merge_file in files_to_merge:
                try:
                    df_merge = pd.read_csv(merge_file, sep='\t')
                    # Merge on variant_id
                    df = df.merge(df_merge, on='variant_id', how='left')
                    print(f"  Merged {os.path.basename(merge_file)}: {len(df_merge)} rows")
                except Exception as e:
                    print(f"  Warning: Error merging {merge_file}: {e}", file=sys.stderr)
    
    # Filter per-model columns to only prioritized models
    print("\nFiltering per-model columns to only prioritized models...")
    prioritized_models = set()
    for col in df.columns:
        if col.startswith('model_prioritized_by_any-'):
            model_name = col.replace('model_prioritized_by_any-', '')
            # Check if any variant is prioritized by this model
            col_values = df[col]
            is_prioritized = (
                (col_values.astype(str).str.lower() == 'true') |
                (col_values == True) |
                (col_values == 1) |
                (col_values.astype(str) == '1')
            ).fillna(False)
            
            if is_prioritized.any():
                prioritized_models.add(model_name)
    
    print(f"Found {len(prioritized_models)} prioritized models")
    
    # Filter per-model columns
    non_per_model_cols = [
        col for col in df.columns 
        if not (col.startswith('logfc-') or 
               col.startswith('aaq-') or 
               col.startswith('model_prioritized_by_any-'))
    ]
    
    # Filter model_prioritized_by_any-{model} columns
    model_prio_cols_to_keep = [
        col for col in df.columns 
        if col.startswith('model_prioritized_by_any-') and 
           col.replace('model_prioritized_by_any-', '') in prioritized_models
    ]
    
    # Filter logfc-{model} columns
    logfc_cols_to_keep = [
        col for col in df.columns 
        if col.startswith('logfc-') and 
           col.replace('logfc-', '') in prioritized_models
    ]
    
    # Filter aaq-{model} columns
    aaq_cols_to_keep = [
        col for col in df.columns 
        if col.startswith('aaq-') and 
           col.replace('aaq-', '') in prioritized_models
    ]
    
    # Combine all columns to keep
    cols_to_keep = non_per_model_cols + model_prio_cols_to_keep + logfc_cols_to_keep + aaq_cols_to_keep
    
    # Filter dataframe
    df = df[cols_to_keep].copy()
    print(f"Filtered per-model columns: kept {len(model_prio_cols_to_keep)} model_prioritized_by_any-, {len(logfc_cols_to_keep)} logfc-, {len(aaq_cols_to_keep)} aaq- columns")
    
    # Add HTML URL column
    print("\nAdding HTML URL column...")
    if variant_dataset and model_dataset:
        html_urls = []
        dummy_variant_id = "DUMMY_VARIANT_ID_FOR_URL_PATTERN"
        dummy_html_path = f"{OUTPUT_DIR}/{variant_dataset}/{model_dataset}/cluster_0/{dummy_variant_id}/00-summary.html"
        
        try:
            result = subprocess.run(
                ['mitra-utils', 'url', dummy_html_path],
                capture_output=True,
                text=True,
                check=True
            )
            url_pattern = result.stdout.strip()
            
            # Generate URLs for all variants
            for variant_id in df['variant_id']:
                url = url_pattern.replace(dummy_variant_id, variant_id)
                html_urls.append(url)
            
            df['variant_html_url'] = html_urls
            
            # Reorder columns so variant_html_url is the second column (after variant_id)
            cols = list(df.columns)
            if 'variant_id' in cols and 'variant_html_url' in cols:
                cols.remove('variant_html_url')
                variant_id_idx = cols.index('variant_id')
                cols.insert(variant_id_idx + 1, 'variant_html_url')
                df = df[cols]
                print(f"Added HTML URL column: {sum(1 for url in html_urls if url)}/{len(html_urls)} variants have URLs")
        except Exception as e:
            print(f"Warning: Could not add HTML URL column: {e}", file=sys.stderr)
            df['variant_html_url'] = ""
    else:
        df['variant_html_url'] = ""
    
    # Create merged closest element columns
    print("\nCreating merged closest element columns...")
    df['closest_genes'] = df.apply(lambda row: merge_closest_elements(row, 'genes', 8), axis=1)
    df['closest_miRNA'] = df.apply(lambda row: merge_closest_elements(row, 'miRNA', 5), axis=1)
    df['closest_lncRNA'] = df.apply(lambda row: merge_closest_elements(row, 'lncRNA', 5), axis=1)
    
    # Remove individual closest element columns
    cols_to_remove = []
    for i in range(1, 9):
        cols_to_remove.extend([f'closest_genes_{i}', f'closest_genes_distance_{i}'])
    for i in range(1, 6):
        cols_to_remove.extend([f'closest_miRNA_{i}', f'closest_miRNA_distance_{i}'])
        cols_to_remove.extend([f'closest_lncRNA_{i}', f'closest_lncRNA_distance_{i}'])
    
    cols_to_remove = [col for col in cols_to_remove if col in df.columns]
    df = df.drop(columns=cols_to_remove)
    print(f"Created merged closest element columns and removed {len(cols_to_remove)} individual columns")
    
    # Reorder columns so merged closest element columns appear after variant_html_url
    cols = list(df.columns)
    merged_closest_cols = ['closest_genes', 'closest_miRNA', 'closest_lncRNA']
    if 'variant_html_url' in cols:
        for col in merged_closest_cols:
            if col in cols:
                cols.remove(col)
        variant_html_url_idx = cols.index('variant_html_url')
        for i, col in enumerate(merged_closest_cols):
            if col in df.columns:
                cols.insert(variant_html_url_idx + 1 + i, col)
        df = df[cols]
    
    # Create simons_searchlight_genes_in_closest_genes column
    print("\nCreating simons_searchlight_genes_in_closest_genes column...")
    searchlight_genes_file = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/curation/broad/varbook-container/snakemake/refs/simonss_searchlight_genes.txt"
    searchlight_genes = set()
    
    if os.path.exists(searchlight_genes_file):
        with open(searchlight_genes_file, 'r') as f:
            for line in f:
                gene = line.strip().rstrip(',').strip()
                if gene:
                    searchlight_genes.add(gene.upper())
        print(f"Loaded {len(searchlight_genes)} genes from simons's searchlight genes file")
    else:
        print(f"Warning: simons's searchlight genes file not found: {searchlight_genes_file}")
    
    df['simons_searchlight_genes_in_closest_genes'] = df['closest_genes'].apply(
        lambda x: filter_searchlight_genes(x, searchlight_genes)
    )
    
    # Reorder columns so simons_searchlight_genes_in_closest_genes appears after closest_genes
    cols = list(df.columns)
    if 'closest_genes' in cols and 'simons_searchlight_genes_in_closest_genes' in cols:
        cols.remove('simons_searchlight_genes_in_closest_genes')
        closest_genes_idx = cols.index('closest_genes')
        cols.insert(closest_genes_idx + 1, 'simons_searchlight_genes_in_closest_genes')
        df = df[cols]
    
    # Create G2P intersection columns
    print("\nCreating G2P intersection columns...")
    g2p_file = "/oak/stanford/groups/akundaje/airanman/refs/gene2phenotype/2025-12-12/AllG2P.csv"
    g2p_lookup = {}
    
    if os.path.exists(g2p_file):
        try:
            df_g2p = pd.read_csv(g2p_file, sep=',', low_memory=False)
            print(f"Loaded G2P file with {len(df_g2p)} entries")
            
            for idx, row in df_g2p.iterrows():
                gene_symbol = str(row.get('gene symbol', '')).strip().upper()
                if not gene_symbol or gene_symbol == 'NAN' or pd.isna(row.get('gene symbol')):
                    continue
                
                g2p_entry = {
                    'disease_name': str(row.get('disease name', '')).strip() if pd.notna(row.get('disease name')) else 'N/A',
                    'confidence': str(row.get('confidence', '')).strip() if pd.notna(row.get('confidence')) else 'N/A',
                    'molecular mechanism': str(row.get('molecular mechanism', '')).strip() if pd.notna(row.get('molecular mechanism')) else 'N/A',
                    'molecular mechanism categorisation': str(row.get('molecular mechanism categorisation', '')).strip() if pd.notna(row.get('molecular mechanism categorisation')) else 'N/A',
                    'variant_consequence': str(row.get('variant consequence', '')).strip() if pd.notna(row.get('variant consequence')) else 'N/A',
                    'variant_types': str(row.get('variant types', '')).strip() if pd.notna(row.get('variant types')) else 'N/A',
                    'phenotypes': str(row.get('phenotypes', '')).strip() if pd.notna(row.get('phenotypes')) else 'N/A',
                    'panel': str(row.get('panel', '')).strip() if pd.notna(row.get('panel')) else 'N/A'
                }
                
                if gene_symbol not in g2p_lookup:
                    g2p_lookup[gene_symbol] = []
                g2p_lookup[gene_symbol].append(g2p_entry)
                
                # Also add to lookup by previous gene symbols
                previous_symbols_str = str(row.get('previous gene symbols', '')).strip()
                if previous_symbols_str and previous_symbols_str != 'NAN' and pd.notna(row.get('previous gene symbols')):
                    previous_symbols = [s.strip().upper() for s in previous_symbols_str.split(';') if s.strip()]
                    for prev_symbol in previous_symbols:
                        if prev_symbol:
                            if prev_symbol not in g2p_lookup:
                                g2p_lookup[prev_symbol] = []
                            g2p_lookup[prev_symbol].append(g2p_entry)
            
            print(f"Created G2P lookup with {len(g2p_lookup)} unique gene symbols")
        except Exception as e:
            print(f"Error loading G2P file: {e}", file=sys.stderr)
            g2p_lookup = {}
    else:
        print(f"Warning: G2P file not found: {g2p_file}")
        g2p_lookup = {}
    
    df['g2p_genes_in_closest_genes'] = df['closest_genes'].apply(
        lambda x: intersect_with_g2p(x, g2p_lookup, 'gene')
    )
    df['g2p_miRNA_in_closest_miRNA'] = df['closest_miRNA'].apply(
        lambda x: intersect_with_g2p(x, g2p_lookup, 'miRNA')
    )
    df['g2p_lncRNA_in_closest_lncRNA'] = df['closest_lncRNA'].apply(
        lambda x: intersect_with_g2p(x, g2p_lookup, 'lncRNA')
    )
    
    # Reorder columns so G2P columns appear after their respective closest element columns
    cols = list(df.columns)
    g2p_cols = ['g2p_genes_in_closest_genes', 'g2p_miRNA_in_closest_miRNA', 'g2p_lncRNA_in_closest_lncRNA']
    
    if 'g2p_genes_in_closest_genes' in cols:
        cols.remove('g2p_genes_in_closest_genes')
        if 'simons_searchlight_genes_in_closest_genes' in cols:
            insert_idx = cols.index('simons_searchlight_genes_in_closest_genes') + 1
        elif 'closest_genes' in cols:
            insert_idx = cols.index('closest_genes') + 1
        else:
            insert_idx = len(cols)
        cols.insert(insert_idx, 'g2p_genes_in_closest_genes')
    
    if 'g2p_miRNA_in_closest_miRNA' in cols:
        cols.remove('g2p_miRNA_in_closest_miRNA')
        if 'closest_miRNA' in cols:
            insert_idx = cols.index('closest_miRNA') + 1
        else:
            insert_idx = len(cols)
        cols.insert(insert_idx, 'g2p_miRNA_in_closest_miRNA')
    
    if 'g2p_lncRNA_in_closest_lncRNA' in cols:
        cols.remove('g2p_lncRNA_in_closest_lncRNA')
        if 'closest_lncRNA' in cols:
            insert_idx = cols.index('closest_lncRNA') + 1
        else:
            insert_idx = len(cols)
        cols.insert(insert_idx, 'g2p_lncRNA_in_closest_lncRNA')
    
    df = df[cols]
    
    # Calculate mean columns for prioritized fetal brain models
    print("\nCalculating mean columns for prioritized fetal brain models...")
    
    # Identify fetal brain models (models starting with KUN_FB)
    all_models = set()
    for col in df.columns:
        if col.startswith('model_prioritized_by_any-'):
            model_name = col.replace('model_prioritized_by_any-', '')
            all_models.add(model_name)
        elif col.startswith('logfc-'):
            model_name = col.replace('logfc-', '')
            all_models.add(model_name)
        elif col.startswith('aaq-'):
            model_name = col.replace('aaq-', '')
            all_models.add(model_name)
    
    # Filter to only fetal brain models
    fetal_brain_models = [model for model in all_models if model.startswith('KUN_FB')]
    print(f"Found {len(fetal_brain_models)} fetal brain models")
    
    if fetal_brain_models:
        # Calculate mean columns for each variant
        mean_logfc_values = []
        mean_aaq_values = []
        mean_logfc_x_aaq_values = []
        
        for _, row in df.iterrows():
            # Find prioritized fetal brain models for this variant
            prioritized_fb_models = []
            for model in fetal_brain_models:
                prio_col = f'model_prioritized_by_any-{model}'
                if prio_col in df.columns:
                    prio_value = row[prio_col]
                    # Check if prioritized (handle various formats: True, 'True', 1, '1')
                    is_prioritized = (
                        (pd.notna(prio_value) and 
                         (prio_value == True or 
                          str(prio_value).lower() == 'true' or 
                          prio_value == 1 or 
                          str(prio_value) == '1'))
                    )
                    if is_prioritized:
                        prioritized_fb_models.append(model)
            
            # Collect logfc and aaq values for prioritized fetal brain models
            logfc_values = []
            aaq_values = []
            logfc_x_aaq_values = []
            
            for model in prioritized_fb_models:
                logfc_col = f'logfc-{model}'
                aaq_col = f'aaq-{model}'
                
                if logfc_col in df.columns and pd.notna(row[logfc_col]):
                    logfc_val = row[logfc_col]
                    try:
                        logfc_val = float(logfc_val)
                        logfc_values.append(logfc_val)
                        
                        if aaq_col in df.columns and pd.notna(row[aaq_col]):
                            aaq_val = row[aaq_col]
                            try:
                                aaq_val = float(aaq_val)
                                aaq_values.append(aaq_val)
                                logfc_x_aaq_values.append(logfc_val * aaq_val)
                            except (ValueError, TypeError):
                                pass
                    except (ValueError, TypeError):
                        pass
                elif aaq_col in df.columns and pd.notna(row[aaq_col]):
                    aaq_val = row[aaq_col]
                    try:
                        aaq_val = float(aaq_val)
                        aaq_values.append(aaq_val)
                    except (ValueError, TypeError):
                        pass
            
            # Calculate means
            mean_logfc = np.nanmean(logfc_values) if logfc_values else np.nan
            mean_aaq = np.nanmean(aaq_values) if aaq_values else np.nan
            mean_logfc_x_aaq = np.nanmean(logfc_x_aaq_values) if logfc_x_aaq_values else np.nan
            
            mean_logfc_values.append(mean_logfc)
            mean_aaq_values.append(mean_aaq)
            mean_logfc_x_aaq_values.append(mean_logfc_x_aaq)
        
        # Add columns to dataframe
        df['mean_prioritized_logfc'] = mean_logfc_values
        df['mean_prioritized_aaq'] = mean_aaq_values
        df['mean_prioritized_logfc_x_aaq'] = mean_logfc_x_aaq_values
        
        # Count variants with non-null values
        non_null_logfc = df['mean_prioritized_logfc'].notna().sum()
        non_null_aaq = df['mean_prioritized_aaq'].notna().sum()
        non_null_logfc_x_aaq = df['mean_prioritized_logfc_x_aaq'].notna().sum()
        
        print(f"Added mean columns:")
        print(f"  mean_prioritized_logfc: {non_null_logfc}/{len(df)} variants have values")
        print(f"  mean_prioritized_aaq: {non_null_aaq}/{len(df)} variants have values")
        print(f"  mean_prioritized_logfc_x_aaq: {non_null_logfc_x_aaq}/{len(df)} variants have values")
    else:
        print("Warning: No fetal brain models found, setting mean columns to NaN")
        df['mean_prioritized_logfc'] = np.nan
        df['mean_prioritized_aaq'] = np.nan
        df['mean_prioritized_logfc_x_aaq'] = np.nan
    
    # Drop logfc and score columns (but keep the filtered logfc columns we already have)
    print("\nDropping score columns...")
    score_cols_to_drop = [col for col in df.columns if col.startswith('score-')]
    df = df.drop(columns=score_cols_to_drop)
    print(f"Dropped {len(score_cols_to_drop)} score columns")
    
    # Reorder columns: variant_id, chr, pos, ref, alt, then motif columns, then rest
    print("\nReordering columns...")
    cols = list(df.columns)
    
    # Define the columns to move
    genomic_cols = ['chr', 'pos', 'ref', 'alt']
    motif_cols = ['ref_specific_motifs', 'alt_specific_motifs', 'unchanged_motifs_in_region']
    
    # Remove columns from their current positions
    for col in genomic_cols + motif_cols:
        if col in cols:
            cols.remove(col)
    
    # Find variant_id index
    if 'variant_id' in cols:
        variant_id_idx = cols.index('variant_id')
        # Insert genomic columns right after variant_id
        for i, col in enumerate(genomic_cols):
            if col in df.columns:
                cols.insert(variant_id_idx + 1 + i, col)
        
        # Find the position after the last genomic column (or variant_id if no genomic cols)
        insert_idx = variant_id_idx + 1 + len([c for c in genomic_cols if c in df.columns])
        
        # Insert motif columns after genomic columns
        for i, col in enumerate(motif_cols):
            if col in df.columns:
                cols.insert(insert_idx + i, col)
    else:
        # If variant_id not found, just prepend the columns
        cols = genomic_cols + motif_cols + cols
    
    # Reorder dataframe
    df = df[cols]
    print(f"Reordered columns: variant_id, {', '.join([c for c in genomic_cols if c in df.columns])}, {', '.join([c for c in motif_cols if c in df.columns])}, ...")
    
    # Save output
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    df.to_csv(output_file, sep='\t', index=False)
    
    end_time = time.time()
    duration = end_time - start_time
    print(f"\n[TIMING] Clustering completed in {duration:.1f}s")
    print(f"Saved {len(df)} variants to {output_file}")
    print(f"Total columns: {len(df.columns)}")
    print(f"{'='*80}\n")
    
    return output_file


def main():
    parser = argparse.ArgumentParser(
        description='Generate clustered TSV files from filtered TSV files with human-readable columns. '
                    'Maps motif names using hdma_motif_mapping.tsv.'
    )
    
    # Input options (mutually exclusive)
    input_group = parser.add_mutually_exclusive_group(required=True)
    input_group.add_argument(
        '--input',
        help='Path to filtered TSV file (must contain score-{model} columns)'
    )
    input_group.add_argument(
        '--variant-dataset',
        help='Variant dataset name (e.g., "Broad neurodevelopmental and neuromuscular disorders"). '
             'Requires --model-dataset.'
    )
    
    parser.add_argument(
        '--model-dataset',
        help='Model dataset name (e.g., "Fetal Brain"). Required if using --variant-dataset.'
    )
    
    parser.add_argument(
        '--output',
        help='Output file path (default: auto-generated from input file, replacing .filtered.tsv with .clustered.tsv)'
    )
    
    parser.add_argument(
        '--n-clusters',
        type=int,
        default=35,
        help='Number of clusters (default: 35)'
    )
    
    parser.add_argument(
        '--random-state',
        type=int,
        default=42,
        help='Random state for reproducibility (default: 42)'
    )
    
    parser.add_argument(
        '--motif-mapping-file',
        default='hdma_motif_mapping.tsv',
        help='Path to motif mapping TSV file (default: hdma_motif_mapping.tsv)'
    )
    
    args = parser.parse_args()
    
    # Validate arguments
    if args.variant_dataset and not args.model_dataset:
        parser.error("--model-dataset is required when using --variant-dataset")
    
    # Determine input file
    if args.input:
        input_file = args.input
    else:
        input_file = get_filtered_tsv_path(args.variant_dataset, args.model_dataset)
        print(f"Auto-constructed input path: {input_file}")
    
    # Run clustering
    try:
        output_file = run_clustering(
            input_file=input_file,
            output_file=args.output,
            n_clusters=args.n_clusters,
            random_state=args.random_state,
            motif_mapping_file=args.motif_mapping_file,
            variant_dataset=args.variant_dataset,
            model_dataset=args.model_dataset
        )
        print(f"Success! Output saved to: {output_file}")
        return 0
    except Exception as e:
        print(f"Error: {e}", file=sys.stderr)
        import traceback
        traceback.print_exc()
        return 1


if __name__ == '__main__':
    sys.exit(main())

