#!/usr/bin/env python3
"""
Generate an HTML summary page for a variant that displays all plots in an easily-readable format.

Usage:
    python generate_variant_summary_html.py <variant_dir> <variant_id> <clustered_tsv> <model_dataset> <output_html>
"""

import sys
import os
import pandas as pd
from pathlib import Path
from typing import List, Optional, Dict


def get_prioritized_models_from_tsv(clustered_tsv: Path, variant_id: str, model_dataset: str) -> List[str]:
    """Get prioritized models for a variant from clustered.tsv.
    
    Reads the models_prioritized_by_any-{model_dataset} column and parses
    the semicolon-separated format to extract model names.
    
    Parameters:
    -----------
    clustered_tsv : Path
        Path to clustered.tsv file
    variant_id : str
        Variant ID to look up
    model_dataset : str
        Model dataset name (e.g., "Fetal Brain")
    
    Returns:
    --------
    List[str]
        List of model names that prioritize this variant
    """
    if not clustered_tsv.exists():
        print(f"Warning: clustered.tsv does not exist: {clustered_tsv}")
        return []
    
    try:
        df = pd.read_csv(clustered_tsv, sep='\t')
    except Exception as e:
        print(f"Error reading clustered.tsv: {e}")
        return []
    
    variant_row = df[df['variant_id'] == variant_id]
    
    if variant_row.empty:
        print(f"Warning: Variant {variant_id} not found in clustered.tsv")
        return []
    
    models = []
    
    # First, try the aggregated column for this model_dataset
    aggregated_col = f'models_prioritized_by_any-{model_dataset}'
    if aggregated_col in df.columns:
        models_str = variant_row[aggregated_col].iloc[0]
        if pd.notna(models_str) and str(models_str).strip():
            # Parse: ";MODEL1(score);MODEL2(score);..."
            model_entries = [m.strip() for m in str(models_str).split(';') if m.strip()]
            for entry in model_entries:
                if '(' in entry:
                    model = entry.split('(')[0]
                    if model:
                        models.append(model)
    
    # If no models found, check all models_prioritized_by_any-* columns
    # (fallback for different column naming)
    if not models:
        for col in df.columns:
            if col.startswith('models_prioritized_by_any-'):
                models_str = variant_row[col].iloc[0]
                if pd.notna(models_str) and str(models_str).strip():
                    model_entries = [m.strip() for m in str(models_str).split(';') if m.strip()]
                    for entry in model_entries:
                        if '(' in entry:
                            model = entry.split('(')[0]
                            if model:
                                models.append(model)
    
    return list(set(models))  # Remove duplicates


def find_expected_plots(variant_dir: Path, variant_id: str, clustered_tsv: Path, model_dataset: str) -> dict:
    """Find all expected plot paths for a variant, regardless of whether they exist.
    
    Returns:
        dict with keys: 'barplot', 'barplot_superset', 'scatterplot', 'scatterplot_superset', 'profiles' (list)
    """
    plots = {
        'barplot': variant_dir / "01-model-specificity-barplot.png",
        'barplot_superset': variant_dir / "01-model-specificity-barplot-superset.png",
        'scatterplot': variant_dir / "02-model-scatterplot.html",
        'scatterplot_superset': variant_dir / "02-model-scatterplot-superset.html",
        'profiles': []
    }
    
    # Get prioritized models from clustered.tsv
    prioritized_models = get_prioritized_models_from_tsv(clustered_tsv, variant_id, model_dataset)
    
    for model in prioritized_models:
        profile_path = variant_dir / f"03-profile-{model}.png"
        plots['profiles'].append({
            'path': profile_path,
            'model': model,
            'exists': profile_path.exists()
        })
    
    return plots


