| | """ |
| | Evaluation utilities for beat and downbeat detection. |
| | |
| | This module provides functions to evaluate beat/downbeat predictions against |
| | ground truth annotations using F1-scores at various timing thresholds and |
| | continuity-based metrics (CMLt, AMLt). |
| | |
| | The evaluation metrics include: |
| | - **F1-scores**: Calculated for timing thresholds from 3ms to 30ms |
| | - **Weighted F1**: Weights are inversely proportional to threshold (e.g., 3ms: 1, 6ms: 1/2) |
| | - **CMLt (Correct Metrical Level Total)**: Accuracy at the correct metrical level |
| | - **AMLt (Any Metrical Level Total)**: Accuracy allowing for metrical variations |
| | (double/half tempo, off-beat, etc.) |
| | - **CMLc/AMLc**: Continuous versions (longest correct segment) |
| | |
| | Example usage: |
| | from ..data.eval import ( |
| | evaluate_beats, evaluate_all, compute_weighted_f1, |
| | compute_continuity_metrics, format_results |
| | ) |
| | |
| | # Evaluate single track |
| | results = evaluate_beats(pred_beats, gt_beats) |
| | print(f"Weighted F1: {results['weighted_f1']:.4f}") |
| | print(f"CMLt: {results['continuity']['CMLt']:.4f}") |
| | print(f"AMLt: {results['continuity']['AMLt']:.4f}") |
| | |
| | # Evaluate with custom thresholds |
| | results = evaluate_beats(pred_beats, gt_beats, thresholds_ms=[5, 10, 20]) |
| | |
| | # Evaluate all tracks in dataset |
| | summary = evaluate_all(predictions, ground_truths) |
| | print(format_results(summary)) |
| | """ |
| |
|
| | from typing import Sequence |
| | import numpy as np |
| | import mir_eval |
| |
|
| |
|
| | |
| | DEFAULT_THRESHOLDS_MS = [3, 6, 9, 12, 15, 18, 21, 24, 27, 30] |
| |
|
| | |
| | DEFAULT_MIN_BEAT_TIME = 5.0 |
| |
|
| |
|
| | def match_events( |
| | pred: np.ndarray, |
| | gt: np.ndarray, |
| | tolerance_sec: float, |
| | ) -> tuple[int, int, int]: |
| | """ |
| | Match predicted events to ground truth events within a tolerance. |
| | |
| | Uses greedy matching: each ground truth event is matched to the closest |
| | unmatched prediction within the tolerance window. |
| | |
| | Args: |
| | pred: Predicted event times in seconds, shape (N,) |
| | gt: Ground truth event times in seconds, shape (M,) |
| | tolerance_sec: Maximum time difference for a match (in seconds) |
| | |
| | Returns: |
| | Tuple of (true_positives, false_positives, false_negatives) |
| | """ |
| | if len(gt) == 0: |
| | return 0, len(pred), 0 |
| | if len(pred) == 0: |
| | return 0, 0, len(gt) |
| |
|
| | pred = np.sort(pred) |
| | gt = np.sort(gt) |
| |
|
| | matched_pred = np.zeros(len(pred), dtype=bool) |
| | matched_gt = np.zeros(len(gt), dtype=bool) |
| |
|
| | |
| | for i, gt_time in enumerate(gt): |
| | |
| | diffs = np.abs(pred - gt_time) |
| | candidates = np.where((diffs <= tolerance_sec) & ~matched_pred)[0] |
| |
|
| | if len(candidates) > 0: |
| | |
| | best_idx = candidates[np.argmin(diffs[candidates])] |
| | matched_pred[best_idx] = True |
| | matched_gt[i] = True |
| |
|
| | tp = int(matched_gt.sum()) |
| | fp = int((~matched_pred).sum() == 0 and len(pred) - tp or len(pred) - tp) |
| | fn = int(len(gt) - tp) |
| |
|
| | |
| | fp = len(pred) - tp |
| |
|
| | return tp, fp, fn |
| |
|
| |
|
| | def compute_f1(tp: int, fp: int, fn: int) -> tuple[float, float, float]: |
| | """ |
| | Compute precision, recall, and F1-score from TP, FP, FN counts. |
| | |
| | Args: |
| | tp: True positives |
| | fp: False positives |
| | fn: False negatives |
| | |
| | Returns: |
| | Tuple of (precision, recall, f1_score) |
| | """ |
| | precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| | recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| | f1 = ( |
| | 2 * precision * recall / (precision + recall) |
| | if (precision + recall) > 0 |
| | else 0.0 |
| | ) |
| | return precision, recall, f1 |
| |
|
| |
|
| | def compute_weighted_f1( |
| | f1_scores: dict[int, float], |
| | thresholds_ms: Sequence[int] | None = None, |
| | ) -> float: |
| | """ |
| | Compute weighted F1-score where weights are inversely proportional to threshold. |
| | |
| | The weight for threshold T ms is 1 / (T / min_threshold). |
| | For example, with thresholds [3, 6, 9, ...]: |
| | - 3ms: weight = 1 |
| | - 6ms: weight = 0.5 |
| | - 9ms: weight = 0.333... |
| | |
| | Args: |
| | f1_scores: Dict mapping threshold (ms) to F1-score |
| | thresholds_ms: List of thresholds used (for weight calculation) |
| | |
| | Returns: |
| | Weighted F1-score |
| | """ |
| | if not f1_scores: |
| | return 0.0 |
| |
|
| | if thresholds_ms is None: |
| | thresholds_ms = sorted(f1_scores.keys()) |
| |
|
| | min_threshold = min(thresholds_ms) |
| | total_weight = 0.0 |
| | weighted_sum = 0.0 |
| |
|
| | for t in thresholds_ms: |
| | if t in f1_scores: |
| | weight = min_threshold / t |
| | weighted_sum += weight * f1_scores[t] |
| | total_weight += weight |
| |
|
| | return weighted_sum / total_weight if total_weight > 0 else 0.0 |
| |
|
| |
|
| | def compute_continuity_metrics( |
| | pred_times: Sequence[float], |
| | gt_times: Sequence[float], |
| | min_beat_time: float = DEFAULT_MIN_BEAT_TIME, |
| | phase_threshold: float = 0.175, |
| | period_threshold: float = 0.175, |
| | ) -> dict: |
| | """ |
| | Compute continuity-based beat tracking metrics using mir_eval. |
| | |
| | These metrics evaluate beat tracking accuracy accounting for metrical level: |
| | - CMLt (Correct Metric Level Total): Accuracy at the correct metrical level |
| | - AMLt (Any Metric Level Total): Accuracy allowing for metrical variations |
| | (double/half tempo, off-beat, etc.) |
| | - CMLc/AMLc: Continuous versions (longest correct segment) |
| | |
| | Args: |
| | pred_times: Predicted beat times in seconds |
| | gt_times: Ground truth beat times in seconds |
| | min_beat_time: Minimum time to start evaluation (default: 5.0s) |
| | Set to 0.0 to use all beats, but note that early beats |
| | may not have stable inter-beat intervals. |
| | phase_threshold: Maximum phase error as ratio of beat interval (default: 0.175) |
| | period_threshold: Maximum period error as ratio of beat interval (default: 0.175) |
| | |
| | Returns: |
| | Dict containing: |
| | - 'CMLc': Correct Metric Level Continuous |
| | - 'CMLt': Correct Metric Level Total |
| | - 'AMLc': Any Metric Level Continuous |
| | - 'AMLt': Any Metric Level Total |
| | """ |
| | pred_arr = np.sort(np.array(pred_times, dtype=np.float64)) |
| | gt_arr = np.sort(np.array(gt_times, dtype=np.float64)) |
| |
|
| | |
| | pred_trimmed = mir_eval.beat.trim_beats(pred_arr, min_beat_time=min_beat_time) |
| | gt_trimmed = mir_eval.beat.trim_beats(gt_arr, min_beat_time=min_beat_time) |
| |
|
| | |
| | if len(gt_trimmed) < 2 or len(pred_trimmed) < 2: |
| | return { |
| | "CMLc": 0.0, |
| | "CMLt": 0.0, |
| | "AMLc": 0.0, |
| | "AMLt": 0.0, |
| | } |
| |
|
| | |
| | CMLc, CMLt, AMLc, AMLt = mir_eval.beat.continuity( |
| | gt_trimmed, |
| | pred_trimmed, |
| | continuity_phase_threshold=phase_threshold, |
| | continuity_period_threshold=period_threshold, |
| | ) |
| |
|
| | return { |
| | "CMLc": float(CMLc), |
| | "CMLt": float(CMLt), |
| | "AMLc": float(AMLc), |
| | "AMLt": float(AMLt), |
| | } |
| |
|
| |
|
| | def evaluate_beats( |
| | pred_times: Sequence[float], |
| | gt_times: Sequence[float], |
| | thresholds_ms: Sequence[int] | None = None, |
| | min_beat_time: float = DEFAULT_MIN_BEAT_TIME, |
| | ) -> dict: |
| | """ |
| | Evaluate beat predictions against ground truth at multiple thresholds. |
| | |
| | Args: |
| | pred_times: Predicted beat times in seconds |
| | gt_times: Ground truth beat times in seconds |
| | thresholds_ms: Timing thresholds in milliseconds (default: 3ms to 30ms) |
| | min_beat_time: Minimum time for continuity metrics (default: 5.0s) |
| | |
| | Returns: |
| | Dict containing: |
| | - 'per_threshold': Dict[threshold_ms, {'precision', 'recall', 'f1'}] |
| | - 'f1_scores': Dict[threshold_ms, f1_score] (convenience access) |
| | - 'weighted_f1': Weighted F1-score across all thresholds |
| | - 'continuity': Dict with CMLc, CMLt, AMLc, AMLt metrics |
| | - 'num_predictions': Number of predictions |
| | - 'num_ground_truth': Number of ground truth events |
| | """ |
| | if thresholds_ms is None: |
| | thresholds_ms = DEFAULT_THRESHOLDS_MS |
| |
|
| | pred_arr = np.array(pred_times, dtype=np.float64) |
| | gt_arr = np.array(gt_times, dtype=np.float64) |
| |
|
| | per_threshold = {} |
| | f1_scores = {} |
| |
|
| | for threshold_ms in thresholds_ms: |
| | tolerance_sec = threshold_ms / 1000.0 |
| | tp, fp, fn = match_events(pred_arr, gt_arr, tolerance_sec) |
| | precision, recall, f1 = compute_f1(tp, fp, fn) |
| |
|
| | per_threshold[threshold_ms] = { |
| | "precision": precision, |
| | "recall": recall, |
| | "f1": f1, |
| | "tp": tp, |
| | "fp": fp, |
| | "fn": fn, |
| | } |
| | f1_scores[threshold_ms] = f1 |
| |
|
| | weighted_f1 = compute_weighted_f1(f1_scores, thresholds_ms) |
| | continuity = compute_continuity_metrics(pred_times, gt_times, min_beat_time) |
| |
|
| | return { |
| | "per_threshold": per_threshold, |
| | "f1_scores": f1_scores, |
| | "weighted_f1": weighted_f1, |
| | "continuity": continuity, |
| | "num_predictions": len(pred_arr), |
| | "num_ground_truth": len(gt_arr), |
| | } |
| |
|
| |
|
| | def evaluate_track( |
| | pred_beats: Sequence[float], |
| | pred_downbeats: Sequence[float], |
| | gt_beats: Sequence[float], |
| | gt_downbeats: Sequence[float], |
| | thresholds_ms: Sequence[int] | None = None, |
| | min_beat_time: float = DEFAULT_MIN_BEAT_TIME, |
| | ) -> dict: |
| | """ |
| | Evaluate both beat and downbeat predictions for a single track. |
| | |
| | Args: |
| | pred_beats: Predicted beat times in seconds |
| | pred_downbeats: Predicted downbeat times in seconds |
| | gt_beats: Ground truth beat times in seconds |
| | gt_downbeats: Ground truth downbeat times in seconds |
| | thresholds_ms: Timing thresholds in milliseconds |
| | min_beat_time: Minimum time for continuity metrics (default: 5.0s) |
| | |
| | Returns: |
| | Dict containing: |
| | - 'beats': Results from evaluate_beats for beats |
| | - 'downbeats': Results from evaluate_beats for downbeats |
| | - 'combined_weighted_f1': Average of beat and downbeat weighted F1 |
| | """ |
| | beat_results = evaluate_beats(pred_beats, gt_beats, thresholds_ms, min_beat_time) |
| | downbeat_results = evaluate_beats( |
| | pred_downbeats, gt_downbeats, thresholds_ms, min_beat_time |
| | ) |
| |
|
| | combined_weighted_f1 = ( |
| | beat_results["weighted_f1"] + downbeat_results["weighted_f1"] |
| | ) / 2 |
| |
|
| | return { |
| | "beats": beat_results, |
| | "downbeats": downbeat_results, |
| | "combined_weighted_f1": combined_weighted_f1, |
| | } |
| |
|
| |
|
| | def evaluate_all( |
| | predictions: Sequence[dict], |
| | ground_truths: Sequence[dict], |
| | thresholds_ms: Sequence[int] | None = None, |
| | min_beat_time: float = DEFAULT_MIN_BEAT_TIME, |
| | verbose: bool = False, |
| | ) -> dict: |
| | """ |
| | Evaluate predictions for multiple tracks. |
| | |
| | Args: |
| | predictions: List of dicts with 'beats' and 'downbeats' keys |
| | ground_truths: List of dicts with 'beats' and 'downbeats' keys |
| | thresholds_ms: Timing thresholds in milliseconds |
| | min_beat_time: Minimum time for continuity metrics (default: 5.0s) |
| | verbose: If True, print per-track results |
| | |
| | Returns: |
| | Dict containing: |
| | - 'per_track': List of per-track results |
| | - 'mean_beat_weighted_f1': Mean weighted F1 for beats |
| | - 'mean_downbeat_weighted_f1': Mean weighted F1 for downbeats |
| | - 'mean_combined_weighted_f1': Mean combined weighted F1 |
| | - 'beat_f1_by_threshold': Mean F1 per threshold for beats |
| | - 'downbeat_f1_by_threshold': Mean F1 per threshold for downbeats |
| | - 'beat_continuity': Mean continuity metrics for beats |
| | - 'downbeat_continuity': Mean continuity metrics for downbeats |
| | """ |
| | if len(predictions) != len(ground_truths): |
| | raise ValueError( |
| | f"Number of predictions ({len(predictions)}) must match " |
| | f"number of ground truths ({len(ground_truths)})" |
| | ) |
| |
|
| | if thresholds_ms is None: |
| | thresholds_ms = DEFAULT_THRESHOLDS_MS |
| |
|
| | per_track = [] |
| | beat_weighted_f1s = [] |
| | downbeat_weighted_f1s = [] |
| | combined_weighted_f1s = [] |
| |
|
| | beat_f1_by_threshold = {t: [] for t in thresholds_ms} |
| | downbeat_f1_by_threshold = {t: [] for t in thresholds_ms} |
| |
|
| | |
| | beat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []} |
| | downbeat_continuity = {"CMLc": [], "CMLt": [], "AMLc": [], "AMLt": []} |
| |
|
| | for i, (pred, gt) in enumerate(zip(predictions, ground_truths)): |
| | result = evaluate_track( |
| | pred_beats=pred["beats"], |
| | pred_downbeats=pred["downbeats"], |
| | gt_beats=gt["beats"], |
| | gt_downbeats=gt["downbeats"], |
| | thresholds_ms=thresholds_ms, |
| | min_beat_time=min_beat_time, |
| | ) |
| |
|
| | per_track.append(result) |
| | beat_weighted_f1s.append(result["beats"]["weighted_f1"]) |
| | downbeat_weighted_f1s.append(result["downbeats"]["weighted_f1"]) |
| | combined_weighted_f1s.append(result["combined_weighted_f1"]) |
| |
|
| | for t in thresholds_ms: |
| | beat_f1_by_threshold[t].append(result["beats"]["f1_scores"][t]) |
| | downbeat_f1_by_threshold[t].append(result["downbeats"]["f1_scores"][t]) |
| |
|
| | |
| | for metric in ["CMLc", "CMLt", "AMLc", "AMLt"]: |
| | beat_continuity[metric].append(result["beats"]["continuity"][metric]) |
| | downbeat_continuity[metric].append( |
| | result["downbeats"]["continuity"][metric] |
| | ) |
| |
|
| | if verbose: |
| | beat_cont = result["beats"]["continuity"] |
| | print( |
| | f"Track {i}: Beat F1={result['beats']['weighted_f1']:.4f}, " |
| | f"CMLt={beat_cont['CMLt']:.4f}, AMLt={beat_cont['AMLt']:.4f}, " |
| | f"Downbeat F1={result['downbeats']['weighted_f1']:.4f}, " |
| | f"Combined={result['combined_weighted_f1']:.4f}" |
| | ) |
| |
|
| | return { |
| | "per_track": per_track, |
| | "mean_beat_weighted_f1": float(np.mean(beat_weighted_f1s)), |
| | "mean_downbeat_weighted_f1": float(np.mean(downbeat_weighted_f1s)), |
| | "mean_combined_weighted_f1": float(np.mean(combined_weighted_f1s)), |
| | "beat_f1_by_threshold": { |
| | t: float(np.mean(v)) for t, v in beat_f1_by_threshold.items() |
| | }, |
| | "downbeat_f1_by_threshold": { |
| | t: float(np.mean(v)) for t, v in downbeat_f1_by_threshold.items() |
| | }, |
| | "beat_continuity": { |
| | metric: float(np.mean(values)) for metric, values in beat_continuity.items() |
| | }, |
| | "downbeat_continuity": { |
| | metric: float(np.mean(values)) |
| | for metric, values in downbeat_continuity.items() |
| | }, |
| | "num_tracks": len(predictions), |
| | } |
| |
|
| |
|
| | def format_results(results: dict, title: str = "Evaluation Results") -> str: |
| | """ |
| | Format evaluation results as a human-readable string. |
| | |
| | Args: |
| | results: Results dict from evaluate_all or evaluate_track |
| | title: Title for the report |
| | |
| | Returns: |
| | Formatted string report |
| | """ |
| | lines = [title, "=" * len(title), ""] |
| |
|
| | |
| | if "num_tracks" in results: |
| | lines.append(f"Number of tracks: {results['num_tracks']}") |
| | lines.append("") |
| | lines.append("Overall Metrics:") |
| | lines.append( |
| | f" Mean Beat Weighted F1: {results['mean_beat_weighted_f1']:.4f}" |
| | ) |
| | lines.append( |
| | f" Mean Downbeat Weighted F1: {results['mean_downbeat_weighted_f1']:.4f}" |
| | ) |
| | lines.append( |
| | f" Mean Combined Weighted F1: {results['mean_combined_weighted_f1']:.4f}" |
| | ) |
| | lines.append("") |
| |
|
| | lines.append("Beat F1 by Threshold:") |
| | for t, f1 in sorted(results["beat_f1_by_threshold"].items()): |
| | lines.append(f" {t:2d}ms: {f1:.4f}") |
| | lines.append("") |
| |
|
| | lines.append("Downbeat F1 by Threshold:") |
| | for t, f1 in sorted(results["downbeat_f1_by_threshold"].items()): |
| | lines.append(f" {t:2d}ms: {f1:.4f}") |
| | lines.append("") |
| |
|
| | |
| | if "beat_continuity" in results: |
| | lines.append("Beat Continuity Metrics:") |
| | bc = results["beat_continuity"] |
| | lines.append(f" CMLt: {bc['CMLt']:.4f} (Correct Metrical Level Total)") |
| | lines.append(f" AMLt: {bc['AMLt']:.4f} (Any Metrical Level Total)") |
| | lines.append( |
| | f" CMLc: {bc['CMLc']:.4f} (Correct Metrical Level Continuous)" |
| | ) |
| | lines.append(f" AMLc: {bc['AMLc']:.4f} (Any Metrical Level Continuous)") |
| | lines.append("") |
| |
|
| | if "downbeat_continuity" in results: |
| | lines.append("Downbeat Continuity Metrics:") |
| | dc = results["downbeat_continuity"] |
| | lines.append(f" CMLt: {dc['CMLt']:.4f} (Correct Metrical Level Total)") |
| | lines.append(f" AMLt: {dc['AMLt']:.4f} (Any Metrical Level Total)") |
| | lines.append( |
| | f" CMLc: {dc['CMLc']:.4f} (Correct Metrical Level Continuous)" |
| | ) |
| | lines.append(f" AMLc: {dc['AMLc']:.4f} (Any Metrical Level Continuous)") |
| |
|
| | |
| | elif "beats" in results and "downbeats" in results: |
| | lines.append("Beat Detection:") |
| | lines.append(f" Weighted F1: {results['beats']['weighted_f1']:.4f}") |
| | lines.append(f" Predictions: {results['beats']['num_predictions']}") |
| | lines.append(f" Ground Truth: {results['beats']['num_ground_truth']}") |
| |
|
| | |
| | if "continuity" in results["beats"]: |
| | bc = results["beats"]["continuity"] |
| | lines.append(f" CMLt: {bc['CMLt']:.4f} AMLt: {bc['AMLt']:.4f}") |
| | lines.append(f" CMLc: {bc['CMLc']:.4f} AMLc: {bc['AMLc']:.4f}") |
| | lines.append("") |
| |
|
| | lines.append("Downbeat Detection:") |
| | lines.append(f" Weighted F1: {results['downbeats']['weighted_f1']:.4f}") |
| | lines.append(f" Predictions: {results['downbeats']['num_predictions']}") |
| | lines.append(f" Ground Truth: {results['downbeats']['num_ground_truth']}") |
| |
|
| | |
| | if "continuity" in results["downbeats"]: |
| | dc = results["downbeats"]["continuity"] |
| | lines.append(f" CMLt: {dc['CMLt']:.4f} AMLt: {dc['AMLt']:.4f}") |
| | lines.append(f" CMLc: {dc['CMLc']:.4f} AMLc: {dc['AMLc']:.4f}") |
| | lines.append("") |
| |
|
| | lines.append(f"Combined Weighted F1: {results['combined_weighted_f1']:.4f}") |
| |
|
| | return "\n".join(lines) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | print("Running evaluation demo...\n") |
| |
|
| | |
| | gt_beats = np.arange(0, 30, 0.5).tolist() |
| | gt_downbeats = np.arange(0, 30, 2.0).tolist() |
| |
|
| | |
| | np.random.seed(42) |
| | pred_beats = ( |
| | np.array(gt_beats) + np.random.normal(0, 0.005, len(gt_beats)) |
| | ).tolist() |
| | pred_beats = pred_beats[:-2] |
| | pred_beats.append(15.25) |
| |
|
| | pred_downbeats = ( |
| | np.array(gt_downbeats) + np.random.normal(0, 0.003, len(gt_downbeats)) |
| | ).tolist() |
| |
|
| | |
| | results = evaluate_track( |
| | pred_beats=pred_beats, |
| | pred_downbeats=pred_downbeats, |
| | gt_beats=gt_beats, |
| | gt_downbeats=gt_downbeats, |
| | ) |
| |
|
| | print(format_results(results, "Single Track Demo")) |
| | print("\n" + "=" * 50 + "\n") |
| |
|
| | |
| | predictions = [ |
| | {"beats": pred_beats, "downbeats": pred_downbeats}, |
| | {"beats": pred_beats, "downbeats": pred_downbeats}, |
| | ] |
| | ground_truths = [ |
| | {"beats": gt_beats, "downbeats": gt_downbeats}, |
| | {"beats": gt_beats, "downbeats": gt_downbeats}, |
| | ] |
| |
|
| | all_results = evaluate_all(predictions, ground_truths, verbose=True) |
| | print() |
| | print(format_results(all_results, "Multi-Track Demo")) |
| |
|