| from collections import defaultdict |
| from typing import Tuple |
|
|
| from rex.metrics import calc_p_r_f1_from_tp_fp_fn, safe_division |
| from rex.metrics.base import MetricBase |
| from rex.metrics.tagging import tagging_prf1 |
| from rex.utils.batch import decompose_batch_into_instances |
| from rex.utils.iteration import windowed_queue_iter |
| from rex.utils.random import generate_random_string_with_datetime |
| from sklearn.metrics import accuracy_score, matthews_corrcoef |
|
|
|
|
| class MrcNERMetric(MetricBase): |
| def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple: |
| gold_instances = [] |
| pred_instances = [] |
|
|
| batch_gold = decompose_batch_into_instances(raw_batch) |
| assert len(batch_gold) == len(out_batch["pred"]) |
|
|
| for i, gold in enumerate(batch_gold): |
| gold_instances.append( |
| { |
| "id": gold["id"], |
| "ents": {(gold["ent_type"], gent) for gent in gold["gold_ents"]}, |
| } |
| ) |
| pred_instances.append( |
| { |
| "id": gold["id"], |
| "ents": {(gold["ent_type"], pent) for pent in out_batch["pred"][i]}, |
| } |
| ) |
|
|
| return gold_instances, pred_instances |
|
|
| def calculate_scores(self, golds: list, preds: list) -> dict: |
| id2gold = defaultdict(set) |
| id2pred = defaultdict(set) |
| |
| for gold in golds: |
| id2gold[gold["id"]].update(gold["ents"]) |
| for pred in preds: |
| id2pred[pred["id"]].update(pred["ents"]) |
| assert len(id2gold) == len(id2pred) |
|
|
| gold_ents = [] |
| pred_ents = [] |
| for _id in id2gold: |
| gold_ents.append(id2gold[_id]) |
| pred_ents.append(id2pred[_id]) |
|
|
| return tagging_prf1(gold_ents, pred_ents, type_idx=0) |
|
|
|
|
| class MrcSpanMetric(MetricBase): |
| def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple: |
| gold_instances = [] |
| pred_instances = [] |
|
|
| batch_gold = decompose_batch_into_instances(raw_batch) |
| assert len(batch_gold) == len(out_batch["pred"]) |
|
|
| for i, gold in enumerate(batch_gold): |
| gold_instances.append( |
| { |
| "id": gold["id"], |
| "spans": set(tuple(span) for span in gold["gold_spans"]), |
| } |
| ) |
| pred_instances.append( |
| { |
| "id": gold["id"], |
| "spans": set(out_batch["pred"][i]), |
| } |
| ) |
|
|
| return gold_instances, pred_instances |
|
|
| def calculate_scores(self, golds: list, preds: list) -> dict: |
| id2gold = defaultdict(set) |
| id2pred = defaultdict(set) |
| |
| for gold in golds: |
| id2gold[gold["id"]].update(gold["spans"]) |
| for pred in preds: |
| id2pred[pred["id"]].update(pred["spans"]) |
| assert len(id2gold) == len(id2pred) |
|
|
| gold_spans = [] |
| pred_spans = [] |
| for _id in id2gold: |
| gold_spans.append(id2gold[_id]) |
| pred_spans.append(id2pred[_id]) |
|
|
| return tagging_prf1(gold_spans, pred_spans, type_idx=None) |
|
|
|
|
| def calc_char_event(golds, preds): |
| """ |
| Calculate char-level event argument scores |
| |
| References: |
| - https://aistudio.baidu.com/aistudio/competition/detail/46/0/submit-result |
| |
| Args: |
| golds: a list of gold answers (a list of `event_list`), len=#data, |
| format is a list of `event_list` |
| preds: a list of pred answers, len=#data |
| """ |
|
|
| def _match_arg_char_f1(gold_arg, pred_args): |
| gtype, grole, gstring = gold_arg |
| gchars = set(gstring) |
| garg_len = len(gchars) |
| cands = [] |
| for parg in pred_args: |
| if parg[0] == gtype and parg[1] == grole: |
| pchars = set(str(parg[-1])) |
| parg_len = len(pchars) |
| pmatch = len(pchars & gchars) |
| p = safe_division(pmatch, parg_len) |
| r = safe_division(pmatch, garg_len) |
| f1 = safe_division(2 * p * r, p + r) |
| cands.append(f1) |
| if len(cands) > 0: |
| f1 = sorted(cands)[-1] |
| return f1 |
| else: |
| return 0.0 |
|
|
| pscore = num_gargs = num_pargs = 0 |
| for _golds, _preds in zip(golds, preds): |
| |
| gold_args = [] |
| pred_args = [] |
| for gold in _golds: |
| for arg in gold.get("arguments", []): |
| gold_args.append( |
| (gold.get("event_type"), arg.get("role"), arg.get("argument")) |
| ) |
| for pred in _preds: |
| for arg in pred.get("arguments", []): |
| pred_args.append( |
| (pred.get("event_type"), arg.get("role"), arg.get("argument")) |
| ) |
|
|
| num_gargs += len(gold_args) |
| num_pargs += len(pred_args) |
| for gold_arg in gold_args: |
| pscore += _match_arg_char_f1(gold_arg, pred_args) |
|
|
| p = safe_division(pscore, num_pargs) |
| r = safe_division(pscore, num_gargs) |
| f1 = safe_division(2 * p * r, p + r) |
| return { |
| "p": p, |
| "r": r, |
| "f1": f1, |
| "pscore": pscore, |
| "num_pargs": num_pargs, |
| "num_gargs": num_gargs, |
| } |
|
|
|
|
| def calc_trigger_identification_metrics(golds, preds): |
| tp = fp = fn = 0 |
| for _golds, _preds in zip(golds, preds): |
| gold_triggers = {gold["trigger"] for gold in _golds} |
| pred_triggers = {pred["trigger"] for pred in _preds} |
| tp += len(gold_triggers & pred_triggers) |
| fp += len(pred_triggers - gold_triggers) |
| fn += len(gold_triggers - pred_triggers) |
| metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) |
| return metrics |
|
|
|
|
| def calc_trigger_classification_metrics(golds, preds): |
| tp = fp = fn = 0 |
| for _golds, _preds in zip(golds, preds): |
| gold_tgg_cls = {(gold["trigger"], gold["event_type"]) for gold in _golds} |
| pred_tgg_cls = {(pred["trigger"], pred["event_type"]) for pred in _preds} |
| tp += len(gold_tgg_cls & pred_tgg_cls) |
| fp += len(pred_tgg_cls - gold_tgg_cls) |
| fn += len(gold_tgg_cls - pred_tgg_cls) |
| metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) |
| return metrics |
|
|
|
|
| def calc_arg_identification_metrics(golds, preds): |
| """Calculate argument identification metrics |
| |
| Notice: |
| An entity could take different roles in an event, |
| so the base number must be calculated by |
| (arg, event type, pos, role) |
| """ |
| tp = fp = fn = 0 |
| for _golds, _preds in zip(golds, preds): |
| gold_args = set() |
| pred_args = set() |
| for gold in _golds: |
| _args = { |
| (arg["role"], arg["argument"], gold["event_type"]) |
| for arg in gold["arguments"] |
| } |
| gold_args.update(_args) |
| for pred in _preds: |
| _args = { |
| (arg["role"], arg["argument"], pred["event_type"]) |
| for arg in pred["arguments"] |
| } |
| pred_args.update(_args) |
| |
| _tp = 0 |
| _tp_fp = len(pred_args) |
| _tp_fn = len(gold_args) |
| _gold_args_wo_role = {_ga[1:] for _ga in gold_args} |
| for pred_arg in pred_args: |
| if pred_arg[1:] in _gold_args_wo_role: |
| _tp += 1 |
| tp += _tp |
| fp += _tp_fp - _tp |
| fn += _tp_fn - _tp |
| metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) |
| return metrics |
|
|
|
|
| def calc_arg_classification_metrics(golds, preds): |
| tp = fp = fn = 0 |
| for _golds, _preds in zip(golds, preds): |
| gold_arg_cls = set() |
| pred_arg_cls = set() |
| for gold in _golds: |
| _args = { |
| (arg["argument"], arg["role"], gold["event_type"]) |
| for arg in gold["arguments"] |
| } |
| gold_arg_cls.update(_args) |
| for pred in _preds: |
| _args = { |
| (arg["argument"], arg["role"], pred["event_type"]) |
| for arg in pred["arguments"] |
| } |
| pred_arg_cls.update(_args) |
| tp += len(gold_arg_cls & pred_arg_cls) |
| fp += len(pred_arg_cls - gold_arg_cls) |
| fn += len(gold_arg_cls - pred_arg_cls) |
| metrics = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) |
| return metrics |
|
|
|
|
| def calc_ent(golds, preds): |
| """ |
| Args: |
| golds, preds: [(type, index list), ...] |
| """ |
| res = tagging_prf1(golds, preds, type_idx=0) |
| return res |
|
|
|
|
| def calc_rel(golds, preds): |
| gold_ents = [] |
| pred_ents = [] |
| for gold, pred in zip(golds, preds): |
| gold_ins_ents = [] |
| for t in gold: |
| gold_ins_ents.extend(t[1:]) |
| gold_ents.append(gold_ins_ents) |
| pred_ins_ents = [] |
| for t in pred: |
| pred_ins_ents.extend(t[1:]) |
| pred_ents.append(pred_ins_ents) |
|
|
| metrics = { |
| "ent": tagging_prf1(gold_ents, pred_ents, type_idx=None), |
| "rel": tagging_prf1(golds, preds, type_idx=None), |
| } |
| return metrics |
|
|
|
|
| def calc_cls(golds, preds): |
| metrics = { |
| "mcc": -1, |
| "acc": -1, |
| "mf1": tagging_prf1(golds, preds, type_idx=None), |
| } |
| y_true = [] |
| y_pred = [] |
| for gold, pred in zip(golds, preds): |
| y_true.append(" ".join(sorted(gold))) |
| y_pred.append(" ".join(sorted(pred))) |
| if y_true and y_pred: |
| metrics["acc"] = accuracy_score(y_true, y_pred) |
| else: |
| metrics["acc"] = 0.0 |
| metrics["mcc"] = matthews_corrcoef(y_true, y_pred) |
| return metrics |
|
|
|
|
| def calc_span(golds, preds, mode="span"): |
| def _get_tokens(spans: list[tuple[tuple[int]]]) -> list[int]: |
| tokens = [] |
| for span in spans: |
| for part in span: |
| _toks = [] |
| if len(part) == 1: |
| _toks = [part[0]] |
| elif len(part) > 1: |
| if mode == "w2": |
| _toks = [*part] |
| elif mode == "span": |
| _toks = [*range(part[0], part[1] + 1)] |
| else: |
| raise ValueError |
| tokens.extend(_toks) |
| return tokens |
|
|
| metrics = { |
| "em": -1, |
| "f1": None, |
| } |
| acc_num = 0 |
| tp = fp = fn = 0 |
| for gold, pred in zip(golds, preds): |
| if gold == pred: |
| acc_num += 1 |
| gold_tokens = _get_tokens(gold) |
| pred_tokens = _get_tokens(pred) |
| tp += len(set(gold_tokens) & set(pred_tokens)) |
| fp += len(set(pred_tokens) - set(gold_tokens)) |
| fn += len(set(gold_tokens) - set(pred_tokens)) |
| if len(golds) > 0: |
| metrics["em"] = acc_num / len(golds) |
| else: |
| metrics["em"] = 0.0 |
| metrics["f1"] = calc_p_r_f1_from_tp_fp_fn(tp, fp, fn) |
| return metrics |
|
|
|
|
| class MultiPartSpanMetric(MetricBase): |
| def _encode_span_to_label_dict(self, span_to_label: dict) -> list: |
| span_to_label_list = [] |
| for key, val in span_to_label.items(): |
| span_to_label_list.append({"key": key, "val": val}) |
| return span_to_label_list |
|
|
| def _decode_span_to_label(self, span_to_label_list: list) -> dict: |
| span_to_label = {} |
| for content in span_to_label_list: |
| span_to_label[tuple(content["key"])] = content["val"] |
| return span_to_label |
|
|
| def get_instances_from_batch(self, raw_batch: dict, out_batch: dict) -> Tuple: |
| gold_instances = [] |
| pred_instances = [] |
|
|
| batch_gold = decompose_batch_into_instances(raw_batch) |
| assert len(batch_gold) == len(out_batch["pred"]) |
|
|
| for i, gold in enumerate(batch_gold): |
| ins_id = gold["raw"].get("id", generate_random_string_with_datetime()) |
| |
| |
| span_to_label_list = self._encode_span_to_label_dict(gold["span_to_label"]) |
| gold["span_to_label"] = span_to_label_list |
| gold_instances.append( |
| { |
| "id": ins_id, |
| "span_to_label_list": span_to_label_list, |
| "raw_gold_content": gold, |
| "spans": set( |
| tuple(multi_part_span) for multi_part_span in gold["spans"] |
| ), |
| } |
| ) |
| pred_instances.append( |
| { |
| "id": ins_id, |
| "spans": set( |
| tuple(multi_part_span) |
| for multi_part_span in out_batch["pred"][i] |
| ), |
| } |
| ) |
|
|
| return gold_instances, pred_instances |
|
|
| def calculate_scores(self, golds: list, preds: list) -> dict: |
| |
| general_gold_spans, general_pred_spans = [], [] |
| |
| gold_cls_list, pred_cls_list = [], [] |
| |
| gold_ent_list, pred_ent_list = [], [] |
| |
| gold_rel_list, pred_rel_list = [], [] |
| |
| gold_event_list, pred_event_list = [], [] |
| |
| gold_span_list, pred_span_list = [], [] |
| |
| gold_discon_ent_list, pred_discon_ent_list = [], [] |
| |
| gold_hyper_rel_list, pred_hyper_rel_list = [], [] |
|
|
| for gold, pred in zip(golds, preds): |
| general_gold_spans.append(gold["spans"]) |
| general_pred_spans.append(pred["spans"]) |
| span_to_label = self._decode_span_to_label(gold["span_to_label_list"]) |
| gold_clses, pred_clses = [], [] |
| gold_ents, pred_ents = [], [] |
| gold_rels, pred_rels = [], [] |
| gold_trigger_to_event = defaultdict( |
| lambda: {"event_type": "", "arguments": []} |
| ) |
| pred_trigger_to_event = defaultdict( |
| lambda: {"event_type": "", "arguments": []} |
| ) |
| gold_events, pred_events = [], [] |
| gold_spans, pred_spans = [], [] |
| gold_discon_ents, pred_discon_ents = [], [] |
| gold_hyper_rels, pred_hyper_rels = [], [] |
|
|
| raw_schema = gold["raw_gold_content"]["raw"]["schema"] |
| for span in gold["spans"]: |
| if span[0] in span_to_label: |
| label = span_to_label[span[0]] |
| if label["task"] == "cls" and len(span) == 1: |
| gold_clses.append(label["string"]) |
| elif label["task"] == "ent" and len(span) == 2: |
| gold_ents.append((label["string"], *span[1:])) |
| elif label["task"] == "rel" and len(span) == 3: |
| gold_rels.append((label["string"], *span[1:])) |
| elif label["task"] == "event": |
| if label["type"] == "lm" and len(span) == 2: |
| gold_trigger_to_event[span[1]]["event_type"] = label["string"] |
| elif label["type"] == "lr" and len(span) == 3: |
| gold_trigger_to_event[span[1]]["arguments"].append( |
| {"argument": span[2], "role": label["string"]} |
| ) |
| elif label["task"] == "discontinuous_ent" and len(span) > 1: |
| gold_discon_ents.append((label["string"], *span[1:])) |
| elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: |
| q_label = span_to_label[span[3]] |
| gold_hyper_rels.append((label["string"], span[1], span[2], q_label["string"], span[4])) |
| else: |
| |
| gold_spans.append(tuple(span)) |
| for trigger, item in gold_trigger_to_event.items(): |
| legal_roles = raw_schema["event"][item["event_type"]] |
| gold_events.append( |
| { |
| "trigger": trigger, |
| "event_type": item["event_type"], |
| "arguments": [ |
| arg |
| for arg in filter( |
| lambda arg: arg["role"] in legal_roles, |
| item["arguments"], |
| ) |
| ], |
| } |
| ) |
|
|
| for span in pred["spans"]: |
| if span[0] in span_to_label: |
| label = span_to_label[span[0]] |
| if label["task"] == "cls" and len(span) == 1: |
| pred_clses.append(label["string"]) |
| elif label["task"] == "ent" and len(span) == 2: |
| pred_ents.append((label["string"], *span[1:])) |
| elif label["task"] == "rel" and len(span) == 3: |
| pred_rels.append((label["string"], *span[1:])) |
| elif label["task"] == "event": |
| if label["type"] == "lm" and len(span) == 2: |
| pred_trigger_to_event[span[1]]["event_type"] = label["string"] |
| elif label["type"] == "lr" and len(span) == 3: |
| pred_trigger_to_event[span[1]]["arguments"].append( |
| {"argument": span[2], "role": label["string"]} |
| ) |
| elif label["task"] == "discontinuous_ent" and len(span) > 1: |
| pred_discon_ents.append((label["string"], *span[1:])) |
| elif label["task"] == "hyper_rel" and len(span) == 5 and span[3] in span_to_label: |
| q_label = span_to_label[span[3]] |
| pred_hyper_rels.append((label["string"], span[1], span[2], q_label["string"], span[4])) |
| else: |
| |
| pred_spans.append(tuple(span)) |
| for trigger, item in pred_trigger_to_event.items(): |
| if item["event_type"] not in raw_schema["event"]: |
| continue |
| legal_roles = raw_schema["event"][item["event_type"]] |
| pred_events.append( |
| { |
| "trigger": trigger, |
| "event_type": item["event_type"], |
| "arguments": [ |
| arg |
| for arg in filter( |
| lambda arg: arg["role"] in legal_roles, |
| item["arguments"], |
| ) |
| ], |
| } |
| ) |
|
|
| gold_cls_list.append(gold_clses) |
| pred_cls_list.append(pred_clses) |
| gold_ent_list.append(gold_ents) |
| pred_ent_list.append(pred_ents) |
| gold_rel_list.append(gold_rels) |
| pred_rel_list.append(pred_rels) |
| gold_event_list.append(gold_events) |
| pred_event_list.append(pred_events) |
| gold_span_list.append(gold_spans) |
| pred_span_list.append(pred_spans) |
| gold_discon_ent_list.append(gold_discon_ents) |
| pred_discon_ent_list.append(pred_discon_ents) |
| gold_hyper_rel_list.append(gold_hyper_rels) |
| pred_hyper_rel_list.append(pred_hyper_rels) |
|
|
| metrics = { |
| "general_spans": tagging_prf1( |
| general_gold_spans, general_pred_spans, type_idx=None |
| ), |
| "cls": calc_cls(gold_cls_list, pred_cls_list), |
| "ent": calc_ent(gold_ent_list, pred_ent_list), |
| "rel": calc_rel(gold_rel_list, pred_rel_list), |
| "event": { |
| "trigger_id": calc_trigger_identification_metrics( |
| gold_event_list, pred_event_list |
| ), |
| "trigger_cls": calc_trigger_classification_metrics( |
| gold_event_list, pred_event_list |
| ), |
| "arg_id": calc_arg_identification_metrics( |
| gold_event_list, pred_event_list |
| ), |
| "arg_cls": calc_arg_classification_metrics( |
| gold_event_list, pred_event_list |
| ), |
| "char_event": calc_char_event(gold_event_list, pred_event_list), |
| }, |
| "discontinuous_ent": tagging_prf1( |
| gold_discon_ent_list, pred_discon_ent_list, type_idx=None |
| ), |
| "hyper_rel": tagging_prf1( |
| gold_hyper_rel_list, pred_hyper_rel_list, type_idx=None |
| ), |
| |
| "span": calc_span(gold_span_list, pred_span_list), |
| } |
|
|
| return metrics |
|
|