#!/usr/bin/env python3
"""
Generate combined figure with motif scans (finemo or FIMO) and ChromBPNet predictions.

Usage:
    python regenerate_plot_svg.py <variant_id> <model_name> <output_format>
    # Finemo only (top panel + SHAP overlays from finemo):
    python regenerate_plot_svg.py variant_1 thyroid_gland__ENCSR474XFV svg --motifs-tsv /path/to/finemo.tsv
    # FIMO only (top panel from JASPAR FIMO scan):
    python regenerate_plot_svg.py variant_1 thyroid_gland__ENCSR474XFV svg --jaspar-meme /path/to/jaspar.meme
    # Both: top panel = FIMO, SHAP overlays = finemo (they occupy different places in the figure):
    python regenerate_plot_svg.py variant_1 thyroid_gland__ENCSR474XFV svg --motifs-tsv /path/to/finemo.tsv --jaspar-meme /path/to/jaspar.meme
"""

import argparse
import re
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
import logomaker
import pyfaidx
import subprocess
import tempfile
import os
from pathlib import Path
import sys

# For loading predictions/SHAP from files (chr15 thyroid ENCSR474XFV)
try:
    import h5py
    import deepdish
    import torch
    _FILE_LOADING_AVAILABLE = True
except ImportError:
    _FILE_LOADING_AVAILABLE = False

# Configure matplotlib to save text as text (not paths) in SVG for selectability
matplotlib.rcParams['svg.fonttype'] = 'none'

# Add varbook to path if needed
sys.path.insert(0, str(Path(__file__).parent / 'varbook'))

from varbook.plot.variant.profiles import (
    _extract_variant_sequences,
    _plot_profile,
    _plot_shap,
    _plotter_shap,
    _resolve_model_metadata,
    _process_single_model_with_folds,
    _load_motif_mapping,
    _parse_finemo_motifs,
    _add_motif_overlays,
    _build_motif_to_color,
)

# ----------------------------------------------------------------------------
# File-based predictions/SHAP for chr15 thyroid ENCSR474XFV (from prototype)
# Paths match prototype: .../gregor-luria-h5-prio-uwcrdr/... and .../gregor-luria-shap-uwcrdr/...
# ----------------------------------------------------------------------------
FILE_LOADING_VARIANT_ID = "chr15_thyroid"
FILE_LOADING_MODEL_MATCH = "ENCSR474XFV"  # model_name contains this
UWCRDR_VARIANTS_PATH = "/oak/stanford/groups/akundaje/airanman/projects/lab/rare-disease-manuscript/variants/from_other_projects/uwcrdr.wo_structural.tsv"
SCORES_BASE = "/oak/stanford/groups/akundaje/airanman/nautilus-sync/gregor-luria-h5-prio-uwcrdr/pvc/outputs/outputs/gregor-luria-scoring-uwcrdr/variant_scoring"
SHAP_BASE = "/oak/stanford/groups/akundaje/airanman/nautilus-sync/gregor-luria-shap-uwcrdr/pvc/outputs/outputs/gregor-luria-uwcrdr-shap/variant_scoring"


def _load_scores_from_files(file_prefix):
    """Load variant prediction scores from precomputed files, averaging across folds.
    Same layout as prototype: one file per fold at {file_prefix}/fold_{f}/.variant_predictions.h5.
    Returns (preds1, preds2, counts1, counts2) for profiles and counts (for logfc).
    """
    def _softmax(x, temp=1):
        norm_x = x - np.mean(x, axis=1, keepdims=True)
        return np.exp(temp * norm_x) / np.sum(np.exp(temp * norm_x), axis=1, keepdims=True)

    fold_preds1, fold_preds2 = [], []
    fold_counts1, fold_counts2 = [], []
    for f in range(5):
        path = f"{file_prefix}/fold_{f}/.variant_predictions.h5"
        if not os.path.exists(path):
            continue
        print(f"   Loaded scores: {path}")
        with h5py.File(path, 'r') as h5file:
            c1 = np.array(h5file["observed"]["allele1_pred_counts"][:])
            p1 = np.array(h5file["observed"]["allele1_pred_profiles"][:])
            c2 = np.array(h5file["observed"]["allele2_pred_counts"][:])
            p2 = np.array(h5file["observed"]["allele2_pred_profiles"][:])
            fold_preds1.append(c1 * _softmax(p1))
            fold_preds2.append(c2 * _softmax(p2))
            fold_counts1.append(c1)
            fold_counts2.append(c2)
    if not fold_preds1:
        raise FileNotFoundError(f"No fold files under {file_prefix}")
    avg_preds1 = np.mean(fold_preds1, axis=0)
    avg_preds2 = np.mean(fold_preds2, axis=0)
    avg_c1 = np.mean(fold_counts1, axis=0)
    avg_c2 = np.mean(fold_counts2, axis=0)
    return avg_preds1, avg_preds2, avg_c1, avg_c2


def _load_shaps_from_files(file_prefix):
    """Load SHAP values from precomputed files, averaging across folds.
    Same layout as prototype: one file per fold at {file_prefix}/fold_{f}/.variant_shap.counts.h5,
    key ["projected_shap"]["seq"], first half = allele1, second half = allele2, then transpose (0,2,1).
    """
    shaps1, shaps2 = None, None
    n_folds = 0
    for f in range(5):
        path = f"{file_prefix}/fold_{f}/.variant_shap.counts.h5"
        if not os.path.exists(path):
            continue
        print(f"   Loaded SHAP:   {path}")
        fold_shaps = deepdish.io.load(path)
        shaps_f = fold_shaps["projected_shap"]["seq"]
        N = shaps_f.shape[0] // 2
        if shaps1 is None:
            shaps1 = shaps_f[:N, :, :].copy()
            shaps2 = shaps_f[N:, :, :].copy()
        else:
            shaps1 += shaps_f[:N, :, :]
            shaps2 += shaps_f[N:, :, :]
        n_folds += 1
    if shaps1 is None or n_folds == 0:
        raise FileNotFoundError(f"No fold SHAP files under {file_prefix}")
    shaps1 = np.transpose(shaps1 / n_folds, (0, 2, 1))
    shaps2 = np.transpose(shaps2 / n_folds, (0, 2, 1))
    return shaps1, shaps2


def _variant_idx_from_uwcrdr(chrom, pos, ref, alt, uwcrdr_tsv_path=UWCRDR_VARIANTS_PATH):
    """Return row index in uwcrdr variants file matching (chrom, pos, ref, alt)."""
    df = pd.read_csv(uwcrdr_tsv_path, sep="\t", names=['chr', 'pos', 'ref', 'alt', 'variant_index'])
    pos = int(pos)
    ref = str(ref) if ref != "-" else ""
    alt = str(alt) if alt != "-" else ""
    chrom_s = str(chrom).replace("chr", "") if str(chrom).startswith("chr") else str(chrom)
    df_chr = df['chr'].astype(str).str.replace("chr", "", regex=False)
    match = (df_chr == chrom_s) & (df['pos'].astype(int) == pos) & (df['ref'].astype(str) == ref) & (df['alt'].astype(str) == alt)
    if not match.any():
        match = (df['chr'].astype(str) == str(chrom)) & (df['pos'].astype(int) == pos) & (df['ref'].astype(str) == ref) & (df['alt'].astype(str) == alt)
    if not match.any():
        raise ValueError(f"Variant {chrom}:{pos}:{ref}:{alt} not found in {uwcrdr_tsv_path}")
    return int(np.where(match)[0][0])  # 0-based row index for array indexing (preds1[idx], etc.)


# ============================================================================
# FIMO SCANNING
# ============================================================================

def extract_sequence_for_fimo(chrom, pos, ref, alt, genome_fa, window_size=300):
    """
    Extract sequence around variant for FIMO scanning.
    
    Parameters:
    -----------
    chrom : str
        Chromosome
    pos : int
        1-based variant position
    ref : str
        Reference allele
    alt : str
        Alternate allele
    genome_fa : str
        Path to genome FASTA
    window_size : int
        Window size around variant
    
    Returns:
    --------
    sequence : str
        DNA sequence around variant (reference allele)
    start_pos : int
        Genomic start position of sequence (0-based)
    variant_pos_in_seq : int
        Position of variant within sequence (0-based)
    """
    genome = pyfaidx.Fasta(genome_fa)
    chrom = str(chrom)
    pos_0based = pos - 1
    
    # Extract window around variant
    half_window = window_size // 2
    seq_start = pos_0based - half_window
    seq_end = pos_0based + half_window
    
    sequence = str(genome[chrom][seq_start:seq_end].seq).upper()
    variant_pos_in_seq = half_window  # Variant is at center
    
    return sequence, seq_start, variant_pos_in_seq


