#!/usr/bin/env python3
"""
Exploratory Data Analysis (EDA) for ChromBPNet variant analysis.

This script performs comprehensive EDA on the "Broad neurodevelopmental and neuromuscular disorders"
variant set across Fetal Brain (KUN_FB*) models, with optional HDMA model support.

Inspired by the HDMA motif compendium analysis:
https://greenleaflab.github.io/HDMA/code/03-chrombpnet/03-syntax/01-motif_compendium.html

Analyses include:
1. Motif distribution analysis (frequencies, per-variant counts, top motifs)
2. Cluster-based motif analysis (motif patterns across kmeans clusters)
3. Prioritization analysis (model prioritization patterns, AAQ distributions)
4. Motif-prioritization relationship analysis
5. Summary visualizations
"""

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import re
from collections import Counter, defaultdict
from typing import Dict, List, Tuple, Optional
import argparse
import warnings
warnings.filterwarnings('ignore')

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10


def parse_motif_string(motif_str: str, motif_type: str = 'all') -> List[Tuple[str, int, Optional[str]]]:
    """
    Parse motif strings from the clustered TSV.
    
    Format examples:
    - "CTCF (2, KUN_FB_neuron), REST (1, KUN_FB_astrocyte)"
    - "ZEB/SNAI (7), HD:DLX/LHX#1 (5)"
    
    Returns:
    --------
    List of tuples: (motif_name, count, model_name)
    """
    if pd.isna(motif_str) or motif_str == '':
        return []
    
    motifs = []
    # Pattern to match: "MOTIF_NAME (count, MODEL)" or "MOTIF_NAME (count)"
    pattern = r'([^(]+)\s*\((\d+)(?:,\s*([^)]+))?\)'
    
    for match in re.finditer(pattern, str(motif_str)):
        motif_name = match.group(1).strip()
        count = int(match.group(2))
        model_name = match.group(3).strip() if match.group(3) else None
        
        motifs.append((motif_name, count, model_name))
    
    return motifs


def load_data(clustered_tsv: str) -> pd.DataFrame:
    """Load the clustered TSV file."""
    print(f"Loading data from: {clustered_tsv}")
    df = pd.read_csv(clustered_tsv, sep='\t')
    print(f"Loaded {len(df)} variants with {len(df.columns)} columns")
    return df


def extract_fetal_brain_models(df: pd.DataFrame) -> List[str]:
    """Extract list of Fetal Brain (KUN_FB*) models from column names."""
    priority_cols = [col for col in df.columns if col.startswith('model_prioritized_by_any-KUN_FB')]
    models = [col.replace('model_prioritized_by_any-', '') for col in priority_cols]
    return sorted(models)


def analyze_motif_distributions(df: pd.DataFrame) -> Dict:
    """
    Analyze motif distributions across all variants.
    
    Returns dictionary with:
    - ref_specific: Counter of ref-specific motifs
    - alt_specific: Counter of alt-specific motifs
    - unchanged: Counter of unchanged motifs
    - per_variant_counts: Statistics on motifs per variant
    """
    print("\n" + "="*80)
    print("1. MOTIF DISTRIBUTION ANALYSIS")
    print("="*80)
    
    results = {
        'ref_specific': Counter(),
        'alt_specific': Counter(),
        'unchanged': Counter(),
        'per_variant_counts': {
            'ref': [],
            'alt': [],
            'unchanged': []
        }
    }
    
    # Process each variant
    for idx, row in df.iterrows():
        # Ref-specific motifs
        ref_motifs = parse_motif_string(row.get('ref_specific_motifs', ''))
        for motif_name, count, model in ref_motifs:
            results['ref_specific'][motif_name] += count
        results['per_variant_counts']['ref'].append(len(ref_motifs))
        
        # Alt-specific motifs
        alt_motifs = parse_motif_string(row.get('alt_specific_motifs', ''))
        for motif_name, count, model in alt_motifs:
            results['alt_specific'][motif_name] += count
        results['per_variant_counts']['alt'].append(len(alt_motifs))
        
        # Unchanged motifs
        unchanged_motifs = parse_motif_string(row.get('unchanged_motifs_in_region', ''))
        for motif_name, count, model in unchanged_motifs:
            results['unchanged'][motif_name] += count
        results['per_variant_counts']['unchanged'].append(len(unchanged_motifs))
    
    # Print summary
    print(f"\nRef-specific motifs: {len(results['ref_specific'])} unique motifs")
    print(f"Alt-specific motifs: {len(results['alt_specific'])} unique motifs")
    print(f"Unchanged motifs: {len(results['unchanged'])} unique motifs")
    
    print(f"\nAverage motifs per variant:")
    print(f"  Ref-specific: {np.mean(results['per_variant_counts']['ref']):.2f}")
    print(f"  Alt-specific: {np.mean(results['per_variant_counts']['alt']):.2f}")
    print(f"  Unchanged: {np.mean(results['per_variant_counts']['unchanged']):.2f}")
    
    # Top motifs
    print(f"\nTop 10 ref-specific motifs:")
    for motif, count in results['ref_specific'].most_common(10):
        print(f"  {motif}: {count}")
    
    print(f"\nTop 10 alt-specific motifs:")
    for motif, count in results['alt_specific'].most_common(10):
        print(f"  {motif}: {count}")
    
    print(f"\nTop 10 unchanged motifs:")
    for motif, count in results['unchanged'].most_common(10):
        print(f"  {motif}: {count}")
    
    return results


