#!/usr/bin/env python3
"""
Shared configuration module for model datasets.

This module provides a unified way to access model dataset configurations
used across multiple scripts (run_filter.py, run_kmeans.py, etc.).

The configuration can be read from Snakefile.original or use a fallback
hardcoded mapping.
"""

import os
import re


def get_model_datasets_config(snakefile_path=None):
    """
    Get VARIANT_DATASET_CONFIGS from Snakefile or return fallback config.
    
    Parameters:
    -----------
    snakefile_path : str, optional
        Path to Snakefile.original. If None, tries to find it automatically.
    
    Returns:
    --------
    dict
        VARIANT_DATASET_CONFIGS dictionary
    """
    # Try to read from Snakefile if available
    if snakefile_path is None:
        # Try common locations
        script_dir = os.path.dirname(os.path.abspath(__file__))
        potential_paths = [
            os.path.join(script_dir, "Snakefile.original"),
            os.path.join(script_dir, "Snakefile"),
        ]
        for path in potential_paths:
            if os.path.exists(path):
                snakefile_path = path
                break
    
    if snakefile_path and os.path.exists(snakefile_path):
        try:
            # Read the Snakefile and extract VARIANT_DATASET_CONFIGS using exec
            with open(snakefile_path, 'r') as f:
                content = f.read()
            
            # Create a namespace to execute the config
            namespace = {}
            # Extract just the VARIANT_DATASET_CONFIGS definition
            # Find the start and end of the dictionary
            # Look for VARIANT_DATASET_CONFIGS = { ... }
            pattern = r'VARIANT_DATASET_CONFIGS\s*=\s*\{'
            match = re.search(pattern, content)
            if match:
                start_pos = match.start()
                # Find matching closing brace (simple approach - count braces)
                brace_count = 0
                in_string = False
                string_char = None
                i = start_pos + match.end() - 1
                while i < len(content):
                    char = content[i]
                    if char in ('"', "'") and (i == 0 or content[i-1] != '\\'):
                        if not in_string:
                            in_string = True
                            string_char = char
                        elif char == string_char:
                            in_string = False
                            string_char = None
                    elif not in_string:
                        if char == '{':
                            brace_count += 1
                        elif char == '}':
                            brace_count -= 1
                            if brace_count == 0:
                                # Found the end
                                config_str = content[start_pos:i+1]
                                # Execute in namespace
                                exec(config_str, namespace)
                                if 'VARIANT_DATASET_CONFIGS' in namespace:
                                    print(f"Loaded VARIANT_DATASET_CONFIGS from {snakefile_path}")
                                    return namespace['VARIANT_DATASET_CONFIGS']
                                break
                    i += 1
        except Exception as e:
            print(f"Warning: Could not parse Snakefile for config: {e}")
    
    # Fallback: Return hardcoded config
    print("Using fallback hardcoded VARIANT_DATASET_CONFIGS")
    return get_fallback_config()


def get_fallback_config():
    """
    Get fallback hardcoded VARIANT_DATASET_CONFIGS.
    
    Returns:
    --------
    dict
        Fallback VARIANT_DATASET_CONFIGS
    """
    return {
        "Broad neurodevelopmental and neuromuscular disorders": [
            {
                'name': 'Fetal Brain',
                'models': ['KUN_FB*'],
                'model_superset': ['KUN_FB*', 'KUN_HDMA*'],
            },
        ],
    }


def get_model_datasets_list(variant_dataset, snakefile_path=None):
    """
    Get the list of model dataset configs for a variant dataset.
    
    Handles both list and dict formats in VARIANT_DATASET_CONFIGS.
    
    Parameters:
    -----------
    variant_dataset : str
        Variant dataset name
    snakefile_path : str, optional
        Path to Snakefile.original
    
    Returns:
    --------
    list
        List of model dataset config dicts
    """
    configs = get_model_datasets_config(snakefile_path)
    
    if variant_dataset not in configs:
        return []
    
    config = configs[variant_dataset]
    
    # Handle both list and dict formats
    if isinstance(config, dict) and 'model_datasets' in config:
        return config['model_datasets']
    elif isinstance(config, list):
        return config
    else:
        raise ValueError(
            f"Invalid config format for variant_dataset '{variant_dataset}': "
            f"expected list or dict with 'model_datasets' key"
        )


def get_model_patterns_from_dataset(variant_dataset, model_dataset_name, snakefile_path=None):
    """
    Get model patterns for a model_dataset name.
    
    Parameters:
    -----------
    variant_dataset : str
        Variant dataset name
    model_dataset_name : str
        Model dataset name (e.g., "Fetal Brain")
    snakefile_path : str, optional
        Path to Snakefile.original
    
    Returns:
    --------
    list of str
        Model patterns (e.g., ["KUN_FB*"])
    """
    model_datasets = get_model_datasets_list(variant_dataset, snakefile_path)
    
    # Find matching model_dataset
    for config in model_datasets:
        if isinstance(config, dict) and config.get('name') == model_dataset_name:
            models = config.get('models', [])
            if models:
                return models
    
    # If not found, raise error
    raise ValueError(
        f"Could not find model patterns for model_dataset '{model_dataset_name}' "
        f"in variant_dataset '{variant_dataset}'. "
        f"Available model_datasets: {[c.get('name') for c in model_datasets if isinstance(c, dict)]}"
    )


def list_model_datasets(variant_dataset, snakefile_path=None):
    """
    List all available model datasets for a variant dataset.
    
    Parameters:
    -----------
    variant_dataset : str
        Variant dataset name
    snakefile_path : str, optional
        Path to Snakefile.original
    
    Returns:
    --------
    list of str
        List of model dataset names
    """
    model_datasets = get_model_datasets_list(variant_dataset, snakefile_path)
    return [c.get('name') for c in model_datasets if isinstance(c, dict) and 'name' in c]