def find_plots(variant_dir: Path) -> dict:
    """Find all plot files in the variant directory.
    
    Returns:
        dict with keys: 'barplot', 'barplot_superset', 'scatterplot', 'scatterplot_superset', 'profiles' (list)
    """
    plots = {
        'barplot': None,
        'barplot_superset': None,
        'scatterplot': None,
        'scatterplot_superset': None,
        'profiles': []
    }
    
    # Find cluster-level barplot
    barplot_path = variant_dir / "01-model-specificity-barplot.png"
    if barplot_path.exists():
        plots['barplot'] = barplot_path
    
    # Find superset-level barplot
    barplot_superset_path = variant_dir / "01-model-specificity-barplot-superset.png"
    if barplot_superset_path.exists():
        plots['barplot_superset'] = barplot_superset_path
    
    # Find cluster-level scatterplot
    scatterplot_path = variant_dir / "02-model-scatterplot.html"
    if scatterplot_path.exists():
        plots['scatterplot'] = scatterplot_path
    
    # Find superset-level scatterplot
    scatterplot_superset_path = variant_dir / "02-model-scatterplot-superset.html"
    if scatterplot_superset_path.exists():
        plots['scatterplot_superset'] = scatterplot_superset_path
    
    # Find all profile plots (03-profile-*.png)
    for profile_file in sorted(variant_dir.glob("03-profile-*.png")):
        # Extract model name from filename (03-profile-MODEL_NAME.png)
        model_name = profile_file.stem.replace("03-profile-", "")
        plots['profiles'].append({
            'path': profile_file,
            'model': model_name
        })
    
    return plots