def run_fimo_scan(sequence, sequence_name, jaspar_meme_file, output_dir, pvalue_threshold=1e-2):
    """
    Run FIMO scan on sequence using JASPAR database.
    
    Parameters:
    -----------
    sequence : str
        DNA sequence to scan
    sequence_name : str
        Name for sequence (used in FIMO output)
    jaspar_meme_file : str
        Path to JASPAR CORE database in MEME format
    output_dir : str
        Directory for FIMO output
    pvalue_threshold : float
        P-value threshold for FIMO (default: 1e-2)
    
    Returns:
    --------
    fimo_df : pd.DataFrame
        FIMO results with columns: pattern_name, sequence_name, start, stop, 
        strand, score, p-value, q-value, matched_sequence
    """
    # Create temporary FASTA file
    fasta_file = os.path.join(output_dir, 'sequence.fa')
    with open(fasta_file, 'w') as f:
        f.write(f">{sequence_name}\n{sequence}\n")
    
    # Run FIMO
    # FIMO creates output in --oc directory
    # Note: --o flag might not work as expected, FIMO may create its own directory structure
    cmd = [
        'fimo',
        '--thresh', str(pvalue_threshold),
        '--oc', output_dir,
        jaspar_meme_file,
        fasta_file
    ]
    
    print(f"Running FIMO: {' '.join(cmd)}")
    result = subprocess.run(cmd, capture_output=True, text=True)
    
    if result.returncode != 0:
        print(f"FIMO stderr: {result.stderr}")
        print(f"FIMO stdout: {result.stdout}")
        raise RuntimeError(f"FIMO failed with return code {result.returncode}: {result.stderr}")
    
    # FIMO creates output in --oc directory
    # It typically creates a subdirectory named 'fimo' with fimo.tsv inside
    # Check multiple possible output locations
    possible_outputs = [
        os.path.join(output_dir, 'fimo', 'fimo.tsv'),  # Most common: subdirectory
        os.path.join(output_dir, 'fimo.tsv'),  # Direct in output_dir
        os.path.join(output_dir, 'fimo', 'fimo.txt'),  # Sometimes .txt
        os.path.join(output_dir, 'fimo.txt'),  # Alternative .txt
    ]
    
    fimo_tsv = None
    for path in possible_outputs:
        if os.path.exists(path):
            fimo_tsv = path
            print(f"Found FIMO output at: {fimo_tsv}")
            break
    
    if fimo_tsv is None:
        # List what files were actually created for debugging
        print(f"\nDebugging FIMO output location:")
        print(f"Output directory: {output_dir}")
        if os.path.exists(output_dir):
            print(f"Contents of output_dir: {os.listdir(output_dir)}")
            # Check for any subdirectories
            for item in os.listdir(output_dir):
                item_path = os.path.join(output_dir, item)
                if os.path.isdir(item_path):
                    print(f"  Subdirectory '{item}' contents: {os.listdir(item_path)}")
        raise FileNotFoundError(
            f"FIMO output not found. Checked: {possible_outputs}\n"
            f"FIMO stdout: {result.stdout}\n"
            f"FIMO stderr: {result.stderr}"
        )
    
    # Read FIMO TSV (skip header lines starting with #)
    # First, let's check what the actual columns are
    with open(fimo_tsv, 'r') as f:
        first_lines = [f.readline() for _ in range(5)]
        print(f"First few lines of FIMO output:")
        for i, line in enumerate(first_lines):
            print(f"  Line {i+1}: {line.strip()[:100]}")
    
    fimo_df = pd.read_csv(fimo_tsv, sep='\t', comment='#')
    
    print(f"FIMO output columns: {list(fimo_df.columns)}")
    print(f"FIMO output shape: {fimo_df.shape}")
    if len(fimo_df) > 0:
        print(f"First row: {fimo_df.iloc[0].to_dict()}")
    
    # Standardize column names - FIMO uses different column names
    # Common FIMO column names:
    # - 'motif_id' or 'pattern name' or 'motif_alt_id' for pattern
    # - 'sequence_name' or 'sequence-id' for sequence
    # - 'start' and 'stop' for positions
    # - 'strand' for orientation
    # - 'score' for score
    # - 'p-value' or 'pvalue' for p-value
    # - 'q-value' or 'qvalue' for q-value
    # - 'matched sequence' or 'matched_sequence' for sequence
    
    # FIMO uses 'motif_alt_id' for the TF name (e.g., "ZNF530", "CUP9")
    # and 'motif_id' for the JASPAR ID (e.g., "MA1981.2")
    # We want to use motif_alt_id as the pattern_name for TF names
    
    # Map columns - be careful not to create duplicates
    if 'motif_alt_id' in fimo_df.columns:
        # Use motif_alt_id as pattern_name (TF name)
        fimo_df = fimo_df.rename(columns={'motif_alt_id': 'pattern_name'})
    elif 'pattern name' in fimo_df.columns:
        fimo_df = fimo_df.rename(columns={'pattern name': 'pattern_name'})
    elif 'motif_id' in fimo_df.columns:
        # Fallback to motif_id if motif_alt_id not available
        fimo_df = fimo_df.rename(columns={'motif_id': 'pattern_name'})
    
    # Map other columns
    column_mapping = {
        'sequence name': 'sequence_name',
        'sequence-id': 'sequence_name',
        'sequence_id': 'sequence_name',
        'end': 'stop',
        'pvalue': 'p-value',
        'p_value': 'p-value',
        'qvalue': 'q-value',
        'q_value': 'q-value',
        'matched sequence': 'matched_sequence'
    }
    
    # Only rename columns that exist and aren't already renamed
    for old_col, new_col in column_mapping.items():
        if old_col in fimo_df.columns and new_col not in fimo_df.columns:
            fimo_df = fimo_df.rename(columns={old_col: new_col})
    
    # Check what we have after mapping
    print(f"Columns after mapping: {list(fimo_df.columns)}")
    
    # Ensure required columns exist
    required_cols = ['pattern_name', 'start', 'stop', 'p-value']
    missing_cols = [col for col in required_cols if col not in fimo_df.columns]
    if missing_cols:
        raise ValueError(
            f"FIMO output missing required columns: {missing_cols}\n"
            f"Available columns: {list(fimo_df.columns)}\n"
            f"Please check FIMO output format."
        )
    
    print(f"Found {len(fimo_df)} FIMO hits")
    return fimo_df


def finemo_motifs_to_dataframe(ref_motifs, alt_motifs, gtex_thyroid_tpm=None, exclude_low_tpm=False):
    """
    Convert finemo ref/alt motif lists (from _parse_finemo_motifs) to a DataFrame
    compatible with plot_fimo_scans and add_foxe1_highlighting.

    Merges ref and alt hits into one track; uses tf_name, start_rel, end_rel, color.
    Optionally applies GTEx thyroid TPM filter/colors when gtex_thyroid_tpm is provided.
    """
    rows = []
    for m in (ref_motifs or []) + (alt_motifs or []):
        motif_name = m.get('motif', '')
        if motif_name.startswith('pos_patterns.'):
            motif_name = motif_name[len('pos_patterns.'):]
        elif motif_name.startswith('neg_patterns.'):
            motif_name = motif_name[len('neg_patterns.'):]
        tf_name = str(motif_name).upper()
        rows.append({
            'tf_name': tf_name,
            'start_rel': m['start_plot'],
            'end_rel': m['end_plot'],
            'score': m.get('score', 1.0),
            'color': 'navy',
            'pattern_name': motif_name,
        })
    if not rows:
        return pd.DataFrame(columns=['tf_name', 'start_rel', 'end_rel', 'color'])
    df = pd.DataFrame(rows)
    if gtex_thyroid_tpm is not None:
        def tpm_for_tf(tf_name):
            u = str(tf_name).upper()
            if u in gtex_thyroid_tpm.index:
                return gtex_thyroid_tpm.loc[u]
            return np.nan
        df['thyroid_tpm'] = df['tf_name'].map(tpm_for_tf)
        df = df.dropna(subset=['thyroid_tpm'])
        df = df[df['thyroid_tpm'] >= 1.0]
        if exclude_low_tpm:
            df = df[df['thyroid_tpm'] >= 5.0]
        def color_text_font_alpha_by_tpm(tpm):
            if tpm > 100:
                return ('#2171B5', 'white', 11, 0.9)
            if tpm >= 35:
                return ('#6BAED6', 'black', 10, 0.76)
            if tpm >= 15:
                return ('#9ECAE1', 'black', 9.5, 0.62)
            if tpm >= 5:
                return ('#C6DBEF', 'black', 9, 0.56)
            return ('#DEEBF7', 'black', 9, 0.48)
        df['color'] = df['thyroid_tpm'].apply(lambda t: color_text_font_alpha_by_tpm(t)[0])
        df['text_color'] = df['thyroid_tpm'].apply(lambda t: color_text_font_alpha_by_tpm(t)[1])
        df['motif_fontsize'] = df['thyroid_tpm'].apply(lambda t: color_text_font_alpha_by_tpm(t)[2])
        df['motif_alpha'] = df['thyroid_tpm'].apply(lambda t: color_text_font_alpha_by_tpm(t)[3])
    else:
        df['text_color'] = 'black'
        df['motif_alpha'] = 1.0
    return df


