| r"""Compute active speaker detection performance for the AVA dataset.
|
| Please send any questions about this code to the Google Group ava-dataset-users:
|
| https://groups.google.com/forum/#!forum/ava-dataset-users
|
| Example usage:
|
| python -O get_ava_active_speaker_performance.py \
|
| -g testdata/eval.csv \
|
| -p testdata/predictions.csv \
|
| -v
|
| """
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import argparse
|
| import logging
|
| import time, warnings
|
| import numpy as np
|
| import pandas as pd
|
| import matplotlib.pyplot as plt
|
| warnings.filterwarnings("ignore")
|
|
|
| def compute_average_precision(precision, recall):
|
| """Compute Average Precision according to the definition in VOCdevkit.
|
| Precision is modified to ensure that it does not decrease as recall
|
| decrease.
|
| Args:
|
| precision: A float [N, 1] numpy array of precisions
|
| recall: A float [N, 1] numpy array of recalls
|
| Raises:
|
| ValueError: if the input is not of the correct format
|
| Returns:
|
| average_precison: The area under the precision recall curve. NaN if
|
| precision and recall are None.
|
| """
|
| if precision is None:
|
| if recall is not None:
|
| raise ValueError("If precision is None, recall must also be None")
|
| return np.NAN
|
|
|
| if not isinstance(precision, np.ndarray) or not isinstance(
|
| recall, np.ndarray):
|
| raise ValueError("precision and recall must be numpy array")
|
| if precision.dtype != np.float or recall.dtype != np.float:
|
| raise ValueError("input must be float numpy array.")
|
| if len(precision) != len(recall):
|
| raise ValueError("precision and recall must be of the same size.")
|
| if not precision.size:
|
| return 0.0
|
| if np.amin(precision) < 0 or np.amax(precision) > 1:
|
| raise ValueError("Precision must be in the range of [0, 1].")
|
| if np.amin(recall) < 0 or np.amax(recall) > 1:
|
| raise ValueError("recall must be in the range of [0, 1].")
|
| if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)):
|
| raise ValueError("recall must be a non-decreasing array")
|
|
|
| recall = np.concatenate([[0], recall, [1]])
|
| precision = np.concatenate([[0], precision, [0]])
|
|
|
|
|
| for i in range(len(precision) - 2, -1, -1):
|
| precision[i] = np.maximum(precision[i], precision[i + 1])
|
|
|
| indices = np.where(recall[1:] != recall[:-1])[0] + 1
|
| average_precision = np.sum(
|
| (recall[indices] - recall[indices - 1]) * precision[indices])
|
| return average_precision
|
|
|
|
|
| def load_csv(filename, column_names):
|
| """Loads CSV from the filename using given column names.
|
| Adds uid column.
|
| Args:
|
| filename: Path to the CSV file to load.
|
| column_names: A list of column names for the data.
|
| Returns:
|
| df: A Pandas DataFrame containing the data.
|
| """
|
|
|
|
|
| df = pd.read_csv(filename, usecols=column_names)
|
|
|
|
|
|
|
| df["uid"] = (df["frame_timestamp"].map(str) + ":" + df["entity_id"])
|
| return df
|
|
|
|
|
| def eq(a, b, tolerance=1e-09):
|
| """Returns true if values are approximately equal."""
|
| return abs(a - b) <= tolerance
|
|
|
|
|
| def merge_groundtruth_and_predictions(df_groundtruth, df_predictions):
|
| """Merges groundtruth and prediction DataFrames.
|
| The returned DataFrame is merged on uid field and sorted in descending order
|
| by score field. Bounding boxes are checked to make sure they match between
|
| groundtruth and predictions.
|
| Args:
|
| df_groundtruth: A DataFrame with groundtruth data.
|
| df_predictions: A DataFrame with predictions data.
|
| Returns:
|
| df_merged: A merged DataFrame, with rows matched on uid column.
|
| """
|
| if df_groundtruth["uid"].count() != df_predictions["uid"].count():
|
| raise ValueError(
|
| "Groundtruth and predictions CSV must have the same number of "
|
| "unique rows.")
|
|
|
| if df_predictions["label"].unique() != ["SPEAKING_AUDIBLE"]:
|
| raise ValueError(
|
| "Predictions CSV must contain only SPEAKING_AUDIBLE label.")
|
|
|
| if df_predictions["score"].count() < df_predictions["uid"].count():
|
| raise ValueError("Predictions CSV must contain score value for every row.")
|
|
|
|
|
|
|
| df_merged = df_groundtruth.merge(
|
| df_predictions,
|
| on="uid",
|
| suffixes=("_groundtruth", "_prediction"),
|
| validate="1:1").sort_values(
|
| by=["score"], ascending=False).reset_index()
|
|
|
|
|
| df_merged["bounding_box_correct"] = np.where(
|
| eq(df_merged["entity_box_x1_groundtruth"],
|
| df_merged["entity_box_x1_prediction"])
|
| & eq(df_merged["entity_box_x2_groundtruth"],
|
| df_merged["entity_box_x2_prediction"])
|
| & eq(df_merged["entity_box_y1_groundtruth"],
|
| df_merged["entity_box_y1_prediction"])
|
| & eq(df_merged["entity_box_y2_groundtruth"],
|
| df_merged["entity_box_y2_prediction"]), True, False)
|
|
|
| if (~df_merged["bounding_box_correct"]).sum() > 0:
|
| raise ValueError(
|
| "Mismatch between groundtruth and predictions bounding boxes found at "
|
| + str(list(df_merged[~df_merged["bounding_box_correct"]]["uid"])))
|
|
|
| return df_merged
|
|
|
|
|
| def get_all_positives(df_merged):
|
| """Counts all positive examples in the groundtruth dataset."""
|
| return df_merged[df_merged["label_groundtruth"] ==
|
| "SPEAKING_AUDIBLE"]["uid"].count()
|
|
|
|
|
| def calculate_precision_recall(df_merged):
|
| """Calculates precision and recall arrays going through df_merged row-wise."""
|
| all_positives = get_all_positives(df_merged)
|
|
|
|
|
| df_merged["is_tp"] = np.where(
|
| (df_merged["label_groundtruth"] == "SPEAKING_AUDIBLE") &
|
| (df_merged["label_prediction"] == "SPEAKING_AUDIBLE"), 1, 0)
|
|
|
|
|
| df_merged["tp"] = df_merged["is_tp"].cumsum()
|
|
|
|
|
|
|
| df_merged["precision"] = df_merged["tp"] / (df_merged.index + 1)
|
|
|
|
|
|
|
| df_merged["recall"] = df_merged["tp"] / all_positives
|
| logging.info(
|
| "\n%s\n",
|
| df_merged.head(10)[[
|
| "uid", "score", "label_groundtruth", "is_tp", "tp", "precision",
|
| "recall"
|
| ]])
|
|
|
| return np.array(df_merged["precision"]), np.array(df_merged["recall"])
|
|
|
|
|
| def run_evaluation(groundtruth, predictions):
|
| """Runs AVA Active Speaker evaluation, printing average precision result."""
|
| df_groundtruth = load_csv(
|
| groundtruth,
|
| column_names=[
|
| "video_id", "frame_timestamp", "entity_box_x1", "entity_box_y1",
|
| "entity_box_x2", "entity_box_y2", "label", "entity_id"
|
| ])
|
| df_predictions = load_csv(
|
| predictions,
|
| column_names=[
|
| "video_id", "frame_timestamp", "entity_box_x1", "entity_box_y1",
|
| "entity_box_x2", "entity_box_y2", "label", "entity_id", "score"
|
| ])
|
| df_merged = merge_groundtruth_and_predictions(df_groundtruth, df_predictions)
|
| precision, recall = calculate_precision_recall(df_merged)
|
| mAP = 100 * compute_average_precision(precision, recall)
|
| print("average precision: %2.2f%%"%(mAP))
|
| return mAP
|
|
|
|
|
| def parse_arguments():
|
| """Parses command-line flags.
|
| Returns:
|
| args: a named tuple containing three file objects args.labelmap,
|
| args.groundtruth, and args.detections.
|
| """
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument(
|
| "-g",
|
| "--groundtruth",
|
| help="CSV file containing ground truth.",
|
| type=argparse.FileType("r"),
|
| required=True)
|
| parser.add_argument(
|
| "-p",
|
| "--predictions",
|
| help="CSV file containing active speaker predictions.",
|
| type=argparse.FileType("r"),
|
| required=True)
|
| parser.add_argument(
|
| "-v", "--verbose", help="Increase output verbosity.", action="store_true")
|
| return parser.parse_args()
|
|
|
|
|
| def main():
|
| start = time.time()
|
| args = parse_arguments()
|
| if args.verbose:
|
| logging.basicConfig(level=logging.DEBUG)
|
| del args.verbose
|
| mAP = run_evaluation(**vars(args))
|
| logging.info("Computed in %s seconds", time.time() - start)
|
| return mAP
|
|
|
| if __name__ == "__main__":
|
| main() |