"""
Tests for varbook.annotate.kmeans module.
"""

import pytest
import pandas as pd
import numpy as np
from pathlib import Path
import tempfile


def test_kmeans_import():
    """Test that kmeans module can be imported."""
    from varbook.annotate import kmeans
    assert hasattr(kmeans, 'perform_kmeans_clustering')


def test_kmeans_basic_clustering():
    """Test basic KMeans clustering functionality."""
    from varbook.annotate.kmeans import perform_kmeans_clustering

    # Create a simple test dataset
    df = pd.DataFrame({
        'variant_id': [f'var_{i}' for i in range(100)],
        'logfc_KUN_FB_model1': np.random.randn(100),
        'logfc_KUN_FB_model2': np.random.randn(100),
        'logfc_KUN_FB_model3': np.random.randn(100),
    })

    # Perform clustering
    result_df = perform_kmeans_clustering(
        df=df,
        variant_id_col='variant_id',
        cluster_col='kmeans_5-KUN_FB',
        n_clusters=5,
        feature_columns=['logfc_KUN_FB_model1', 'logfc_KUN_FB_model2', 'logfc_KUN_FB_model3']
    )

    # Check that cluster column was added
    assert 'kmeans_5-KUN_FB' in result_df.columns

    # Check that we have the correct number of clusters
    unique_clusters = result_df['kmeans_5-KUN_FB'].nunique()
    assert unique_clusters <= 5  # Could be less if not enough data

    # Check that all variants are assigned to a cluster
    assert result_df['kmeans_5-KUN_FB'].notna().all()


def test_kmeans_with_missing_values():
    """Test KMeans clustering handles missing values."""
    from varbook.annotate.kmeans import perform_kmeans_clustering

    # Create dataset with missing values
    df = pd.DataFrame({
        'variant_id': [f'var_{i}' for i in range(100)],
        'logfc_KUN_FB_model1': np.random.randn(100),
        'logfc_KUN_FB_model2': np.random.randn(100),
    })

    # Add some NaN values
    df.loc[0:10, 'logfc_KUN_FB_model1'] = np.nan

    result_df = perform_kmeans_clustering(
        df=df,
        variant_id_col='variant_id',
        cluster_col='kmeans_3-KUN_FB',
        n_clusters=3,
        feature_columns=['logfc_KUN_FB_model1', 'logfc_KUN_FB_model2']
    )

    # Should still produce clusters (NaNs filled with 0)
    assert 'kmeans_3-KUN_FB' in result_df.columns
    assert result_df['kmeans_3-KUN_FB'].notna().all()


def test_kmeans_single_cluster():
    """Test KMeans with n_clusters=1."""
    from varbook.annotate.kmeans import perform_kmeans_clustering

    df = pd.DataFrame({
        'variant_id': ['var_1', 'var_2', 'var_3'],
        'score_KUN_FB_model1': [1.0, 2.0, 3.0],
    })

    result_df = perform_kmeans_clustering(
        df=df,
        variant_id_col='variant_id',
        cluster_col='kmeans_1-KUN_FB',
        n_clusters=1,
        feature_columns=['score_KUN_FB_model1']
    )

    # All should be in cluster 0
    assert (result_df['kmeans_1-KUN_FB'] == 0).all()


def test_kmeans_output_file(tmp_path):
    """Test KMeans clustering with output file."""
    from varbook.annotate.kmeans import perform_kmeans_clustering

    df = pd.DataFrame({
        'variant_id': [f'var_{i}' for i in range(50)],
        'logfc_KUN_FB_model1': np.random.randn(50),
        'logfc_KUN_FB_model2': np.random.randn(50),
    })

    output_file = tmp_path / "output.tsv"

    result_df = perform_kmeans_clustering(
        df=df,
        variant_id_col='variant_id',
        cluster_col='kmeans_3-KUN_FB',
        n_clusters=3,
        feature_columns=['logfc_KUN_FB_model1', 'logfc_KUN_FB_model2'],
        output_file=str(output_file)
    )

    # Check that file was created
    assert output_file.exists()

    # Check that file can be read
    saved_df = pd.read_csv(output_file, sep='\t')
    assert len(saved_df) == len(result_df)
    assert 'kmeans_3-KUN_FB' in saved_df.columns