def generate_html(variant_dir: Path, variant_id: str, output_html: Path, plots: dict):
    """Generate HTML summary page for the variant."""
    
    # Get relative paths for images (relative to output HTML location)
    output_dir = output_html.parent
    
    def get_relative_path(file_path: Path) -> str:
        """Get relative path from output HTML to the file."""
        try:
            return os.path.relpath(file_path, output_dir)
        except ValueError:
            # If on different drives (Windows), use absolute path
            return str(file_path)
    
    html_content = f"""<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Variant Summary: {variant_id}</title>
    <style>
        body {{
            font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
            line-height: 1.6;
            color: #333;
            max-width: 1400px;
            margin: 0 auto;
            padding: 20px;
            background-color: #f5f5f5;
        }}
        .header {{
            background-color: #2c3e50;
            color: white;
            padding: 20px;
            border-radius: 8px;
            margin-bottom: 30px;
        }}
        .header h1 {{
            margin: 0;
            font-size: 24px;
        }}
        .header .variant-id {{
            font-family: 'Courier New', monospace;
            font-size: 18px;
            margin-top: 10px;
            opacity: 0.9;
        }}
        .section {{
            background-color: white;
            padding: 25px;
            margin-bottom: 25px;
            border-radius: 8px;
            box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        }}
        .section h2 {{
            margin-top: 0;
            color: #2c3e50;
            border-bottom: 2px solid #3498db;
            padding-bottom: 10px;
        }}
        .plot-container {{
            margin: 20px 0;
            text-align: center;
        }}
        .plot-container img {{
            max-width: 100%;
            height: auto;
            border: 1px solid #ddd;
            border-radius: 4px;
            box-shadow: 0 2px 8px rgba(0,0,0,0.1);
        }}
        .plot-container iframe {{
            width: 100%;
            height: 600px;
            border: 1px solid #ddd;
            border-radius: 4px;
            box-shadow: 0 2px 8px rgba(0,0,0,0.1);
        }}
        .profile-grid {{
            display: grid;
            grid-template-columns: repeat(auto-fit, minmax(500px, 1fr));
            gap: 20px;
            margin-top: 20px;
        }}
        .profile-item {{
            background-color: #f8f9fa;
            padding: 15px;
            border-radius: 4px;
            border: 1px solid #e0e0e0;
        }}
        .profile-item h3 {{
            margin-top: 0;
            color: #2c3e50;
            font-size: 16px;
        }}
        .profile-item img {{
            width: 100%;
            height: auto;
        }}
        .no-plots {{
            color: #999;
            font-style: italic;
            text-align: center;
            padding: 40px;
        }}
        .plot-label {{
            font-weight: 600;
            color: #555;
            margin-bottom: 10px;
            font-size: 14px;
        }}
    </style>
</head>
<body>
    <div class="header">
        <h1>Variant Summary</h1>
        <div class="variant-id">{variant_id}</div>
    </div>
"""
    
    # Add barplot section (cluster-level)
    html_content += '    <div class="section">\n'
    html_content += '        <h2>Model Specificity Barplot (Cluster-level)</h2>\n'
    barplot_path = plots.get('barplot')
    if barplot_path:
        # Check if file exists (could be Path object or None)
        if isinstance(barplot_path, Path) and barplot_path.exists():
            rel_path = get_relative_path(barplot_path)
            html_content += f'''        <div class="plot-container">
            <div class="plot-label">Model prioritization across tissues/organs in the current cluster/model dataset</div>
            <img src="{rel_path}" alt="Model Specificity Barplot">
        </div>
'''
        else:
            # File doesn't exist yet, but include it anyway (browser will show broken link until file appears)
            rel_path = get_relative_path(barplot_path)
            html_content += f'''        <div class="plot-container">
            <div class="plot-label">Model prioritization across tissues/organs in the current cluster/model dataset</div>
            <img src="{rel_path}" alt="Model Specificity Barplot" onerror="this.style.display='none'; this.nextElementSibling.style.display='block';">
            <div style="display:none; color: #999; font-style: italic; padding: 20px;">Plot will be available after generation</div>
        </div>
'''
    else:
        html_content += '        <div class="no-plots">No cluster-level barplot available</div>\n'
    html_content += '    </div>\n\n'
    
    # Add superset barplot section
    html_content += '    <div class="section">\n'
    html_content += '        <h2>Model Specificity Barplot (Superset-level)</h2>\n'
    barplot_superset_path = plots.get('barplot_superset')
    if barplot_superset_path:
        if isinstance(barplot_superset_path, Path) and barplot_superset_path.exists():
            rel_path = get_relative_path(barplot_superset_path)
            html_content += f'''        <div class="plot-container">
            <div class="plot-label">Model prioritization across tissues/organs in the superset (broader context)</div>
            <img src="{rel_path}" alt="Model Specificity Barplot (Superset)">
        </div>
'''
        else:
            rel_path = get_relative_path(barplot_superset_path)
            html_content += f'''        <div class="plot-container">
            <div class="plot-label">Model prioritization across tissues/organs in the superset (broader context)</div>
            <img src="{rel_path}" alt="Model Specificity Barplot (Superset)" onerror="this.style.display='none'; this.nextElementSibling.style.display='block';">
            <div style="display:none; color: #999; font-style: italic; padding: 20px;">Plot will be available after generation</div>
        </div>
'''
    else:
        html_content += '        <div class="no-plots">No superset-level barplot available</div>\n'
    html_content += '    </div>\n\n'
    
    # Add scatterplot section (cluster-level)
    html_content += '    <div class="section">\n'
    html_content += '        <h2>Model Scatterplot (Cluster-level)</h2>\n'
    scatterplot_path = plots.get('scatterplot')
    if scatterplot_path:
        if isinstance(scatterplot_path, Path) and scatterplot_path.exists():
            rel_path = get_relative_path(scatterplot_path)
            html_content += f'''        <div class="plot-container">
            <div class="plot-label">Interactive scatterplot showing variant scores across models in the current cluster/model dataset</div>
            <iframe src="{rel_path}" frameborder="0"></iframe>
        </div>
'''
        else:
            rel_path = get_relative_path(scatterplot_path)
            html_content += f'''        <div class="plot-container">
            <div class="plot-label">Interactive scatterplot showing variant scores across models in the current cluster/model dataset</div>
            <iframe src="{rel_path}" frameborder="0" onerror="this.style.display='none'; this.nextElementSibling.style.display='block';"></iframe>
            <div style="display:none; color: #999; font-style: italic; padding: 20px;">Plot will be available after generation</div>
        </div>
'''
    else:
        html_content += '        <div class="no-plots">No cluster-level scatterplot available</div>\n'
    html_content += '    </div>\n\n'
    
    # Add superset scatterplot section
    html_content += '    <div class="section">\n'
    html_content += '        <h2>Model Scatterplot (Superset-level)</h2>\n'
    scatterplot_superset_path = plots.get('scatterplot_superset')
    if scatterplot_superset_path:
        if isinstance(scatterplot_superset_path, Path) and scatterplot_superset_path.exists():
            rel_path = get_relative_path(scatterplot_superset_path)
            html_content += f'''        <div class="plot-container">
            <div class="plot-label">Interactive scatterplot showing variant scores across all models in the superset (broader context)</div>
            <iframe src="{rel_path}" frameborder="0"></iframe>
        </div>
'''
        else:
            rel_path = get_relative_path(scatterplot_superset_path)
            html_content += f'''        <div class="plot-container">
            <div class="plot-label">Interactive scatterplot showing variant scores across all models in the superset (broader context)</div>
            <iframe src="{rel_path}" frameborder="0" onerror="this.style.display='none'; this.nextElementSibling.style.display='block';"></iframe>
            <div style="display:none; color: #999; font-style: italic; padding: 20px;">Plot will be available after generation</div>
        </div>
'''
    else:
        html_content += '        <div class="no-plots">No superset-level scatterplot available</div>\n'
    html_content += '    </div>\n\n'
    
    # Add profiles section
    html_content += '    <div class="section">\n'
    html_content += '        <h2>Profile Plots</h2>\n'
    if plots['profiles']:
        html_content += '        <div class="profile-grid">\n'
        for profile in plots['profiles']:
            profile_path = profile['path']
            rel_path = get_relative_path(profile_path)
            exists = profile.get('exists', False)
            if exists:
                html_content += f'''            <div class="profile-item">
                <h3>{profile['model']}</h3>
                <img src="{rel_path}" alt="Profile plot for {profile['model']}">
            </div>
'''
            else:
                html_content += f'''            <div class="profile-item">
                <h3>{profile['model']}</h3>
                <img src="{rel_path}" alt="Profile plot for {profile['model']}" onerror="this.style.display='none'; this.nextElementSibling.style.display='block';">
                <div style="display:none; color: #999; font-style: italic; padding: 20px;">Plot will be available after generation</div>
            </div>
'''
        html_content += '        </div>\n'
    else:
        html_content += '        <div class="no-plots">No profile plots available</div>\n'
    html_content += '    </div>\n\n'
    
    html_content += """</body>
</html>
"""
    
    # Write HTML file
    output_html.parent.mkdir(parents=True, exist_ok=True)
    with open(output_html, 'w') as f:
        f.write(html_content)
    
    print(f"✓ Generated HTML summary: {output_html}")


