| import ast |
| import json |
| import math |
| import os |
| import re |
|
|
| from datasets import Dataset |
|
|
| from opencompass.openicl.icl_evaluator.icl_base_evaluator import BaseEvaluator |
| from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET |
| from opencompass.utils import get_data_path |
|
|
| from .base import BaseDataset |
|
|
|
|
| @LOAD_DATASET.register_module() |
| class OpenSWIDataset(BaseDataset): |
|
|
| @staticmethod |
| def load(path: str, name: str): |
| new_data = [] |
| path = os.path.join(get_data_path(path), name) |
| for file in os.listdir(path): |
| if file.endswith('.jsonl'): |
| final_path = os.path.join(path, file) |
| with open(final_path, 'r', encoding='utf-8') as f: |
| for line in f: |
| data = json.loads(line) |
| new_data.append({ |
| 'id': data['id_ddm'], |
| 'prompt': data['dialogs'][0]['content'], |
| 'ground_truth': data['ground_truth'], |
| 'subset': file.split('.')[0] |
| }) |
| dataset = Dataset.from_list(new_data) |
| return dataset |
|
|
|
|
| def extract_list(text): |
| |
| matches = re.findall(r'(\[.*?\])', text, re.DOTALL) |
| if matches: |
| raw_list_str = matches[-1] |
| |
| try: |
| question_list = ast.literal_eval(raw_list_str) |
| return question_list |
| except Exception: |
| return None |
| else: |
| return None |
|
|
|
|
| @ICL_EVALUATORS.register_module() |
| class OpenSWIMSEEvaluator(BaseEvaluator): |
| """Exact match evaluator for name conversion.""" |
|
|
| def __init__(self) -> None: |
| super().__init__() |
|
|
| def score(self, predictions, references): |
| if len(predictions) != len(references): |
| return { |
| 'error': 'predictions and references have different ' |
| 'length' |
| } |
|
|
| avg_score = 0 |
| avg_valid = [] |
| details = [] |
| for prediction, reference in zip(predictions, references): |
| pred = extract_list(prediction) |
| ans = reference |
|
|
| if not pred or all(isinstance(x, float) for x in pred) is False: |
| detail = {'pred': None, 'answer': ans, 'valid': False} |
| pred = [0] * len(ans) |
| else: |
| detail = {'pred': pred, 'answer': ans, 'valid': True} |
| if len(pred) < len(ans): |
| detail['valid'] = False |
| pred = pred + [0] * (len(ans) - len(pred)) |
| elif len(pred) > len(ans): |
| detail['valid'] = False |
| pred = pred[:len(ans)] |
| avg_valid.append(detail['valid']) |
| squared_errors = [(a - p)**2 for a, p in zip(ans, pred)] |
| rmse_score = math.sqrt(sum(squared_errors) / len(squared_errors)) |
| detail['score'] = rmse_score |
| avg_score += rmse_score |
| details.append(detail) |
|
|
| score = avg_score / len(predictions) |
| valid = sum(avg_valid) / len(avg_valid) |
|
|
| return {'score': score, 'valid': valid * 100, 'details': details} |
|
|