def process_fimo_results(fimo_df, variant_pos_genomic, seq_start_genomic, window_size=300):
    """
    Process FIMO results and convert to plot coordinates.
    
    Parameters:
    -----------
    fimo_df : pd.DataFrame
        FIMO results
    variant_pos_genomic : int
        Genomic position of variant (1-based)
    seq_start_genomic : int
        Genomic start position of sequence (0-based)
    window_size : int
        Window size
    
    Returns:
    --------
    fimo_df : pd.DataFrame
        FIMO DataFrame with added columns:
        - start_rel: Start position relative to variant (variant at 0)
        - end_rel: End position relative to variant
        - color: Color based on p-value
        - tf_name: Cleaned TF name
    """
    # Convert to relative coordinates (variant at position 0)
    # FIMO coordinates are 1-based within the sequence
    # seq_start_genomic is 0-based
    # variant_pos_genomic is 1-based
    
    variant_pos_in_seq = variant_pos_genomic - seq_start_genomic - 1  # Convert to 0-based in sequence
    
    # FIMO start/stop are 1-based within sequence
    fimo_df['start_rel'] = (fimo_df['start'] - 1) - variant_pos_in_seq
    fimo_df['end_rel'] = (fimo_df['stop'] - 1) - variant_pos_in_seq
    
    # Clean TF names (remove JASPAR IDs if present, e.g., "MA0001.1_IRF7" -> "IRF7")
    def clean_tf_name(pattern_name):
        # Remove JASPAR ID prefix if present
        if '_' in pattern_name:
            parts = pattern_name.split('_')
            if len(parts) > 1 and parts[0].startswith('MA'):
                return '_'.join(parts[1:])
        return pattern_name
    
    fimo_df['tf_name'] = fimo_df['pattern_name'].apply(clean_tf_name)
    # Consistent case: uppercase for gene/TF symbols (e.g. Sox6 -> SOX6)
    fimo_df['tf_name'] = fimo_df['tf_name'].str.upper()
    
    # Assign colors based on p-value
    def assign_color(pval):
        if pval < 1e-5:
            return 'navy'
        elif pval < 1e-4:
            return 'mediumblue'
        elif pval < 1e-3:
            return 'lightblue'
        elif pval < 1e-2:
            return 'lightgray'
        else:
            return 'gainsboro'
    
    fimo_df['color'] = fimo_df['p-value'].apply(assign_color)
    
    # Filter to window
    half_window = window_size // 2
    fimo_df = fimo_df[
        (fimo_df['start_rel'] >= -half_window) & 
        (fimo_df['end_rel'] <= half_window)
    ].copy()
    
    return fimo_df


def load_gtex_thyroid_tpm(gtex_tsv_path):
    """
    Load GTEx RNA-seq thyroid TPM per gene from a TSV.
    
    Expected format: TSV with a gene-identifier column and a thyroid TPM column.
    Auto-detects common column names:
    - gene: 'gene_symbol', 'gene_name', 'Name', 'gene_id', or first column
    - thyroid TPM: 'thyroid_tpm', 'Thyroid', 'TPM', or column containing 'thyroid' or 'TPM'
    
    Returns:
    --------
    pandas.Series
        Index = gene symbol (str), value = thyroid TPM (float).
    """
    df = pd.read_csv(gtex_tsv_path, sep='\t')
    cols = [c for c in df.columns]
    gene_col = None
    for c in cols:
        if c in ('gene_symbol', 'gene_name', 'Name', 'gene_id', 'gene'):
            gene_col = c
            break
    if gene_col is None:
        gene_col = cols[0]
    tpm_col = None
    for c in cols:
        if c == gene_col:
            continue
        if 'thyroid' in c.lower() or (c == 'TPM' and len(cols) <= 3):
            tpm_col = c
            break
    if tpm_col is None and len(cols) >= 2:
        tpm_col = cols[1]
    if tpm_col is None:
        raise ValueError(f"Could not infer thyroid TPM column from {gtex_tsv_path}. Columns: {cols}")
    out = df.set_index(df[gene_col].astype(str).str.upper())[tpm_col]
    return out.astype(float)


def apply_gtex_filter_and_colors(fimo_df, gtex_thyroid_tpm, exclude_low_tpm=False):
    """
    Filter FIMO hits by GTEx thyroid TPM >= 1 (or >= 5 when exclude_low_tpm)
    and assign color by TPM bin.
    Matches original figure: "elements for TFs with less than one TPM in thyroid
    were removed"; color = GTEx thyroid expression level.
    
    When exclude_low_tpm is True, drops bins below 5 (keeps only >100, 35–100, 15–35, 5–15).
    
    TPM bins: colorblind-friendly sequential blue (ColorBrewer-style).
    Highest tier (TPM > 100) near-full opacity; lower tiers use more transparency.
    - TPM > 100:  #2171B5 (white text), font 11, alpha 0.9
    - 35--100:    #6BAED6 (black text), font 10, alpha 0.76
    - 15--35:     #9ECAE1 (black text), font 9.5, alpha 0.62
    - 5--15:      #C6DBEF (black text), font 9, alpha 0.56
    - 1--5:       #DEEBF7 (black text), font 9, alpha 0.48  [dropped if exclude_low_tpm]
    - < 1:        drop
    """
    df = fimo_df.copy()
    # Map tf_name to gene for lookup (JASPAR motif_alt_id often equals gene symbol)
    def tpm_for_tf(tf_name):
        u = str(tf_name).upper()
        if u in gtex_thyroid_tpm.index:
            return gtex_thyroid_tpm.loc[u]
        if tf_name in gtex_thyroid_tpm.index:
            return gtex_thyroid_tpm.loc[tf_name]
        return np.nan
    df['thyroid_tpm'] = df['tf_name'].map(lambda x: tpm_for_tf(x))
    df = df.dropna(subset=['thyroid_tpm'])
    df = df[df['thyroid_tpm'] >= 1.0]
    if exclude_low_tpm:
        df = df[df['thyroid_tpm'] >= 5.0]
    def color_text_font_alpha_by_tpm(tpm):
        if tpm > 100:
            return ('#2171B5', 'white', 11, 0.9)
        if tpm >= 35:
            return ('#6BAED6', 'black', 10, 0.76)
        if tpm >= 15:
            return ('#9ECAE1', 'black', 9.5, 0.62)
        if tpm >= 5:
            return ('#C6DBEF', 'black', 9, 0.56)
        return ('#DEEBF7', 'black', 9, 0.48)
    df['color'] = df['thyroid_tpm'].apply(lambda t: color_text_font_alpha_by_tpm(t)[0])
    df['text_color'] = df['thyroid_tpm'].apply(lambda t: color_text_font_alpha_by_tpm(t)[1])
    df['motif_fontsize'] = df['thyroid_tpm'].apply(lambda t: color_text_font_alpha_by_tpm(t)[2])
    df['motif_alpha'] = df['thyroid_tpm'].apply(lambda t: color_text_font_alpha_by_tpm(t)[3])
    return df


def filter_fimo_for_display(fimo_df, best_hit_per_tf=True, max_display_pvalue=1e-3,
                            gtex_thyroid_tpm=None, protein_footprint=None,
                            footprint_only=False, exclude_low_tpm=False):
    """
    Reduce FIMO hits for a readable figure.
    
    Original caption: "Predicted TF-binding elements ... with the GTEx RNA-seq
    expression level in the thyroid of each TF indicated by the color (predicted
    binding elements for TFs with less than one TPM in thyroid were removed).
    The region corresponding to a large protein footprint ... is indicated in blue."
    So: (1) filter TFs with thyroid TPM < 1; (2) color by GTEx thyroid expression;
    (3) blue shading *indicates* the footprint region — motifs are shown across the
    full window, not only in the colored regions.
    
    - If gtex_thyroid_tpm is provided: filter to TFs with thyroid TPM >= 1 and
      color by TPM. Then apply p-value cutoff and best hit per TF.
    - Otherwise: p-value cutoff and best hit per TF, p-value colors.
    - If footprint_only: keep only motifs overlapping the protein footprint
      (contradicts caption; use only when you explicitly want footprint-only).
    - best_hit_per_tf: if True, keep only the single best (lowest p-value) hit
      per TF; if False, keep all hits so the same motif can repeat at different
      positions (e.g. 1 bp apart), matching the original figure.
    When best_hit_per_tf is False, duplicate TF names at the *same* start position
    are collapsed to the best hit per (tf_name, start_rel). That avoids showing
    the same label twice at the same location when JASPAR has multiple motif
    matrices for one TF (e.g. MA0142.1_PGR and MA0142.2_PGR both become "PGR").
    """
    df = fimo_df.copy()
    if gtex_thyroid_tpm is not None:
        df = apply_gtex_filter_and_colors(df, gtex_thyroid_tpm, exclude_low_tpm=exclude_low_tpm)
    if max_display_pvalue is not None:
        df = df[df['p-value'] < max_display_pvalue]
    if footprint_only and protein_footprint is not None:
        pf_start, pf_end = protein_footprint[0], protein_footprint[1]
        overlap = (df['start_rel'] < pf_end) & (df['end_rel'] > pf_start)
        df = df[overlap]
    if len(df) == 0:
        return df
    if best_hit_per_tf:
        df = df.loc[df.groupby('tf_name')['p-value'].idxmin()]
    else:
        # Collapse same-TF same-site: JASPAR has multiple matrices per TF
        # (e.g. MA0142.1_PGR, MA0142.2_PGR); keep best hit per (tf_name, start_rel)
        df = df.loc[df.groupby(['tf_name', 'start_rel'])['p-value'].idxmin()]
    return df


