"""
Tests for varbook.plot.variant module.
"""

import pytest
import pandas as pd
import numpy as np
from pathlib import Path
import tempfile


def test_model_scatterplot_import():
    """Test that model_scatterplot module can be imported."""
    from varbook.plot.variant import model_scatterplot
    assert hasattr(model_scatterplot, 'plot_variant_model_scatterplot')


def test_model_scatterplot_basic():
    """Test basic model scatterplot generation."""
    from varbook.plot.variant.model_scatterplot import plot_variant_model_scatterplot

    # Create test data
    df = pd.DataFrame({
        'variant_id': ['var_1', 'var_2', 'var_3'],
        'logfc-KUN_FB_model1': [0.5, -0.3, 0.8],
        'logfc-KUN_FB_model2': [-0.2, 0.6, 0.1],
        'aaq-KUN_FB_model1': [0.1, 0.05, 0.15],
        'aaq-KUN_FB_model2': [0.08, 0.12, 0.06],
        'is_prioritized_by-KUN_FB_model1': [True, False, True],
        'is_prioritized_by-KUN_FB_model2': [False, True, False],
    })

    with tempfile.TemporaryDirectory() as temp_dir:
        output_md = Path(temp_dir) / 'scatter.md'
        output_html = Path(temp_dir) / 'scatter.html'

        plot_variant_model_scatterplot(
            df=df,
            variant_id='var_1',
            x_col='logfc',
            y_col='aaq',
            datasets=['KUN_FB'],
            output_md=str(output_md),
            output_html=str(output_html),
            interactive=True
        )

        # Check that files were created
        assert output_md.exists()
        assert output_html.exists()


def test_model_scatterplot_static():
    """Test static (non-interactive) scatterplot."""
    from varbook.plot.variant.model_scatterplot import plot_variant_model_scatterplot

    df = pd.DataFrame({
        'variant_id': ['var_1'],
        'logfc-KUN_FB_model1': [0.5],
        'aaq-KUN_FB_model1': [0.1],
        'is_prioritized_by-KUN_FB_model1': [True],
    })

    with tempfile.TemporaryDirectory() as temp_dir:
        output_md = Path(temp_dir) / 'scatter.md'
        output_html = Path(temp_dir) / 'scatter.html'

        plot_variant_model_scatterplot(
            df=df,
            variant_id='var_1',
            x_col='logfc',
            y_col='aaq',
            datasets=['KUN_FB'],
            output_md=str(output_md),
            output_html=str(output_html),
            interactive=False
        )

        # Check that markdown was created
        assert output_md.exists()


def test_model_specificity_barplot_import():
    """Test that model_specificity_barplot module can be imported."""
    from varbook.plot.variant import model_specificity_barplot
    assert hasattr(model_specificity_barplot, 'plot_variant_model_specificity_barplot')


def test_model_specificity_barplot_basic():
    """Test basic model specificity barplot generation."""
    from varbook.plot.variant.model_specificity_barplot import plot_variant_model_specificity_barplot

    # Create test data
    df = pd.DataFrame({
        'variant_id': ['var_1', 'var_2'],
        'is_prioritized_by-KUN_FB_model1': [True, False],
        'is_prioritized_by-KUN_FB_model2': [True, True],
        'is_prioritized_by-KUN_FB_model3': [False, False],
    })

    # Add organ metadata (simplified for testing)
    organ_mapping = {
        'KUN_FB_model1': 'Brain',
        'KUN_FB_model2': 'Brain',
        'KUN_FB_model3': 'Heart',
    }

    with tempfile.TemporaryDirectory() as temp_dir:
        output_file = Path(temp_dir) / 'barplot.png'

        plot_variant_model_specificity_barplot(
            df=df,
            variant_id='var_1',
            prioritization_col='is_prioritized_by',
            datasets=['KUN_FB'],
            organ_mapping=organ_mapping,
            output_file=str(output_file)
        )

        # Check that file was created
        assert output_file.exists()


def test_plot_models_heatmap_import():
    """Test that heatmap module can be imported."""
    from varbook.plot.models import heatmap
    assert hasattr(heatmap, 'plot_models_heatmap')


def test_plot_models_heatmap_basic():
    """Test basic heatmap generation."""
    from varbook.plot.models.heatmap import plot_models_heatmap

    # Create test data
    df = pd.DataFrame({
        'variant_id': [f'var_{i}' for i in range(10)],
        'logfc-KUN_FB_model1': np.random.randn(10),
        'logfc-KUN_FB_model2': np.random.randn(10),
        'logfc-KUN_FB_model3': np.random.randn(10),
    })

    with tempfile.TemporaryDirectory() as temp_dir:
        output_file = Path(temp_dir) / 'heatmap.png'

        plot_models_heatmap(
            df=df,
            heatmap_col='logfc',
            x_col='variant_id',
            y_col='model',
            datasets=['KUN_FB'],
            output_file=str(output_file),
            cluster_rows=False,
            cluster_cols=False
        )

        # Check that file was created
        assert output_file.exists()