def analyze_cluster_motifs(df: pd.DataFrame, motif_results: Dict) -> Dict:
    """
    Analyze motif patterns across kmeans clusters.
    
    Returns dictionary with cluster-specific motif distributions.
    """
    print("\n" + "="*80)
    print("2. CLUSTER-BASED MOTIF ANALYSIS")
    print("="*80)
    
    if 'kmeans_35' not in df.columns:
        print("Warning: kmeans_35 column not found. Skipping cluster analysis.")
        return {}
    
    cluster_motifs = defaultdict(lambda: {
        'ref_specific': Counter(),
        'alt_specific': Counter(),
        'unchanged': Counter(),
        'n_variants': 0
    })
    
    # Process each variant by cluster
    for idx, row in df.iterrows():
        cluster = row.get('kmeans_35')
        if pd.isna(cluster):
            continue
        
        cluster_motifs[cluster]['n_variants'] += 1
        
        # Ref-specific
        ref_motifs = parse_motif_string(row.get('ref_specific_motifs', ''))
        for motif_name, count, model in ref_motifs:
            cluster_motifs[cluster]['ref_specific'][motif_name] += count
        
        # Alt-specific
        alt_motifs = parse_motif_string(row.get('alt_specific_motifs', ''))
        for motif_name, count, model in alt_motifs:
            cluster_motifs[cluster]['alt_specific'][motif_name] += count
        
        # Unchanged
        unchanged_motifs = parse_motif_string(row.get('unchanged_motifs_in_region', ''))
        for motif_name, count, model in unchanged_motifs:
            cluster_motifs[cluster]['unchanged'][motif_name] += count
    
    # Print summary
    print(f"\nFound {len(cluster_motifs)} clusters")
    for cluster in sorted(cluster_motifs.keys()):
        n_vars = cluster_motifs[cluster]['n_variants']
        n_ref = len(cluster_motifs[cluster]['ref_specific'])
        n_alt = len(cluster_motifs[cluster]['alt_specific'])
        n_unchanged = len(cluster_motifs[cluster]['unchanged'])
        print(f"  Cluster {cluster}: {n_vars} variants, {n_ref} ref-specific, {n_alt} alt-specific, {n_unchanged} unchanged motifs")
    
    return dict(cluster_motifs)


