| import os |
| import sys |
| import pickle |
| import json |
| from json import encoder |
| from uniperceiver.config import configurable |
| from .build import EVALUATION_REGISTRY |
|
|
| import numpy as np |
| from sklearn.metrics import f1_score, matthews_corrcoef |
| from scipy.stats import pearsonr, spearmanr |
|
|
| def simple_accuracy(preds, labels): |
| return (preds == labels).mean() |
|
|
| @EVALUATION_REGISTRY.register() |
| class GLUEEvaler(object): |
| def __init__(self, cfg, *args, **kwargs): |
| super(GLUEEvaler, self).__init__() |
| self.task_name = cfg.DATASETS.DATASET_NAME |
| self.tasks = [""] |
|
|
|
|
|
|
| def eval(self, results, epoch): |
| preds = [] |
| labels = [] |
| for result in results: |
| |
| if self.task_name != 'STS-B': |
| preds.append(result["pred"].argmax().item()) |
| labels.append(int(result["label"])) |
|
|
| else: |
| |
| preds.append(float(result["pred"].sigmoid().item())) |
| labels.append(float(result["label"])) |
|
|
| preds = np.array(preds) |
| labels = np.array(labels) |
|
|
| if self.task_name == 'CoLA': |
| acc = simple_accuracy(preds, labels) |
| matthewscorr = matthews_corrcoef(labels, preds) |
| result = { |
| "accuracy": acc, |
| "matthews_corrcoef": matthewscorr, |
| } |
| elif self.task_name in [ 'QNLI', 'RTE', 'SST-2'] or self.task_name.startswith("MNLI"): |
| acc = simple_accuracy(preds, labels) |
| result = { |
| "accuracy": acc, |
| } |
| elif self.task_name in ['MRPC', 'QQP']: |
| acc = simple_accuracy(preds, labels) |
| f1 = f1_score(y_true=labels, y_pred=preds) |
| result = { |
| "accuracy": acc, |
| "f1_score": f1, |
| } |
| elif self.task_name in ['STS-B']: |
| pearson_corr = pearsonr(preds, labels)[0] |
| spearman_corr = spearmanr(preds, labels)[0] |
| result ={ |
| "pearson_corr": pearson_corr, |
| "spearman_corr": spearman_corr, |
| } |
| else: |
| raise NotImplementedError |
|
|
| return result |
|
|