def main():
    if len(sys.argv) != 6:
        print("Usage: python generate_variant_summary_html.py <variant_dir> <variant_id> <clustered_tsv> <model_dataset> <output_html>")
        sys.exit(1)
    
    variant_dir = Path(sys.argv[1])
    variant_id = sys.argv[2]
    clustered_tsv = Path(sys.argv[3])
    model_dataset = sys.argv[4]
    output_html = Path(sys.argv[5])
    
    # Create variant directory if it doesn't exist (plots may not be generated yet)
    variant_dir.mkdir(parents=True, exist_ok=True)
    
    # Find all expected plots (even if they don't exist yet)
    plots = find_expected_plots(variant_dir, variant_id, clustered_tsv, model_dataset)
    
    # Generate HTML
    generate_html(variant_dir, variant_id, output_html, plots)
    
    # Print summary
    num_profiles = len(plots['profiles'])
    print(f"Summary: Expected {num_profiles} profile plot(s)")
    for profile in plots['profiles']:
        status = "✓" if profile.get('exists', False) else "(pending)"
        print(f"  - {profile['model']}: {status}")
    
    barplot_exists = isinstance(plots.get('barplot'), Path) and plots['barplot'].exists()
    barplot_superset_exists = isinstance(plots.get('barplot_superset'), Path) and plots['barplot_superset'].exists()
    scatterplot_exists = isinstance(plots.get('scatterplot'), Path) and plots['scatterplot'].exists()
    scatterplot_superset_exists = isinstance(plots.get('scatterplot_superset'), Path) and plots['scatterplot_superset'].exists()
    
    print(f"  - Barplot: {'✓' if barplot_exists else '(pending)'}")
    print(f"  - Barplot (superset): {'✓' if barplot_superset_exists else '(pending)'}")
    print(f"  - Scatterplot: {'✓' if scatterplot_exists else '(pending)'}")
    print(f"  - Scatterplot (superset): {'✓' if scatterplot_superset_exists else '(pending)'}")


if __name__ == "__main__":
    main()

