| """ |
| Evaluation utilities for beat and downbeat detection. |
| |
| This module provides functions to evaluate beat/downbeat predictions against |
| ground truth annotations using F1-scores at various timing thresholds and |
| continuity-based metrics (CMLt, AMLt). |
| |
| The evaluation metrics include: |
| - **F1-scores**: Calculated for timing thresholds from 3ms to 30ms |
| - **Weighted F1**: Weights are inversely proportional to threshold (e.g., 3ms: 1, 6ms: 1/2) |
| - **CMLt (Correct Metrical Level Total)**: Accuracy at the correct metrical level |
| - **AMLt (Any Metrical Level Total)**: Accuracy allowing for metrical variations |
| (double/half tempo, off-beat, etc.) |
| - **CMLc/AMLc**: Continuous versions (longest correct segment) |
| |
| Example usage: |
| from ..data.eval import ( |
| evaluate_beats, evaluate_all, compute_weighted_f1, |
| compute_continuity_metrics, format_results |
| ) |
| |
| # Evaluate single track |
| results = evaluate_beats(pred_beats, gt_beats) |
| print(f"Weighted F1: {results['weighted_f1']:.4f}") |
| print(f"CMLt: {results['continuity']['CMLt']:.4f}") |
| print(f"AMLt: {results['continuity']['AMLt']:.4f}") |
| |
| # Evaluate with custom thresholds |
| results = evaluate_beats(pred_beats, gt_beats, thresholds_ms=[5, 10, 20]) |
| |
| # Evaluate all tracks in dataset |
| summary = evaluate_all(predictions, ground_truths) |
| print(format_results(summary)) |
| """ |
|
|
| from typing import Sequence |
| import numpy as np |
| import mir_eval |
|
|
|
|
| |
| DEFAULT_THRESHOLDS_MS = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30] |
|
|
| |
| DEFAULT_MIN_BEAT_TIME = 5.0 |
|
|
|
|
| def match_events( |
| pred: np.ndarray, |
| gt: np.ndarray, |
| tolerance_sec: float, |
| ) -> tuple[int, int, int]: |
| """ |
| Match predicted events to ground truth events within a tolerance. |
| |
| Uses greedy matching: each ground truth event is matched to the closest |
| unmatched prediction within the tolerance window. |
| |
| Args: |
| pred: Predicted event times in seconds, shape (N,) |
| gt: Ground truth event times in seconds, shape (M,) |
| tolerance_sec: Maximum time difference for a match (in seconds) |
| |
| Returns: |
| Tuple of (true_positives, false_positives, false_negatives) |
| """ |
| if len(gt) == 0: |
| return 0, len(pred), 0 |
| if len(pred) == 0: |
| return 0, 0, len(gt) |
|
|
| pred = np.sort(pred) |
| gt = np.sort(gt) |
|
|
| matched_pred = np.zeros(len(pred), dtype=bool) |
| matched_gt = np.zeros(len(gt), dtype=bool) |
|
|
| |
| for i, gt_time in enumerate(gt): |
| |
| diffs = np.abs(pred - gt_time) |
| candidates = np.where((diffs <= tolerance_sec) & ~matched_pred)[0] |
|
|
| if len(candidates) > 0: |
| |
| best_idx = candidates[np.argmin(diffs[candidates])] |
| matched_pred[best_idx] = True |
| matched_gt[i] = True |
|
|
| tp = int(matched_gt.sum()) |
| fp = int((~matched_pred).sum() == 0 and len(pred) - tp or len(pred) - tp) |
| fn = int(len(gt) - tp) |
|
|
| |
| fp = len(pred) - tp |
|
|
| return tp, fp, fn |
|
|
|
|
| def compute_f1(tp: int, fp: int, fn: int) -> tuple[float, float, float]: |
| """ |
| Compute precision, recall, and F1-score from TP, FP, FN counts. |
| |
| Args: |
| tp: True positives |
| fp: False positives |
| fn: False negatives |
| |
| Returns: |
| Tuple of (precision, recall, f1_score) |
| """ |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| f1 = ( |
| 2 * precision * recall / (precision + recall) |
| if (precision + recall) > 0 |
| else 0.0 |
| ) |
| return precision, recall, f1 |
|
|
|
|
| def compute_weighted_f1( |
| f1_scores: dict[int, float], |
| thresholds_ms: Sequence[int] | None = None, |
| ) -> float: |
| """ |
| Compute weighted F1-score where weights are inversely proportional to threshold. |
| |
| The weight for threshold T ms is 1 / (T / min_threshold). |
| For example, with thresholds [3, 6, 9, ...]: |
| - 3ms: weight = 1 |
| - 6ms: weight = 0.5 |
| - 9ms: weight = 0.333... |
| |
| Args: |
| f1_scores: Dict mapping threshold (ms) to F1-score |
| thresholds_ms: List of thresholds used (for weight calculation) |
| |
| Returns: |
| Weighted F1-score |
| """ |
| if not f1_scores: |
| return 0.0 |
|
|
| if thresholds_ms is None: |
| thresholds_ms = sorted(f1_scores.keys()) |
|
|
| min_threshold = min(thresholds_ms) |
| total_weight = 0.0 |
| weighted_sum = 0.0 |
|
|
| for t in thresholds_ms: |
| if t in f1_scores: |
| weight = min_threshold / t |
| weighted_sum += weight * f1_scores[t] |
| total_weight += weight |
|
|
| return weighted_sum / total_weight if total_weight > 0 else 0.0 |
|
|
|
|
| def compute_continuity_metrics( |
| pred_times: Sequence[float], |
| gt_times: Sequence[float], |
| min_beat_time: float = DEFAULT_MIN_BEAT_TIME, |
| phase_threshold: float = 0.175, |
| period_threshold: float = 0.175, |
| ) -> dict: |
| """ |
| Compute continuity-based beat tracking metrics using mir_eval. |
| |
| These metrics evaluate beat tracking accuracy accounting for metrical level: |
| - CMLt (Correct Metric Level Total): Accuracy at the correct metrical level |
| - AMLt (Any Metric Level Total): Accuracy allowing for metrical variations |
| (double/half tempo, off-beat, etc.) |
| - CMLc/AMLc: Continuous versions (longest correct segment) |
| |
| Args: |
| pred_times: Predicted beat times in seconds |
| gt_times: Ground truth beat times in seconds |
| min_beat_time: Minimum time to start evaluation (default: 5.0s) |
| Set to 0.0 to use all beats, but note that early beats |
| may not have stable inter-beat intervals. |
| phase_threshold: Maximum phase error as ratio of beat interval (default: 0.175) |
| period_threshold: Maximum period error as ratio of beat interval (default: 0.175) |
| |
| Returns: |
| Dict containing: |
| - 'CMLc': Correct Metric Level Continuous |
| - 'CMLt': Correct Metric Level Total |
| - 'AMLc': Any Metric Level Continuous |
| - 'AMLt': Any Metric Level Total |
| """ |
| pred_arr = np.sort(np.array(pred_times, dtype=np.float64)) |
| gt_arr = np.sort(np.array(gt_times, dtype=np.float64)) |
|
|
| |
| pred_trimmed = mir_eval.beat.trim_beats(pred_arr, min_beat_time=min_beat_time) |
| gt_trimmed = mir_eval.beat.trim_beats(gt_arr, min_beat_time=min_beat_time) |
|
|
| |
| if len(gt_trimmed) < 2 or len(pred_trimmed) < 2: |
| return { |
| "CMLc": 0.0, |
| "CMLt": 0.0, |
| "AMLc": 0.0, |
| "AMLt": 0.0, |
| } |
|
|
| |
| CMLc, CMLt, AMLc, AMLt = mir_eval.beat.continuity( |
| gt_trimmed, |
| pred_trimmed, |
| continuity_phase_threshold=phase_threshold, |
| continuity_period_threshold=period_threshold, |
| ) |
|
|
| return { |
| "CMLc": float(CMLc), |
| "CMLt": float(CMLt), |
| "AMLc": float(AMLc), |
| "AMLt": float(AMLt), |
| } |
|
|
|
|
| def evaluate_beats( |
| pred_times: Sequence[float], |
| gt_times: Sequence[float], |
| thresholds_ms: Sequence[int] | None = None, |
| min_beat_time: float = DEFAULT_MIN_BEAT_TIME, |
| ) -> dict: |
| """ |
| Evaluate beat predictions against ground truth at multiple thresholds. |
| |
| Args: |
| pred_times: Predicted beat times in seconds |
| gt_times: Ground truth beat times in seconds |
| thresholds_ms: Timing thresholds in milliseconds (default: 3ms to 30ms) |
| min_beat_time: Minimum time for continuity metrics (default: 5.0s) |
| |
| Returns: |
| Dict containing: |
| - 'per_threshold': Dict[threshold_ms, {'precision', 'recall', 'f1'}] |
| - 'f1_scores': Dict[threshold_ms, f1_score] (convenience access) |
| - 'weighted_f1': Weighted F1-score across all thresholds |
| - 'continuity': Dict with CMLc, CMLt, AMLc, AMLt metrics |
| - 'num_predictions': Number of predictions |
| - 'num_ground_truth': Number of ground truth events |
| """ |
| if thresholds_ms is None: |
| thresholds_ms = DEFAULT_THRESHOLDS_MS |
|
|
| pred_arr = np.array(pred_times, dtype=np.float64) |
| gt_arr = np.array(gt_times, dtype=np.float64) |
|
|
| per_threshold = {} |
| f1_scores = {} |
|
|
| for threshold_ms in thresholds_ms: |
| tolerance_sec = threshold_ms / 1000.0 |
| tp, fp, fn = match_events(pred_arr, gt_arr, tolerance_sec) |
| precision, recall, f1 = compute_f1(tp, fp, fn) |
|
|
| per_threshold[threshold_ms] = { |
| "precision": precision, |
| "recall": recall, |
| "f1": f1, |
| "tp": tp, |
| "fp": fp, |
| "fn": fn, |
| } |
| f1_scores[threshold_ms] = f1 |
|
|
| weighted_f1 = compute_weighted_f1(f1_scores, thresholds_ms) |
| continuity = compute_continuity_metrics(pred_times, gt_times, min_beat_time) |
|
|
| return { |
| "per_threshold": per_threshold, |
| "f1_scores": f1_scores, |
| "weighted_f1": weighted_f1, |
| "continuity": continuity, |
| "num_predictions": len(pred_arr), |
| "num_ground_truth": len(gt_arr), |
| } |
|
|
|
|
| def evaluate_track( |
| pred_beats: Sequence[float], |
| pred_downbeats: Sequence[float], |
| gt_beats: Sequence[float], |
| gt_downbeats: Sequence[float], |
| thresholds_ms: Sequence[int] | None = None, |
| min_beat_time: float = DEFAULT_MIN_BEAT_TIME, |
| ) -> dict: |
| """ |
| Evaluate both beat and downbeat predictions for a single track. |
| |
| Args: |
| pred_beats: Predicted beat times in seconds |
| pred_downbeats: Predicted downbeat times in seconds |
| gt_beats: Ground truth beat times in seconds |
| gt_downbeats: Ground truth downbeat times in seconds |
| thresholds_ms: Timing thresholds in milliseconds |
| min_beat_time: Minimum time for continuity metrics (default: 5.0s) |
| |
| Returns: |
| Dict containing: |
| - 'beats': Results from evaluate_beats for beats |
| - 'downbeats': Results from evaluate_beats for downbeats |
| - 'combined_weighted_f1': Average of beat and downbeat weighted F1 |
| """ |
| beat_results = evaluate_beats(pred_beats, gt_beats, thresholds_ms, min_beat_time) |
| downbeat_results = evaluate_beats( |
| pred_downbeats, gt_downbeats, thresholds_ms, min_beat_time |
| ) |
|
|
| combined_weighted_f1 = ( |
| beat_results["weighted_f1"] + downbeat_results["weighted_f1"] |
| ) / 2 |
|
|
| return { |
| "beats": beat_results, |
| "downbeats": downbeat_results, |
| "combined_weighted_f1": combined_weighted_f1, |
| } |
|
|
|
|
| def evaluate_all( |
| predictions: Sequence[dict], |
| ground_truths: Sequence[dict], |
| thresholds_ms: Sequence[int] | None = None, |
| min_beat_time: float = DEFAULT_MIN_BEAT_TIME, |
| verbose: bool = False, |
| ) -> dict: |
| """ |
| Evaluate predictions for multiple tracks. |
| |
| Args: |
| predictions: List of dicts with 'beats' and 'downbeats' keys |
| ground_truths: List of dicts with 'beats' and 'downbeats' keys |
| thresholds_ms: Timing thresholds in milliseconds |
| min_beat_time: Minimum time for continuity metrics (default: 5.0s) |
| verbose: If True, print per-track results |
| |
| Returns: |
| Dict containing: |
| - 'per_track': List of per-track results |
| - 'mean_beat_weighted_f1': Mean weighted F1 for beats |
| - 'mean_downbeat_weighted_f1': Mean weighted F1 for downbeats |
| - 'mean_combined_weighted_f1': Mean combined weighted F1 |
| - 'beat_f1_by_threshold': Mean F1 per threshold for beats |
| - 'downbeat_f1_by_threshold': Mean F1 per threshold for downbeats |
| - 'beat_continuity': Mean continuity metrics for beats |
| - 'downbeat_continuity': Mean continuity metrics for downbeats |
| """ |
| if len(predictions) != len(ground_truths): |
| raise ValueError( |
| f"Number of predictions ({len(predictions)}) must match " |
| f"number of ground truths ({len(ground_truths)})" |
| ) |
|
|
| if thresholds_ms is None: |
| thresholds_ms = DEFAULT_THRESHOLDS_MS |
|
|
| per_track = [] |
| beat_weighted_f1s = [] |
| downbeat_weighted_f1s = [] |
| combined_weighted_f1s = [] |
|
|
| beat_f1_by_threshold = {t: [] for t in thresholds_ms} |
| downbeat_f1_by_threshold = {t: [] for t in thresholds_ms} |
|
|
| |
| beat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []} |
| downbeat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []} |
|
|
| for i, (pred, gt) in enumerate(zip(predictions, ground_truths)): |
| result = evaluate_track( |
| pred_beats=pred["beats"], |
| pred_downbeats=pred["downbeats"], |
| gt_beats=gt["beats"], |
| gt_downbeats=gt["downbeats"], |
| thresholds_ms=thresholds_ms, |
| min_beat_time=min_beat_time, |
| ) |
|
|
| per_track.append(result) |
| beat_weighted_f1s.append(result["beats"]["weighted_f1"]) |
| downbeat_weighted_f1s.append(result["downbeats"]["weighted_f1"]) |
| combined_weighted_f1s.append(result["combined_weighted_f1"]) |
|
|
| for t in thresholds_ms: |
| beat_f1_by_threshold[t].append(result["beats"]["f1_scores"][t]) |
| downbeat_f1_by_threshold[t].append(result["downbeats"]["f1_scores"][t]) |
|
|
| |
| for metric in ["CMLc", "CMLt", "AMLc", "AMLt"]: |
| beat_continuity[metric].append(result["beats"]["continuity"][metric]) |
| downbeat_continuity[metric].append( |
| result["downbeats"]["continuity"][metric] |
| ) |
|
|
| if verbose: |
| beat_cont = result["beats"]["continuity"] |
| print( |
| f"Track {i}: Beat F1={result['beats']['weighted_f1']:.4f}, " |
| f"CMLt={beat_cont['CMLt']:.4f}, AMLt={beat_cont['AMLt']:.4f}, " |
| f"Downbeat F1={result['downbeats']['weighted_f1']:.4f}, " |
| f"Combined={result['combined_weighted_f1']:.4f}" |
| ) |
|
|
| return { |
| "per_track": per_track, |
| "mean_beat_weighted_f1": float(np.mean(beat_weighted_f1s)), |
| "mean_downbeat_weighted_f1": float(np.mean(downbeat_weighted_f1s)), |
| "mean_combined_weighted_f1": float(np.mean(combined_weighted_f1s)), |
| "beat_f1_by_threshold": { |
| t: float(np.mean(v)) for t, v in beat_f1_by_threshold.items() |
| }, |
| "downbeat_f1_by_threshold": { |
| t: float(np.mean(v)) for t, v in downbeat_f1_by_threshold.items() |
| }, |
| "beat_continuity": { |
| metric: float(np.mean(values)) for metric, values in beat_continuity.items() |
| }, |
| "downbeat_continuity": { |
| metric: float(np.mean(values)) |
| for metric, values in downbeat_continuity.items() |
| }, |
| "num_tracks": len(predictions), |
| } |
|
|
|
|
| def format_results(results: dict, title: str = "Evaluation Results") -> str: |
| """ |
| Format evaluation results as a human-readable string. |
| |
| Args: |
| results: Results dict from evaluate_all or evaluate_track |
| title: Title for the report |
| |
| Returns: |
| Formatted string report |
| """ |
| lines = [title, "=" * len(title), ""] |
|
|
| |
| if "num_tracks" in results: |
| lines.append(f"Number of tracks: {results['num_tracks']}") |
| lines.append("") |
| lines.append("Overall Metrics:") |
| lines.append( |
| f" Mean Beat Weighted F1: {results['mean_beat_weighted_f1']:.4f}" |
| ) |
| lines.append( |
| f" Mean Downbeat Weighted F1: {results['mean_downbeat_weighted_f1']:.4f}" |
| ) |
| lines.append( |
| f" Mean Combined Weighted F1: {results['mean_combined_weighted_f1']:.4f}" |
| ) |
| lines.append("") |
|
|
| lines.append("Beat F1 by Threshold:") |
| for t, f1 in sorted(results["beat_f1_by_threshold"].items()): |
| lines.append(f" {t:2d}ms: {f1:.4f}") |
| lines.append("") |
|
|
| lines.append("Downbeat F1 by Threshold:") |
| for t, f1 in sorted(results["downbeat_f1_by_threshold"].items()): |
| lines.append(f" {t:2d}ms: {f1:.4f}") |
| lines.append("") |
|
|
| |
| if "beat_continuity" in results: |
| lines.append("Beat Continuity Metrics:") |
| bc = results["beat_continuity"] |
| lines.append(f" CMLt: {bc['CMLt']:.4f} (Correct Metrical Level Total)") |
| lines.append(f" AMLt: {bc['AMLt']:.4f} (Any Metrical Level Total)") |
| lines.append( |
| f" CMLc: {bc['CMLc']:.4f} (Correct Metrical Level Continuous)" |
| ) |
| lines.append(f" AMLc: {bc['AMLc']:.4f} (Any Metrical Level Continuous)") |
| lines.append("") |
|
|
| if "downbeat_continuity" in results: |
| lines.append("Downbeat Continuity Metrics:") |
| dc = results["downbeat_continuity"] |
| lines.append(f" CMLt: {dc['CMLt']:.4f} (Correct Metrical Level Total)") |
| lines.append(f" AMLt: {dc['AMLt']:.4f} (Any Metrical Level Total)") |
| lines.append( |
| f" CMLc: {dc['CMLc']:.4f} (Correct Metrical Level Continuous)" |
| ) |
| lines.append(f" AMLc: {dc['AMLc']:.4f} (Any Metrical Level Continuous)") |
|
|
| |
| elif "beats" in results and "downbeats" in results: |
| lines.append("Beat Detection:") |
| lines.append(f" Weighted F1: {results['beats']['weighted_f1']:.4f}") |
| lines.append(f" Predictions: {results['beats']['num_predictions']}") |
| lines.append(f" Ground Truth: {results['beats']['num_ground_truth']}") |
|
|
| |
| if "continuity" in results["beats"]: |
| bc = results["beats"]["continuity"] |
| lines.append(f" CMLt: {bc['CMLt']:.4f} AMLt: {bc['AMLt']:.4f}") |
| lines.append(f" CMLc: {bc['CMLc']:.4f} AMLc: {bc['AMLc']:.4f}") |
| lines.append("") |
|
|
| lines.append("Downbeat Detection:") |
| lines.append(f" Weighted F1: {results['downbeats']['weighted_f1']:.4f}") |
| lines.append(f" Predictions: {results['downbeats']['num_predictions']}") |
| lines.append(f" Ground Truth: {results['downbeats']['num_ground_truth']}") |
|
|
| |
| if "continuity" in results["downbeats"]: |
| dc = results["downbeats"]["continuity"] |
| lines.append(f" CMLt: {dc['CMLt']:.4f} AMLt: {dc['AMLt']:.4f}") |
| lines.append(f" CMLc: {dc['CMLc']:.4f} AMLc: {dc['AMLc']:.4f}") |
| lines.append("") |
|
|
| lines.append(f"Combined Weighted F1: {results['combined_weighted_f1']:.4f}") |
|
|
| return "\n".join(lines) |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("Running evaluation demo...\n") |
|
|
| |
| gt_beats = np.arange(0, 30, 0.5).tolist() |
| gt_downbeats = np.arange(0, 30, 2.0).tolist() |
|
|
| |
| np.random.seed(42) |
| pred_beats = ( |
| np.array(gt_beats) + np.random.normal(0, 0.005, len(gt_beats)) |
| ).tolist() |
| pred_beats = pred_beats[:-2] |
| pred_beats.append(15.25) |
|
|
| pred_downbeats = ( |
| np.array(gt_downbeats) + np.random.normal(0, 0.003, len(gt_downbeats)) |
| ).tolist() |
|
|
| |
| results = evaluate_track( |
| pred_beats=pred_beats, |
| pred_downbeats=pred_downbeats, |
| gt_beats=gt_beats, |
| gt_downbeats=gt_downbeats, |
| ) |
|
|
| print(format_results(results, "Single Track Demo")) |
| print("\n" + "=" * 50 + "\n") |
|
|
| |
| predictions = [ |
| {"beats": pred_beats, "downbeats": pred_downbeats}, |
| {"beats": pred_beats, "downbeats": pred_downbeats}, |
| ] |
| ground_truths = [ |
| {"beats": gt_beats, "downbeats": gt_downbeats}, |
| {"beats": gt_beats, "downbeats": gt_downbeats}, |
| ] |
|
|
| all_results = evaluate_all(predictions, ground_truths, verbose=True) |
| print() |
| print(format_results(all_results, "Multi-Track Demo")) |
|
|