def analyze_prioritization(df: pd.DataFrame, models: List[str]) -> Dict:
    """
    Analyze variant prioritization patterns across models.
    
    Returns dictionary with prioritization statistics.
    """
    print("\n" + "="*80)
    print("3. PRIORITIZATION ANALYSIS")
    print("="*80)
    
    results = {
        'model_prioritization_counts': {},
        'variants_by_n_models': Counter(),
        'aaq_stats': {}
    }
    
    # Count prioritizations per model
    for model in models:
        col = f'model_prioritized_by_any-{model}'
        if col in df.columns:
            # Count True/1/'True'/'1' values
            prioritized = df[col].apply(lambda x: (
                pd.notna(x) and 
                (x == True or str(x).lower() == 'true' or x == 1 or str(x) == '1')
            )).sum()
            results['model_prioritization_counts'][model] = prioritized
    
    # Count how many models prioritize each variant
    priority_cols = [f'model_prioritized_by_any-{model}' for model in models if f'model_prioritized_by_any-{model}' in df.columns]
    
    for idx, row in df.iterrows():
        n_prioritized = 0
        for col in priority_cols:
            val = row[col]
            if pd.notna(val) and (val == True or str(val).lower() == 'true' or val == 1 or str(val) == '1'):
                n_prioritized += 1
        results['variants_by_n_models'][n_prioritized] += 1
    
    # AAQ statistics
    aaq_cols = [f'aaq-{model}' for model in models if f'aaq-{model}' in df.columns]
    if aaq_cols:
        all_aaq = []
        for col in aaq_cols:
            values = df[col].dropna()
            all_aaq.extend(values.tolist())
        
        if all_aaq:
            results['aaq_stats'] = {
                'mean': np.mean(all_aaq),
                'median': np.median(all_aaq),
                'std': np.std(all_aaq),
                'min': np.min(all_aaq),
                'max': np.max(all_aaq)
            }
    
    # Print summary
    print(f"\nPrioritization by model:")
    for model, count in sorted(results['model_prioritization_counts'].items(), key=lambda x: x[1], reverse=True):
        print(f"  {model}: {count} variants")
    
    print(f"\nVariants prioritized by N models:")
    for n_models in sorted(results['variants_by_n_models'].keys()):
        count = results['variants_by_n_models'][n_models]
        print(f"  {n_models} models: {count} variants")
    
    if results['aaq_stats']:
        print(f"\nAAQ statistics:")
        for stat, value in results['aaq_stats'].items():
            print(f"  {stat}: {value:.4f}")
    
    return results


def analyze_motif_prioritization_relationships(df: pd.DataFrame, models: List[str]) -> Dict:
    """
    Analyze relationships between motifs and prioritizations.
    
    Returns dictionary with motif-model associations.
    """
    print("\n" + "="*80)
    print("4. MOTIF-PRIORITIZATION RELATIONSHIP ANALYSIS")
    print("="*80)
    
    results = {
        'motif_model_associations': defaultdict(lambda: Counter()),
        'prioritized_variant_motifs': {
            'ref_specific': Counter(),
            'alt_specific': Counter()
        }
    }
    
    # Process variants that are prioritized
    priority_cols = {model: f'model_prioritized_by_any-{model}' for model in models 
                     if f'model_prioritized_by_any-{model}' in df.columns}
    
    for idx, row in df.iterrows():
        # Find which models prioritize this variant
        prioritized_models = []
        for model, col in priority_cols.items():
            val = row[col]
            if pd.notna(val) and (val == True or str(val).lower() == 'true' or val == 1 or str(val) == '1'):
                prioritized_models.append(model)
        
        if not prioritized_models:
            continue
        
        # Extract motifs and associate with models
        ref_motifs = parse_motif_string(row.get('ref_specific_motifs', ''))
        for motif_name, count, model in ref_motifs:
            if model and model in prioritized_models:
                results['motif_model_associations'][motif_name][model] += count
            results['prioritized_variant_motifs']['ref_specific'][motif_name] += count
        
        alt_motifs = parse_motif_string(row.get('alt_specific_motifs', ''))
        for motif_name, count, model in alt_motifs:
            if model and model in prioritized_models:
                results['motif_model_associations'][motif_name][model] += count
            results['prioritized_variant_motifs']['alt_specific'][motif_name] += count
    
    # Print summary
    print(f"\nTop motifs associated with prioritized variants:")
    print("Ref-specific:")
    for motif, count in results['prioritized_variant_motifs']['ref_specific'].most_common(10):
        print(f"  {motif}: {count}")
    
    print("\nAlt-specific:")
    for motif, count in results['prioritized_variant_motifs']['alt_specific'].most_common(10):
        print(f"  {motif}: {count}")
    
    return results