# ============================================================================
# PLOTTING FUNCTIONS
# ============================================================================

def plot_fimo_scans(ax, fimo_df, window_size=160, protein_footprint=(-40, -10), variant_region=(-5, 5),
                    annotation_bars=None, font_scale=1.5, motif_box_scale=1.25, annotation_bar_scale=1.2):
    """
    Plot FIMO scan tracks. Fixed shaded regions are not drawn here;
    FOXE1 highlighting is added via add_foxe1_highlighting() across all panels.
    annotation_bars: optional list of (label, start_rel, end_rel) or
     (label, start_rel, end_rel, opts) for red bars. opts may include
     {"no_left_edge": True} to draw the bar fill and only top/right/bottom edges.
    font_scale: multiply all font sizes by this (default 1.5).
    motif_box_scale: scale motif bar height and row spacing (default 1.25).
    annotation_bar_scale: scale annotation bar height relative to motif bars (default 1.2).
    """
    half_window = window_size // 2
    
    # Set x-axis limits; no ticks/labels/spine on FIMO panel (x-axis only on bottom panel)
    ax.set_xlim(-half_window, half_window)
    ax.set_xticks([])
    # No center dashed line; left/right variant bounds drawn only on contrib panels
    
    # Sort: by TPM descending when available (highest at bottom), else by position
    motif_fontsize = 11 * font_scale + 2   # base motif font +1 pt; upper boxes same
    bar_height = 1.15 * motif_box_scale
    y_spacing = 1.45 * motif_box_scale
    bar_height_ann = bar_height * annotation_bar_scale
    y_spacing_ann = y_spacing * annotation_bar_scale  # spacing for annotation rows
    overlap_thresh = 1.1 * motif_box_scale  # same row if vertical distance < this (>= bar_height)
    if 'thyroid_tpm' in fimo_df.columns:
        fimo_df_sorted = fimo_df.sort_values(
            ['thyroid_tpm', 'start_rel'],
            ascending=[False, True],
            na_position='last'
        )
    else:
        fimo_df_sorted = fimo_df.sort_values(['start_rel', 'tf_name'])
    
    # Pack motifs into rows by horizontal overlap: non-overlapping motifs share a row.
    # When TPM-sorted, high-TPM motifs are processed first and get lowest rows (bottom).
    y_positions = {}  # y -> list of (start, end) at that y (for overlap checks)
    for idx, row in fimo_df_sorted.iterrows():
        tf_name = row['tf_name']
        start_rel = row['start_rel']
        end_rel = row['end_rel']
        y_pos = None
        for y in range(500):
            y_val = y * y_spacing
            overlaps = False
            for existing_y, extents in y_positions.items():
                if abs(existing_y - y_val) < overlap_thresh:
                    for (existing_start, existing_end) in extents:
                        start_shifted, end_shifted = start_rel + 0.5, end_rel + 0.5
                        if not (end_shifted < existing_start or start_shifted > existing_end):
                            overlaps = True
                            break
                    if overlaps:
                        break
            if not overlaps:
                y_pos = y_val
                if y_pos not in y_positions:
                    y_positions[y_pos] = []
                y_positions[y_pos].append((start_rel + 0.5, end_rel + 0.5))
                break
        if y_pos is None:
            continue
        # Shift by +0.5 so bars align with ChromBPNet/SHAP (each nucleotide at [i, i+1))
        left = start_rel + 0.5
        width = end_rel - start_rel
        bar_alpha = row.get('motif_alpha', 1.0)
        ax.barh(y_pos, width, left=left, height=bar_height,
                color=row['color'], alpha=bar_alpha, edgecolor='black', linewidth=0.5, zorder=2)
        mid_x = left + width / 2
        text_color = row.get('text_color', 'black')
        fs = motif_fontsize  # all motifs same size (largest tier: 11*font_scale+1)
        ax.text(mid_x, y_pos, tf_name, ha='center', va='center',
                fontsize=fs, color=text_color, zorder=3)
    
    y_max = max(y_positions.keys()) + 1 if y_positions else 1
    # Annotation bars (Accessible chromatin, AluSX1, STR) with distinct colorblind-friendly colors; larger height
    if annotation_bars:
        for i, bar in enumerate(annotation_bars):
            opts = bar[3] if len(bar) > 3 else {}
            label, start_rel, end_rel = bar[0], bar[1], bar[2]
            bar_color = opts.get("color", "#CC79A7")  # default Okabe–Ito pink
            y_ann = y_max + (i + 1) * y_spacing_ann
            left = start_rel + 0.5
            width = end_rel - start_rel
            no_left_edge = opts.get("no_left_edge", False)
            if no_left_edge:
                ax.barh(y_ann, width, left=left, height=bar_height_ann,
                        color=bar_color, edgecolor='none', zorder=2)
                # Draw only top, right, bottom edges (no left)
                y_lo = y_ann - bar_height_ann / 2
                y_hi = y_ann + bar_height_ann / 2
                ax.plot([left, left + width], [y_hi, y_hi], 'k-', linewidth=0.5, zorder=2)
                ax.plot([left + width, left + width], [y_lo, y_hi], 'k-', linewidth=0.5, zorder=2)
                ax.plot([left + width, left], [y_lo, y_lo], 'k-', linewidth=0.5, zorder=2)
            else:
                ax.barh(y_ann, width, left=left, height=bar_height_ann,
                        color=bar_color, edgecolor='black', linewidth=0.5, zorder=2)
            mid_x = left + width / 2
            ax.text(mid_x, y_ann, label, ha='center', va='center',
                    fontsize=motif_fontsize, color='white', zorder=3)  # upper boxes use motif_fontsize (+1 pt)
        y_max = y_max + len(annotation_bars) * y_spacing_ann + (bar_height_ann / 2) + 0.1  # pad above top bar
    ax.set_ylim(-(bar_height / 2) - 0.1, y_max)  # pad below bottom bar so edges aren't clipped
    
    ax.set_yticks([])
    ax.set_ylabel('Predicted TF-binding\nElements', rotation=90, labelpad=10, fontsize=14 * font_scale + 1)
    ax.spines['bottom'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)


def add_foxe1_highlighting(axes_list, fimo_df, alpha_per_span=0.14):
    """
    Add blue highlighting for all FOXE1 binding regions across panels.
    Overlapping FOXE1 regions additively darken (same alpha per span, drawn behind).
    """
    foxe1 = fimo_df[fimo_df['tf_name'] == 'FOXE1'] if 'tf_name' in fimo_df.columns else pd.DataFrame()
    if foxe1.empty:
        return
    # Shift by +0.5 so highlighting aligns with ChromBPNet/SHAP (nucleotide at [i, i+1))
    for ax in axes_list:
        for _, row in foxe1.iterrows():
            s = row['start_rel'] + 0.5
            e = row['end_rel'] + 0.5
            ax.axvspan(s, e, alpha=alpha_per_span, color='#42A5F5', zorder=0)


def build_figure_caption(fimo_display_pvalue, gtex_used, foxe1_count, ref_label, alt_label,
                         exclude_low_tpm=False, use_fimo_top=False, use_finemo_overlays=False):
    """Build a programmatic caption describing how the figure was created."""
    if use_fimo_top and use_finemo_overlays:
        lines = [
            "Top: predicted TF-binding elements from FIMO (JASPAR CORE), p < {:.0e}. "
            "SHAP panels: motif overlays from finemo (ChromBPNet SHAP-based); score ≥ 1.0.".format(fimo_display_pvalue),
        ]
    elif use_finemo_overlays:
        lines = [
            "Predicted TF-binding elements from finemo (ChromBPNet SHAP-based motif calling). "
            "Motifs shown for ref and alt alleles; score ≥ 1.0.",
        ]
    else:
        lines = [
            "Predicted TF-binding elements from FIMO (JASPAR CORE), p < {:.0e}, replicating the "
            "findings in Grasberger, Dumitrescu et al. 2024.".format(fimo_display_pvalue),
        ]
    if gtex_used:
        tpm_thresh = "≥ 5" if exclude_low_tpm else "≥ 1"
        lines.append(
            "TFs with thyroid TPM {} (GTEx) retained. Box color indicates expression level "
            "using colorblind-friendly sequential blue (darker = higher TPM): >100, 35–100, "
            "15–35, 5–15, 1–5 TPM; lower tiers use increased transparency. "
            "Motif names are shown centered on each binding element.".format(tpm_thresh)
        )
    if foxe1_count > 0:
        lines.append("Blue highlighting: FOXE1 binding regions ({} hit{}).".format(
            foxe1_count, "s" if foxe1_count != 1 else ""))
    lines.append("Bottom: ChromBPNet predicted profiles and contribution scores for ref ({}) and alt ({}).".format(
        ref_label or "—", alt_label or "—"))
    return " ".join(lines)


