PayShield-ML / tests /test_models /test_metrics.py
Sibi Krishnamoorthy
prod
8a08300
"""
Tests for Evaluation Metrics.
Tests threshold optimization and metric calculation logic.
"""
import numpy as np
import pytest
from src.models.metrics import calculate_metrics, find_optimal_threshold
class TestCalculateMetrics:
"""Test metric calculation."""
def test_perfect_predictions(self):
"""Test metrics with perfect predictions."""
y_true = np.array([0, 0, 1, 1, 1])
y_prob = np.array([0.1, 0.2, 0.9, 0.95, 0.99])
metrics = calculate_metrics(y_true, y_prob, threshold=0.5)
assert metrics["precision"] == 1.0
assert metrics["recall"] == 1.0
assert metrics["f1"] == 1.0
assert metrics["pr_auc"] > 0.99
def test_random_predictions(self):
"""Test metrics with random predictions."""
np.random.seed(42)
y_true = np.random.randint(0, 2, 100)
y_prob = np.random.random(100)
metrics = calculate_metrics(y_true, y_prob, threshold=0.5)
# Random predictions should have low metrics
assert 0 <= metrics["precision"] <= 1
assert 0 <= metrics["recall"] <= 1
assert 0 <= metrics["f1"] <= 1
assert 0 <= metrics["pr_auc"] <= 1
class TestFindOptimalThreshold:
"""Test threshold optimization."""
def test_finds_threshold_meeting_recall(self):
"""Test that threshold meets recall requirement."""
# Create imbalanced dataset (like fraud)
np.random.seed(42)
n_samples = 1000
# 95% negative, 5% positive
y_true = np.array([0] * 950 + [1] * 50)
# Model that's good but not perfect
# Positive class gets higher probabilities
y_prob = np.concatenate(
[
np.random.beta(2, 5, 950), # Negative class: low probs
np.random.beta(5, 2, 50), # Positive class: high probs
]
)
threshold, metrics = find_optimal_threshold(y_true, y_prob, min_recall=0.70)
# Should find a valid threshold
assert 0 < threshold < 1
# Should meet or come close to recall target
assert metrics["recall"] >= 0.4 # At least reasonable
def test_fallback_to_f1(self):
"""Test fallback to F1 when recall target can't be met."""
# Very difficult scenario
y_true = np.array([0, 0, 0, 0, 1])
y_prob = np.array([0.1, 0.2, 0.3, 0.4, 0.5])
# Impossible to get 99% recall with this data
threshold, metrics = find_optimal_threshold(y_true, y_prob, min_recall=0.99)
# Should still return something valid
assert 0 < threshold < 1
assert metrics["f1"] >= 0
def test_threshold_range(self):
"""Test that found threshold is in valid range."""
np.random.seed(42)
y_true = np.random.randint(0, 2, 100)
y_prob = np.random.random(100)
threshold, _ = find_optimal_threshold(y_true, y_prob)
assert 0 <= threshold <= 1