def analyze_model_specific_motifs(df: pd.DataFrame, models: List[str]) -> Dict:
    """
    Analyze motif patterns specific to each model.
    
    Returns dictionary with model-specific motif distributions.
    """
    print("\n" + "="*80)
    print("5. MODEL-SPECIFIC MOTIF ANALYSIS")
    print("="*80)
    
    model_motifs = defaultdict(lambda: {
        'ref_specific': Counter(),
        'alt_specific': Counter(),
        'n_variants': 0
    })
    
    # Process each variant
    for idx, row in df.iterrows():
        # Check which models prioritize this variant
        prioritized_models = []
        for model in models:
            col = f'model_prioritized_by_any-{model}'
            if col in df.columns:
                val = row[col]
                if pd.notna(val) and (val == True or str(val).lower() == 'true' or val == 1 or str(val) == '1'):
                    prioritized_models.append(model)
        
        # Extract motifs and associate with models
        ref_motifs = parse_motif_string(row.get('ref_specific_motifs', ''))
        for motif_name, count, model in ref_motifs:
            if model and model in prioritized_models:
                model_motifs[model]['ref_specific'][motif_name] += count
                model_motifs[model]['n_variants'] = max(model_motifs[model]['n_variants'], 
                                                       sum(1 for m in prioritized_models if m == model))
        
        alt_motifs = parse_motif_string(row.get('alt_specific_motifs', ''))
        for motif_name, count, model in alt_motifs:
            if model and model in prioritized_models:
                model_motifs[model]['alt_specific'][motif_name] += count
    
    # Count unique variants per model
    for model in models:
        col = f'model_prioritized_by_any-{model}'
        if col in df.columns:
            prioritized = df[col].apply(lambda x: (
                pd.notna(x) and 
                (x == True or str(x).lower() == 'true' or x == 1 or str(x) == '1')
            )).sum()
            if model in model_motifs:
                model_motifs[model]['n_variants'] = prioritized
    
    # Print summary
    print(f"\nModel-specific motif patterns:")
    for model in sorted(models):
        if model in model_motifs:
            n_vars = model_motifs[model]['n_variants']
            n_ref = len(model_motifs[model]['ref_specific'])
            n_alt = len(model_motifs[model]['alt_specific'])
            print(f"  {model}: {n_vars} variants, {n_ref} ref-specific, {n_alt} alt-specific motifs")
            if n_ref > 0:
                top_ref = model_motifs[model]['ref_specific'].most_common(3)
                print(f"    Top ref-specific: {', '.join([f'{m}({c})' for m, c in top_ref])}")
            if n_alt > 0:
                top_alt = model_motifs[model]['alt_specific'].most_common(3)
                print(f"    Top alt-specific: {', '.join([f'{m}({c})' for m, c in top_alt])}")
    
    return dict(model_motifs)