# ============================================================================
# MAIN FUNCTION
# ============================================================================

def generate_combined_figure(variant_id, model_name, variants_tsv, model_paths_tsv,
                            genome_fa, output_file,
                            window_size=160, input_len=2114,
                            protein_footprint=(-40, -10), variant_region=(-5, 5),
                            pvalue_threshold=1e-2, fimo_display_pvalue=5e-4, device='cuda',
                            gtex_thyroid_tsv=None, fimo_footprint_only=False,
                            fimo_all_hits_per_tf=False, exclude_low_tpm=False,
                            font_scale=1.5,
                            motifs_tsv=None, jaspar_meme_file=None, finemo_motif_fontsize=None,
                            motif_mapping_tsv=None, no_foxe1_highlighting=False,
                            top_panel_annotations_only=False):
    """
    Generate combined figure with motif scans (finemo or FIMO) and ChromBPNet predictions.

    When motifs_tsv is provided: use finemo for SHAP-panel motif overlays (and for top panel if no FIMO).
    When jaspar_meme_file is provided: run FIMO for the top motif panel.
    Both can be provided: top panel = FIMO, SHAP panels = finemo overlays.
    At least one of motifs_tsv or jaspar_meme_file must be provided.
    GTEx thyroid TPM filtering/coloring applies to both; blue shading indicates FOXE1.
    finemo_motif_fontsize: if None, matches FIMO motif font size (11 * font_scale + 2); else use this value.
    """
    if not motifs_tsv and not jaspar_meme_file:
        raise ValueError("At least one of --motifs-tsv (finemo) or --jaspar-meme (FIMO) must be provided")
    use_fimo_for_top = jaspar_meme_file is not None
    use_finemo_for_overlays = motifs_tsv is not None

    print("=" * 80)
    print(f"Generating combined figure for variant: {variant_id}")
    print(f"Model: {model_name}")
    print(f"  --motifs-tsv:   {motifs_tsv}")
    print(f"  --jaspar-meme: {jaspar_meme_file}")
    if use_fimo_for_top and use_finemo_for_overlays:
        print("Motifs: top panel = FIMO, SHAP overlays = finemo")
    else:
        print(f"Top motif panel: {'FIMO' if use_fimo_for_top else 'finemo'}")
        print(f"SHAP motif overlays: {'finemo' if use_finemo_for_overlays else 'none'}")
    print("=" * 80)
    
    # Step 1: Load variant information
    print("\n1. Loading variant information...")
    variants_df = pd.read_csv(variants_tsv, sep='\t')
    
    if 'variant_id' not in variants_df.columns:
        raise ValueError("variants_tsv must contain 'variant_id' column")
    
    variant_row = variants_df[variants_df['variant_id'] == variant_id]
    if len(variant_row) == 0:
        raise ValueError(f"Variant {variant_id} not found in {variants_tsv}")
    
    variant_row = variant_row.iloc[0]
    chrom = variant_row['chr']
    pos = int(variant_row['pos'])
    
    # Handle both column naming conventions
    if 'ref' in variant_row.index and 'alt' in variant_row.index:
        ref = variant_row['ref']
        alt = variant_row['alt']
    elif 'allele1' in variant_row.index and 'allele2' in variant_row.index:
        ref = variant_row['allele1']
        alt = variant_row['allele2']
    else:
        raise ValueError("Variants TSV must contain either 'ref'/'alt' or 'allele1'/'allele2' columns")
    
    print(f"   Variant: {chrom}:{pos}:{ref}→{alt}")
    
    ref_motifs, alt_motifs = None, None
    gtex_thyroid_tpm = None
    if gtex_thyroid_tsv:
        gtex_thyroid_tpm = load_gtex_thyroid_tpm(gtex_thyroid_tsv)
        print(f"   Loaded GTEx thyroid TPM for {len(gtex_thyroid_tpm)} genes")

    # Finemo: parse for SHAP-panel overlays (and for top panel when FIMO not used)
    if use_finemo_for_overlays:
        print("\n2–3. Parsing finemo motif annotations (SHAP overlays)...")
        motif_mapping = _load_motif_mapping(motif_mapping_tsv) if motif_mapping_tsv else None
        if motif_mapping_tsv:
            print(f"   Using motif mapping from {motif_mapping_tsv} ({len(motif_mapping)} entries)")
        ref_motifs, alt_motifs = _parse_finemo_motifs(
            motifs_tsv, variant_id, chrom, pos, window_size, model_name=model_name,
            motif_mapping=motif_mapping
        )
        print(f"   Found {len(ref_motifs or [])} ref motifs, {len(alt_motifs or [])} alt motifs (finemo)")

    # Filter out ZNF284 and NFI from profile overlays; add manual FOX (-9 to +2) and TEAD-CEBP-NFI (41 to 67)
    # Match display logic: finemo names can be "ZNF284_0", "NFI_1" etc.; normalize by stripping trailing _N
    _excluded_finemo = {'ZNF284', 'NFI'}

    def _normalize_motif_name(name):
        s = (name or '').strip()
        if not s:
            return ''
        parts = [p.strip() for p in s.split(';') if p.strip()]
        if not parts:
            return s.upper()
        first = re.sub(r'_\d+$', '', parts[0]).strip() or parts[0]
        return first.upper()

    def _drop_finemo_motifs(motifs):
        return [m for m in (motifs or []) if _normalize_motif_name(m.get('motif')) not in _excluded_finemo]
    ref_motifs = _drop_finemo_motifs(ref_motifs)
    alt_motifs = _drop_finemo_motifs(alt_motifs)
    _fox_motif = {'motif': 'FOX', 'start_plot': -9, 'end_plot': 2, 'score': 1.0, 'show_score': False}
    _tead_motif = {'motif': 'TEAD', 'start_plot': 41, 'end_plot': 48, 'score': 1.0, 'show_score': False}
    _cebp_motif = {'motif': 'CEBP', 'start_plot': 49, 'end_plot': 60, 'score': 1.0, 'show_score': False}
    _nfi_motif = {'motif': 'NFI', 'start_plot': 60, 'end_plot': 66, 'score': 1.0, 'show_score': False}
    _manual_motifs = [_fox_motif, _tead_motif, _cebp_motif, _nfi_motif]
    ref_motifs = ref_motifs + _manual_motifs
    alt_motifs = alt_motifs + _manual_motifs
    # Do not show scores on any motif annotations (contribution panels)
    for m in ref_motifs + alt_motifs:
        m['show_score'] = False

    # Top motif panel: FIMO (if jaspar provided) or finemo (if only motifs_tsv)
    if use_fimo_for_top:
        # Step 2–3: Run FIMO for top panel
        if not os.path.exists(jaspar_meme_file):
            raise FileNotFoundError(
                f"JASPAR MEME file not found: {jaspar_meme_file}\n"
                "Pass a valid --jaspar-meme path to show FIMO motifs in the top panel."
            )
        print("\n2–3. FIMO scan for top motif panel...")
        print(f"   JASPAR MEME: {jaspar_meme_file}")
        sequence, seq_start_genomic, variant_pos_in_seq = extract_sequence_for_fimo(
            chrom, pos, ref, alt, genome_fa, window_size=window_size
        )
        print(f"   Sequence length: {len(sequence)} bp")
        print(f"   Genomic region: {chrom}:{seq_start_genomic}-{seq_start_genomic + len(sequence)}")
        with tempfile.TemporaryDirectory() as temp_dir:
            fimo_df = run_fimo_scan(
                sequence, f"{chrom}:{pos}", jaspar_meme_file, temp_dir,
                pvalue_threshold=pvalue_threshold
            )
            fimo_df = process_fimo_results(
                fimo_df, pos, seq_start_genomic, window_size=window_size
            )
            n_fimo_before_filter = len(fimo_df)
            print(f"   Found {n_fimo_before_filter} FIMO hits in window")
            fimo_df = filter_fimo_for_display(
                fimo_df,
                best_hit_per_tf=not fimo_all_hits_per_tf,
                max_display_pvalue=fimo_display_pvalue,
                gtex_thyroid_tpm=gtex_thyroid_tpm,
                protein_footprint=protein_footprint if fimo_footprint_only else None,
                footprint_only=fimo_footprint_only,
                exclude_low_tpm=exclude_low_tpm,
            )
            if len(fimo_df) == 0 and n_fimo_before_filter > 0:
                print(
                    f"   WARNING: All {n_fimo_before_filter} FIMO hits were filtered out. "
                    "Try relaxing --fimo-display-pvalue or omitting --gtex-thyroid-tsv to see unfiltered FIMO hits."
                )
            elif len(fimo_df) == 0:
                print("   No FIMO hits in window (FIMO returned 0 hits for this sequence).")
            msg = f"   Displaying {len(fimo_df)} motifs (FIMO, p < {fimo_display_pvalue})"
            if gtex_thyroid_tpm is not None:
                msg += " (GTEx thyroid TPM ≥ 5)" if exclude_low_tpm else " (GTEx thyroid TPM ≥ 1)"
            msg += ", all hits per TF" if fimo_all_hits_per_tf else ", best hit per TF"
            if fimo_footprint_only:
                msg += ", in protein footprint only"
            print(msg)
    else:
        # Top panel from finemo (no FIMO)
        fimo_df = finemo_motifs_to_dataframe(
            ref_motifs, alt_motifs,
            gtex_thyroid_tpm=gtex_thyroid_tpm,
            exclude_low_tpm=exclude_low_tpm,
        )
        print(f"   Top panel: {len(fimo_df)} motifs (finemo, score ≥ 1.0)")
    
    # Step 4–6: When variant is in uwcrdr file and model is ENCSR474XFV, pull predictions/SHAP from
    # precomputed files (same layout as prototype); otherwise run on-the-fly.
    variant_in_uwcrdr = False
    if FILE_LOADING_MODEL_MATCH in model_name and _FILE_LOADING_AVAILABLE:
        try:
            _variant_idx_from_uwcrdr(chrom, pos, ref, alt)
            variant_in_uwcrdr = True
        except ValueError as e:
            pass
    use_file_predictions = (
        (variant_id == FILE_LOADING_VARIANT_ID or variant_in_uwcrdr)
        and FILE_LOADING_MODEL_MATCH in model_name
        and _FILE_LOADING_AVAILABLE
    )
    # Diagnose why file-loading is or isn't used
    if FILE_LOADING_MODEL_MATCH in model_name:
        if use_file_predictions:
            print(f"\n   [File-loading: using precomputed predictions/SHAP (variant in uwcrdr, deps OK)]")
        else:
            reasons = []
            if not _FILE_LOADING_AVAILABLE:
                reasons.append("h5py/deepdish/torch not available (failed import)")
            elif not (variant_id == FILE_LOADING_VARIANT_ID or variant_in_uwcrdr):
                reasons.append(
                    f"variant {chrom}:{pos}:{ref}:{alt} not in uwcrdr file ({UWCRDR_VARIANTS_PATH})"
                )
            if reasons:
                print(f"\n   [File-loading skipped: {'; '.join(reasons)}]")
    if use_file_predictions:
        encsr_id = model_name.split("__")[1] if "__" in model_name else model_name
        scores_dir = f"{SCORES_BASE}/{encsr_id}"
        shap_dir = f"{SHAP_BASE}/{encsr_id}"
        print(f"\n4–6. Pulling predictions and SHAP from files (variant={variant_id}, model={model_name})")
        print(f"   Scores: {scores_dir}")
        print(f"   SHAP:   {shap_dir}")
        variant_idx = _variant_idx_from_uwcrdr(chrom, pos, ref, alt)
        print(f"   Variant index in uwcrdr file: {variant_idx}")
        preds1, preds2, counts1, counts2 = _load_scores_from_files(scores_dir)
        ref_profile = preds1[variant_idx, :]
        alt_profile = preds2[variant_idx, :]
        ref_count = float(np.ravel(counts1)[variant_idx])
        alt_count = float(np.ravel(counts2)[variant_idx])
        logfc = np.log2(alt_count / ref_count) if (ref_count > 0 and alt_count > 0) else 0.0
        shaps1, shaps2 = _load_shaps_from_files(shap_dir)
        ref_shap = shaps1[variant_idx, :, :]   # (2114, 4)
        alt_shap = shaps2[variant_idx, :, :]
        ref_attr = [torch.from_numpy(ref_shap.T).float()]   # (4, 2114) for downstream .T → (2114, 4)
        alt_attr = [torch.from_numpy(alt_shap.T).float()]
        print(f"   Predictions loaded from files: ref_profile {ref_profile.shape}, alt_profile {alt_profile.shape}")
        print(f"   SHAP loaded from files: ref_shap {ref_shap.shape}, alt_shap {alt_shap.shape}")
        print(f"   LogFC: {logfc:.3f}")
    else:
        # Step 4: Resolve model metadata and generate ChromBPNet predictions on-the-fly
        print("\n4. Resolving model metadata...")
        # Resolve shorthand / alias model names to model_paths.tsv model_name
        model_name_to_resolve = model_name
        if model_name in ("ENCSR474XFV", "thyroid_gland__ENCSR474XFV"):
            model_name_to_resolve = "ENC_ENCODE_ENCSR474XFV"
        try:
            metadata = _resolve_model_metadata(model_name_to_resolve, model_paths_tsv)
        except ValueError as e:
            # Provide helpful error message with suggestions
            print(f"\nError: {e}")
            print(f"\nChecking for similar model names...")
            model_df = pd.read_csv(model_paths_tsv, sep='\t')
            available_models = model_df['model_name'].tolist()

            # Check for partial matches
            suggestions = []
            if 'ENCSR474XFV' in model_name:
                suggestions = [m for m in available_models if 'ENCSR474XFV' in m]
            elif 'thyroid' in model_name.lower():
                suggestions = [m for m in available_models if 'thyroid' in m.lower() or 'Thyroid' in m]

            if suggestions:
                print("\nDid you mean one of these?")
                for sug in suggestions[:5]:  # Show first 5 suggestions
                    print(f"  - {sug}")
            else:
                print("\nAvailable models containing 'ENCSR474XFV':")
                encsr_models = [m for m in available_models if 'ENCSR474XFV' in m]
                for m in encsr_models[:5]:
                    print(f"  - {m}")

            raise

        # Filter model folds to those that actually exist on disk
        model_folds_all = metadata['model_folds']
        model_folds = [mf for mf in model_folds_all if os.path.exists(mf)]
        # If ENC_ENCODE_ENCSR474XFV paths are in TSV but on-disk dir is ENCSR474XFV, try that base
        if not model_folds and "ENC_ENCODE_ENCSR474XFV" in model_name_to_resolve:
            alt_folds = [
                p.replace("ENC_ENCODE_ENCSR474XFV", "ENCSR474XFV")
                for p in model_folds_all
            ]
            model_folds = [p for p in alt_folds if os.path.exists(p)]
            if model_folds:
                print(
                    "Using paths under .../ATAC/ENCSR474XFV/ (model_paths had .../ENC_ENCODE_ENCSR474XFV/)."
                )
                model_folds_all = alt_folds
        if not model_folds:
            msg = (
                "None of the model_fold paths exist on disk for "
                f"model '{model_name_to_resolve}'. Checked:\n  - " + "\n  - ".join(model_folds_all)
            )
            if "ENCSR474XFV" in model_name_to_resolve:
                msg += (
                    "\n\nENCSR474XFV / thyroid_gland__ENCSR474XFV resolve to 'ENC_ENCODE_ENCSR474XFV'. "
                    "Ensure those ChromBPNet .h5 paths in model_paths.tsv exist or point to the correct location."
                )
            raise FileNotFoundError(msg)
        if len(model_folds) < len(model_folds_all):
            missing = [mf for mf in model_folds_all if mf not in model_folds]
            print(
                f"\nWarning: {len(missing)} of {len(model_folds_all)} folds are missing on disk "
                "and will be skipped for this figure."
            )
            for mf in missing:
                print(f"  - missing: {mf}")

        print("\n5. Extracting sequences for ChromBPNet...")
        ref_seq, alt_seq, ref_seq_str, alt_seq_str = _extract_variant_sequences(
            chrom, pos, ref, alt, metadata['genome'], input_len=input_len
        )
        print(f"   Extracted {input_len}bp sequences")

        print("\n6. Generating ChromBPNet predictions and attributions...")
        ref_profile, alt_profile, ref_count, alt_count, ref_attr, alt_attr, logfc = \
            _process_single_model_with_folds(
                model_folds, ref_seq, alt_seq, ref_seq_str, alt_seq_str,
                n_shuffles=20, device=device
            )
        print(f"   LogFC: {logfc:.3f}")
        print(f"   Predictions loaded: ref_profile {getattr(ref_profile, 'shape', type(ref_profile).__name__)}, "
              f"alt_profile {getattr(alt_profile, 'shape', type(alt_profile).__name__)}")
        _ra = ref_attr[0] if hasattr(ref_attr, '__getitem__') else ref_attr
        _aa = alt_attr[0] if hasattr(alt_attr, '__getitem__') else alt_attr
        print(f"   SHAP loaded: ref_attr {getattr(_ra, 'shape', type(_ra).__name__)}, "
              f"alt_attr {getattr(_aa, 'shape', type(_aa).__name__)}")

    # Step 7: Create figure layout (publication-friendly: white background, no rasterization)
    print("\n7. Creating figure layout...")
    fig = plt.figure(figsize=(20, 14), dpi=400, facecolor='white')
    # Shorter top panel when only annotation bars (STR, AluSX1, accessible chromatin)
    height_ratios = [0.35, 1, 1, 1] if top_panel_annotations_only else [1.5, 1, 1, 1]
    gs = GridSpec(4, 1, figure=fig, height_ratios=height_ratios, hspace=0.12)
    
    # Step 8: Plot top section (motif panel or annotations only)
    if top_panel_annotations_only:
        print("\n8. Plotting top panel (annotations only: STR, AluSX1, accessible chromatin)...")
    else:
        print("\n8. Plotting top motif panel ({} scans)...".format("FIMO" if use_fimo_for_top else "finemo"))
    ax_fimo = fig.add_subplot(gs[0])
    ax_fimo.set_gid('fimo-panel')  # SVG group id for collaboration/editing
    half_win = window_size // 2
    # Annotation bars: colorblind-friendly, distinct colors (Okabe–Ito–style)
    # Accessible chromatin: purple (open/accessible chromatin); AluSX: vermillion; STR: green
    annotation_bars = [
        ("Accessible chromatin element", -half_win, half_win,
         {"no_left_edge": True, "color": "#6A3D9A"}),   # purple
        ("AluSX1", -10, half_win, {"color": "#D55E00"}), # vermillion/orange-red
        ("STR", -3, 5, {"color": "#009E73"}),           # bluish green
    ]
    if top_panel_annotations_only:
        # Empty motif data and no STR/AluSX1/accessible chromatin bars; no Y title, no variant lines, no TPM legend
        empty_fimo_df = pd.DataFrame(columns=['tf_name', 'start_rel', 'end_rel'])
        plot_fimo_scans(ax_fimo, empty_fimo_df, window_size=window_size,
                       protein_footprint=protein_footprint, variant_region=variant_region,
                       annotation_bars=None, font_scale=font_scale)
        ax_fimo.set_ylabel('')
    else:
        plot_fimo_scans(ax_fimo, fimo_df, window_size=window_size,
                       protein_footprint=protein_footprint, variant_region=variant_region,
                       annotation_bars=annotation_bars, font_scale=font_scale)
        # Variant boundary dotted lines (left -0.5, right ref_length-0.5), beneath motifs & annotation bars
        ref_len = len(ref) if ref != "-" else 0
        ax_fimo.axvline(-0.5, color='k', linestyle='--', linewidth=1, zorder=0)
        ax_fimo.axvline(ref_len - 0.5, color='k', linestyle='--', linewidth=1, zorder=0)
        # TPM legend (top right) when GTEx thyroid colors are used
        if gtex_thyroid_tsv:
            tpm_handles = [
                mpatches.Patch(facecolor='#2171B5', edgecolor='black', label='TPM > 100'),
                mpatches.Patch(facecolor='#6BAED6', edgecolor='black', label='35–100'),
                mpatches.Patch(facecolor='#9ECAE1', edgecolor='black', label='15–35'),
                mpatches.Patch(facecolor='#C6DBEF', edgecolor='black', label='5–15'),
            ]
            if not exclude_low_tpm:
                tpm_handles.append(mpatches.Patch(facecolor='#DEEBF7', edgecolor='black', label='1–5'))
            ax_fimo.legend(handles=tpm_handles, loc='upper right', fontsize=12 * font_scale + 1)
    
    # Step 9: Plot ChromBPNet predictions (bottom section)
    print("\n9. Plotting ChromBPNet predictions...")
    
    ref_label = ref if ref != "-" else ""
    alt_label = alt if alt != "-" else ""
    # Panel 1: Predicted Profiles
    ax1 = fig.add_subplot(gs[1])
    ax1.set_gid('profile-panel')  # SVG group id for collaboration/editing
    ref_length = len(ref) if ref != "-" else 0
    alt_length = len(alt) if alt != "-" else 0
    _plot_profile(ref_profile, alt_profile, ref_length, alt_length,
                  ref_label, alt_label, window_size, ax1, logfc=logfc, logfc_pval=None,
                  logfc_legend_font_scale=font_scale + 1/12)  # 12*(font_scale+1/12)=12*font_scale+1
    ax1.set_ylabel('Predicted Profiles', rotation=90, labelpad=10, fontsize=14 * font_scale + 1)
    ax1.tick_params(axis='y', labelsize=12 * font_scale + 1)
    configure_shared_xaxis(ax1, window_size=window_size, font_scale=font_scale)
    ax1.set_xticks([])  # no x ticks on predicted profiles
    ax1.set_xticklabels([])
    
    # Panel 2: Contribution (ref)
    ax2 = fig.add_subplot(gs[2])
    ax2.set_gid('contrib-ref-panel')  # SVG group id for collaboration/editing
    ref_attr_np = ref_attr[0].cpu().numpy().T  # Shape: (2114, 4)
    alt_attr_np = alt_attr[0].cpu().numpy().T  # Shape: (2114, 4)
    
    # Prepare ref and alt SHAP data for plotting (extract window centered on variant)
    total_length = 2114
    C = total_length // 2   # variant at center of 2114 bp input
    F = window_size // 2
    ref_shap_plot = ref_attr_np[C-F:C+F+ref_length]
    alt_shap_plot = alt_attr_np[C-F:C+F+alt_length]
    
    # Create DataFrame for logomaker
    df_ref = pd.DataFrame(ref_shap_plot, columns=["A", "C", "G", "T"])
    df_ref.index += -F
    
    # Plot logo
    logomaker.Logo(df_ref, ax=ax2)
    ax2.axvline(-0.5, color='k', linestyle='--', linewidth=1)
    ax2.axvline(ref_length-0.5, color='k', linestyle='--', linewidth=1)
    
    # Set y-axis limits from both ref and alt so both panels share the same scale
    ymax = 1.1 * max(np.max(np.maximum(ref_shap_plot, 0)), np.max(np.maximum(alt_shap_plot, 0)))
    ymin = 1.1 * min(np.min(np.minimum(ref_shap_plot, 0)), np.min(np.minimum(alt_shap_plot, 0)))
    ax2.set_ylim(bottom=ymin, top=ymax)
    print(f"  [Contribution panels] shared ylim: ymin={ymin:.4f}, ymax={ymax:.4f}")
    # Match FIMO motif font size (11 * font_scale + 2) when finemo_motif_fontsize not specified
    finemo_fontsize = (11 * font_scale + 2) if finemo_motif_fontsize is None else finemo_motif_fontsize
    motif_to_color = _build_motif_to_color(ref_motifs, alt_motifs) if (ref_motifs or alt_motifs) else None
    if ref_motifs:
        _add_motif_overlays(ax2, ref_motifs, -F, y_position_frac=0.85, motif_fontsize=finemo_fontsize,
                           motif_to_color=motif_to_color)
    
    plt.text(0.988, 0.903, f"ref ({ref_label})",
            verticalalignment='top', horizontalalignment='right',
            transform=ax2.transAxes, size=12 * font_scale + 1, color='black',
            bbox=dict(boxstyle='round', facecolor='white', edgecolor='lightgrey'))
    
    ax2.set_ylabel('Contribution\n(ref)', rotation=90, labelpad=10, fontsize=14 * font_scale + 1)
    ax2.set_yticks([])
    ax2.set_yticklabels([])
    configure_shared_xaxis(ax2, window_size=window_size, font_scale=font_scale)
    ax2.set_xticks([])  # no x ticks on ref contribution plot
    ax2.set_xticklabels([])
    
    # Panel 3: Contribution (alt)
    ax3 = fig.add_subplot(gs[3])
    ax3.set_gid('contrib-alt-panel')  # SVG group id for collaboration/editing

    # alt_shap_plot already computed above for shared ylim
    # Create DataFrame for logomaker
    df_alt = pd.DataFrame(alt_shap_plot, columns=["A", "C", "G", "T"])
    df_alt.index += -F
    
    # Plot logo
    logomaker.Logo(df_alt, ax=ax3)
    ax3.axvline(-0.5, color='k', linestyle='--', linewidth=1)
    ax3.axvline(alt_length-0.5, color='k', linestyle='--', linewidth=1)
    
    # Same y-axis scale as ref panel (already computed above)
    ax3.set_ylim(bottom=ymin, top=ymax)
    if alt_motifs:
        _add_motif_overlays(ax3, alt_motifs, -F, y_position_frac=0.85, motif_fontsize=finemo_fontsize,
                           motif_to_color=motif_to_color)
    
    plt.text(0.988, 0.903, f"alt ({alt_label})",
            verticalalignment='top', horizontalalignment='right',
            transform=ax3.transAxes, size=12 * font_scale + 1, color='black',
            bbox=dict(boxstyle='round', facecolor='white', edgecolor='lightgrey'))
    
    ax3.set_ylabel('Contribution\n(alt)', rotation=90, labelpad=10, fontsize=14 * font_scale + 1)
    ax3.set_yticks([])
    ax3.set_yticklabels([])
    ax3.set_xlabel('Relative genomic position (bp)', fontsize=14 * font_scale + 1)
    configure_shared_xaxis(ax3, window_size=window_size, font_scale=font_scale)
    
    # FOXE1 blue highlighting across all panels (additive); skip if --no-foxe1-highlighting
    if not no_foxe1_highlighting:
        add_foxe1_highlighting([ax_fimo, ax1, ax2, ax3], fimo_df)
    
    # Print programmatic caption to console (for manual addition to manuscript)
    gtex_used = gtex_thyroid_tsv is not None
    foxe1_count = (fimo_df['tf_name'] == 'FOXE1').sum() if 'tf_name' in fimo_df.columns else 0
    caption = build_figure_caption(fimo_display_pvalue, gtex_used, foxe1_count, ref_label, alt_label,
                                   exclude_low_tpm=exclude_low_tpm,
                                   use_fimo_top=use_fimo_for_top, use_finemo_overlays=use_finemo_for_overlays)
    print("\n" + "=" * 60 + "\nFigure caption (add to manuscript):\n" + "=" * 60)
    print(caption)
    print("=" * 60 + "\n")
    
    # Step 10: Save figure (publication/collaboration-friendly SVG: text as text, semantic group ids, metadata)
    print(f"\n10. Saving figure to {output_file}...")
    out_fmt = (output_file.split('.')[-1] or 'svg').lower()
    save_kw = dict(bbox_inches='tight', dpi=400)
    if out_fmt == 'svg':
        if use_fimo_for_top and use_finemo_for_overlays:
            motif_src = 'FIMO + finemo'
        elif use_finemo_for_overlays:
            motif_src = 'finemo'
        else:
            motif_src = 'FIMO'
        save_kw['metadata'] = {
            'Title': f"Variant {variant_id} — {model_name}",
            'Creator': 'regenerate_plot_svg.py (varbook-container)',
            'Description': f'{motif_src} + ChromBPNet combined figure; panels: motif-panel, profile-panel, contrib-ref-panel, contrib-alt-panel',
        }
        with matplotlib.rc_context({'svg.fonttype': 'none'}):  # keep text as text for editing
            fig.savefig(output_file, format='svg', **save_kw)
    else:
        fig.savefig(output_file, format=out_fmt, **save_kw)
    print(f"✓ Figure saved successfully")

    # Also save PDF version (in addition to primary output)
    if out_fmt != 'pdf':
        pdf_output = str(Path(output_file).with_suffix('.pdf'))
        print(f"    Also saving PDF to {pdf_output}...")
        pdf_save_kw = dict(bbox_inches='tight', dpi=400)
        fig.savefig(pdf_output, format='pdf', **pdf_save_kw)
        print(f"✓ PDF saved successfully")
    
    return fig


