| from typing import List, Dict |
| import numpy as np |
| import evaluate |
|
|
| def compute_metrics_sentiment(eval_pred): |
| logits, labels = eval_pred |
| preds = np.argmax(logits, axis=-1) |
| acc = (preds == labels).mean().item() |
| return {"accuracy": acc} |
|
|
| def compute_metrics_ner(eval_pred, label_list: List[str]): |
| seqeval = evaluate.load("seqeval") |
| logits, labels = eval_pred |
| preds = logits.argmax(-1) |
| true_preds = [ |
| [label_list[p] for (p, l) in zip(pred, lab) if l != -100] |
| for pred, lab in zip(preds, labels) |
| ] |
| true_labels = [ |
| [label_list[l] for (p, l) in zip(pred, lab) if l != -100] |
| for pred, lab in zip(preds, labels) |
| ] |
| results = seqeval.compute(predictions=true_preds, references=true_labels) |
| return { |
| "precision": results["overall_precision"], |
| "recall": results["overall_recall"], |
| "f1": results["overall_f1"], |
| "accuracy": results["overall_accuracy"], |
| } |
|
|