def create_visualizations(df: pd.DataFrame, motif_results: Dict, cluster_results: Dict, 
                         priority_results: Dict, motif_priority_results: Dict,
                         model_motif_results: Dict, output_dir: Path):
    """Create summary visualizations."""
    print("\n" + "="*80)
    print("6. CREATING VISUALIZATIONS")
    print("="*80)
    
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # 1. Motif frequency plots
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Ref-specific top motifs
    top_ref = dict(motif_results['ref_specific'].most_common(20))
    axes[0].barh(range(len(top_ref)), list(top_ref.values()))
    axes[0].set_yticks(range(len(top_ref)))
    axes[0].set_yticklabels(list(top_ref.keys()), fontsize=8)
    axes[0].set_xlabel('Count')
    axes[0].set_title('Top 20 Ref-Specific Motifs')
    axes[0].invert_yaxis()
    
    # Alt-specific top motifs
    top_alt = dict(motif_results['alt_specific'].most_common(20))
    axes[1].barh(range(len(top_alt)), list(top_alt.values()))
    axes[1].set_yticks(range(len(top_alt)))
    axes[1].set_yticklabels(list(top_alt.keys()), fontsize=8)
    axes[1].set_xlabel('Count')
    axes[1].set_title('Top 20 Alt-Specific Motifs')
    axes[1].invert_yaxis()
    
    # Unchanged top motifs
    top_unchanged = dict(motif_results['unchanged'].most_common(20))
    axes[2].barh(range(len(top_unchanged)), list(top_unchanged.values()))
    axes[2].set_yticks(range(len(top_unchanged)))
    axes[2].set_yticklabels(list(top_unchanged.keys()), fontsize=8)
    axes[2].set_xlabel('Count')
    axes[2].set_title('Top 20 Unchanged Motifs')
    axes[2].invert_yaxis()
    
    plt.tight_layout()
    plt.savefig(output_dir / 'motif_frequencies.png', dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved: {output_dir / 'motif_frequencies.png'}")
    
    # 2. Motifs per variant distribution
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].hist(motif_results['per_variant_counts']['ref'], bins=30, edgecolor='black')
    axes[0].set_xlabel('Number of Ref-Specific Motifs')
    axes[0].set_ylabel('Number of Variants')
    axes[0].set_title('Distribution of Ref-Specific Motifs per Variant')
    
    axes[1].hist(motif_results['per_variant_counts']['alt'], bins=30, edgecolor='black')
    axes[1].set_xlabel('Number of Alt-Specific Motifs')
    axes[1].set_ylabel('Number of Variants')
    axes[1].set_title('Distribution of Alt-Specific Motifs per Variant')
    
    axes[2].hist(motif_results['per_variant_counts']['unchanged'], bins=30, edgecolor='black')
    axes[2].set_xlabel('Number of Unchanged Motifs')
    axes[2].set_ylabel('Number of Variants')
    axes[2].set_title('Distribution of Unchanged Motifs per Variant')
    
    plt.tight_layout()
    plt.savefig(output_dir / 'motifs_per_variant.png', dpi=300, bbox_inches='tight')
    plt.close()
    print(f"Saved: {output_dir / 'motifs_per_variant.png'}")
    
    # 3. Prioritization by model
    if priority_results['model_prioritization_counts']:
        fig, ax = plt.subplots(figsize=(12, 8))
        models_sorted = sorted(priority_results['model_prioritization_counts'].items(), 
                              key=lambda x: x[1], reverse=True)
        models_short = [m.replace('KUN_FB_', '') for m, _ in models_sorted]
        counts = [c for _, c in models_sorted]
        
        ax.barh(range(len(models_short)), counts)
        ax.set_yticks(range(len(models_short)))
        ax.set_yticklabels(models_short, fontsize=9)
        ax.set_xlabel('Number of Prioritized Variants')
        ax.set_title('Variant Prioritization by Model')
        ax.invert_yaxis()
        
        plt.tight_layout()
        plt.savefig(output_dir / 'prioritization_by_model.png', dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved: {output_dir / 'prioritization_by_model.png'}")
    
    # 4. Variants by number of prioritizing models
    if priority_results['variants_by_n_models']:
        fig, ax = plt.subplots(figsize=(10, 6))
        n_models = sorted(priority_results['variants_by_n_models'].keys())
        counts = [priority_results['variants_by_n_models'][n] for n in n_models]
        
        ax.bar(n_models, counts, edgecolor='black')
        ax.set_xlabel('Number of Models Prioritizing Variant')
        ax.set_ylabel('Number of Variants')
        ax.set_title('Distribution of Variants by Number of Prioritizing Models')
        
        plt.tight_layout()
        plt.savefig(output_dir / 'variants_by_n_models.png', dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved: {output_dir / 'variants_by_n_models.png'}")
    
    # 5. Cluster analysis (if available)
    if cluster_results:
        fig, ax = plt.subplots(figsize=(12, 8))
        clusters = sorted(cluster_results.keys())
        cluster_sizes = [cluster_results[c]['n_variants'] for c in clusters]
        
        ax.bar(clusters, cluster_sizes, edgecolor='black')
        ax.set_xlabel('Cluster ID')
        ax.set_ylabel('Number of Variants')
        ax.set_title('Variant Distribution Across Clusters')
        
        plt.tight_layout()
        plt.savefig(output_dir / 'cluster_distribution.png', dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Saved: {output_dir / 'cluster_distribution.png'}")
    
    # 6. Model-specific top motifs heatmap
    if model_motif_results:
        # Collect top motifs per model
        all_motifs = set()
        for model_data in model_motif_results.values():
            all_motifs.update(model_data['ref_specific'].keys())
            all_motifs.update(model_data['alt_specific'].keys())
        
        # Focus on top motifs overall
        top_motifs = [m for m, _ in motif_results['ref_specific'].most_common(15)]
        top_motifs.extend([m for m, _ in motif_results['alt_specific'].most_common(15)])
        top_motifs = list(set(top_motifs))[:20]  # Top 20 unique motifs
        
        if top_motifs and model_motif_results:
            # Create matrix: models x motifs
            models_list = sorted([m for m in model_motif_results.keys() if model_motif_results[m]['n_variants'] > 0])
            if len(models_list) > 0:
                matrix = np.zeros((len(models_list), len(top_motifs)))
                
                for i, model in enumerate(models_list):
                    for j, motif in enumerate(top_motifs):
                        # Sum ref and alt counts
                        count = (model_motif_results[model]['ref_specific'].get(motif, 0) + 
                                model_motif_results[model]['alt_specific'].get(motif, 0))
                        matrix[i, j] = count
                
                # Create heatmap
                fig, ax = plt.subplots(figsize=(max(12, len(top_motifs)*0.5), max(8, len(models_list)*0.3)))
                models_short = [m.replace('KUN_FB_', '') for m in models_list]
                sns.heatmap(matrix, 
                           xticklabels=top_motifs,
                           yticklabels=models_short,
                           cmap='YlOrRd',
                           cbar_kws={'label': 'Motif Count'},
                           ax=ax)
                ax.set_xlabel('Motif')
                ax.set_ylabel('Model')
                ax.set_title('Model-Specific Motif Patterns (Top Motifs)')
                plt.xticks(rotation=45, ha='right')
                plt.yticks(rotation=0)
                
                plt.tight_layout()
                plt.savefig(output_dir / 'model_motif_heatmap.png', dpi=300, bbox_inches='tight')
                plt.close()
                print(f"Saved: {output_dir / 'model_motif_heatmap.png'}")
    
    print(f"\nAll visualizations saved to: {output_dir}")


def main():
    parser = argparse.ArgumentParser(
        description='EDA for ChromBPNet variant analysis',
        formatter_class=argparse.RawDescriptionHelpFormatter
    )
    parser.add_argument(
        '--input',
        type=str,
        default='snakemake/data/Broad neurodevelopmental and neuromuscular disorders.Fetal Brain.clustered.tsv',
        help='Path to clustered TSV file'
    )
    parser.add_argument(
        '--output-dir',
        type=str,
        default='snakemake/eda_output',
        help='Output directory for plots and results'
    )
    parser.add_argument(
        '--include-hdma',
        action='store_true',
        help='Include HDMA models in analysis (in addition to Fetal Brain)'
    )
    
    args = parser.parse_args()
    
    # Convert to Path objects
    input_path = Path(args.input)
    output_dir = Path(args.output_dir)
    
    if not input_path.exists():
        print(f"Error: Input file not found: {input_path}")
        return 1
    
    # Load data
    df = load_data(str(input_path))
    
    # Extract models
    models = extract_fetal_brain_models(df)
    print(f"\nFound {len(models)} Fetal Brain models")
    
    if args.include_hdma:
        hdma_models = [col.replace('model_prioritized_by_any-', '') 
                      for col in df.columns 
                      if col.startswith('model_prioritized_by_any-KUN_HDMA')]
        models.extend(hdma_models)
        print(f"Added {len(hdma_models)} HDMA models")
    
    # Run analyses
    motif_results = analyze_motif_distributions(df)
    cluster_results = analyze_cluster_motifs(df, motif_results)
    priority_results = analyze_prioritization(df, models)
    motif_priority_results = analyze_motif_prioritization_relationships(df, models)
    model_motif_results = analyze_model_specific_motifs(df, models)
    
    # Create visualizations
    create_visualizations(df, motif_results, cluster_results, priority_results, 
                         motif_priority_results, model_motif_results, output_dir)
    
    # Save summary statistics
    summary = {
        'n_variants': len(df),
        'n_models': len(models),
        'motif_summary': {
            'n_ref_specific': len(motif_results['ref_specific']),
            'n_alt_specific': len(motif_results['alt_specific']),
            'n_unchanged': len(motif_results['unchanged']),
            'top_ref_motifs': dict(motif_results['ref_specific'].most_common(20)),
            'top_alt_motifs': dict(motif_results['alt_specific'].most_common(20)),
            'top_unchanged_motifs': dict(motif_results['unchanged'].most_common(20))
        },
        'prioritization_summary': {
            'total_prioritizations': sum(priority_results['model_prioritization_counts'].values()),
            'variants_by_n_models': dict(priority_results['variants_by_n_models'])
        }
    }
    
    summary_file = output_dir / 'summary_statistics.json'
    with open(summary_file, 'w') as f:
        json.dump(summary, f, indent=2)
    print(f"\nSaved summary statistics to: {summary_file}")
    
    print("\n" + "="*80)
    print("EDA COMPLETE")
    print("="*80)
    
    return 0


if __name__ == '__main__':
    exit(main())