def configure_shared_xaxis(ax, variant_pos_rel=0, window_size=160, font_scale=1.5):
    """Configure shared x-axis for all panels; ticks include ±50 and window ends.
    No center dashed line; left/right variant bounds are drawn only on contrib panels.
    """
    half_window = window_size // 2
    ax.set_xlim(-half_window, half_window)
    xticks = sorted(set(t for t in [-half_window, -50, 0, 50, half_window] if -half_window <= t <= half_window))
    ax.set_xticks(xticks)
    ax.tick_params(axis='x', labelsize=10 * font_scale + 1)


# ============================================================================
# COMMAND-LINE INTERFACE
# ============================================================================

def main():
    parser = argparse.ArgumentParser(
        description="Generate combined figure with motif scans (finemo or FIMO) and ChromBPNet predictions"
    )
    parser.add_argument("variant_id", help="Variant ID (e.g., 'variant_1')")
    parser.add_argument("model_name", help="Model name (e.g., 'thyroid_gland__ENCSR474XFV')")
    parser.add_argument("output_format", nargs='?', default='svg', choices=['svg', 'png', 'pdf'],
                       help="Output format (default: svg)")
    
    parser.add_argument("--variants-tsv", required=True, 
                       help="Path to variants TSV file")
    parser.add_argument("--model-paths-tsv", required=True,
                       help="Path to model paths TSV file")
    parser.add_argument("--genome-fa", 
                       default="/oak/stanford/groups/akundaje/airanman/refs/hg38/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta",
                       help="Path to reference genome FASTA")
    parser.add_argument("--motifs-tsv",
                       help="Path to finemo motif annotations TSV. Used for SHAP-panel overlays; also for top panel if --jaspar-meme not set. Can be combined with --jaspar-meme (top=FIMO, overlays=finemo).")
    parser.add_argument("--jaspar-meme",
                       help="Path to JASPAR CORE database in MEME format. When provided, runs FIMO for the top motif panel. Can be combined with --motifs-tsv.")
    parser.add_argument("--output-file", 
                       help="Output file path (default: auto-generated)")
    parser.add_argument("--window-size", type=int, default=160,
                       help="Window size around variant, ±half (default: 160 → -80 to 80)")
    parser.add_argument("--protein-footprint", nargs=2, type=int, default=[-40, -10],
                       help="Protein footprint region (default: -40 -10)")
    parser.add_argument("--variant-region", nargs=2, type=int, default=[-5, 5],
                       help="Variant region (default: -5 5)")
    parser.add_argument("--pvalue-threshold", type=float, default=1e-2,
                       help="FIMO p-value threshold (default: 1e-2)")
    parser.add_argument("--fimo-display-pvalue", type=float, default=5e-4,
                       help="Only show motifs with p < this (default: 5e-4, between 1e-3 and 1e-4). Also applied when using GTEx to reduce count.")
    parser.add_argument("--fimo-footprint-only", action="store_true",
                       help="Only show motifs overlapping the protein footprint (-40 to -10). "
                            "The original caption shows motifs across the full window with blue "
                            "indicating the footprint; omit this flag to match that.")
    parser.add_argument("--fimo-all-hits-per-tf", action="store_true",
                       help="Show all FIMO hits per TF (same motif can repeat at different positions, "
                            "e.g. 1 bp apart). Default: only the best (lowest p-value) hit per TF. "
                            "Use this to match the original figure.")
    parser.add_argument("--gtex-thyroid-tsv",
                       help="TSV of GTEx thyroid TPM per gene. When provided, FIMO motifs are "
                            "filtered to TFs with thyroid TPM >= 1 and colored by expression "
                            "(like the original figure). Expects gene-identifier and thyroid-TPM columns.")
    parser.add_argument("--exclude-low-tpm", action="store_true",
                       help="Filter out TPM < 5; show only >100, 35–100, 15–35, 5–15. "
                            "Use when larger fonts cause motif overflow.")
    parser.add_argument("--finemo-motif-fontsize", type=float, default=None,
                       help="Font size for finemo motif labels on SHAP panels (default: match FIMO, 11*font_scale+2)")
    parser.add_argument("--motif-mapping-tsv",
                       help="TSV mapping finemo pattern_id to human-readable motif names (columns: "
                            "pattern_id, motif_0, motif_1). If provided, finemo labels use these names.")
    parser.add_argument("--no-foxe1-highlighting", action="store_true",
                       help="Do not draw blue FOXE1 binding regions across panels")
    parser.add_argument("--top-panel-annotations-only", action="store_true",
                       help="Top panel: show only STR, AluSX1, and accessible chromatin boxes (no Y title, no motif tracks)")
    parser.add_argument("--device", default='cuda', choices=['cuda', 'cpu'],
                       help="Device for ChromBPNet (default: cuda)")
    
    args = parser.parse_args()
    
    # Generate output file name if not provided; default format is SVG
    if not args.output_file:
        args.output_file = f"{args.variant_id}_{args.model_name}.{args.output_format}"
    # Ensure output is SVG when format is svg (normalize extension)
    if args.output_format == 'svg' and not args.output_file.lower().endswith('.svg'):
        args.output_file = str(Path(args.output_file).with_suffix('.svg'))
    
    if not args.motifs_tsv and not args.jaspar_meme:
        parser.error("At least one of --motifs-tsv or --jaspar-meme must be provided")

    # Generate figure
    generate_combined_figure(
        variant_id=args.variant_id,
        model_name=args.model_name,
        variants_tsv=args.variants_tsv,
        model_paths_tsv=args.model_paths_tsv,
        genome_fa=args.genome_fa,
        output_file=args.output_file,
        window_size=args.window_size,
        protein_footprint=tuple(args.protein_footprint),
        variant_region=tuple(args.variant_region),
        pvalue_threshold=args.pvalue_threshold,
        fimo_display_pvalue=args.fimo_display_pvalue,
        device=args.device,
        gtex_thyroid_tsv=args.gtex_thyroid_tsv,
        fimo_footprint_only=args.fimo_footprint_only,
        fimo_all_hits_per_tf=args.fimo_all_hits_per_tf,
        exclude_low_tpm=args.exclude_low_tpm,
        font_scale=1.5,
        motifs_tsv=args.motifs_tsv,
        jaspar_meme_file=args.jaspar_meme,
        finemo_motif_fontsize=args.finemo_motif_fontsize,
        motif_mapping_tsv=args.motif_mapping_tsv,
        no_foxe1_highlighting=args.no_foxe1_highlighting,
        top_panel_annotations_only=args.top_panel_annotations_only,
    )


if __name__ == '__main__':
    main()
