| import re |
| import string |
|
|
| import jieba |
| from fuzzywuzzy import fuzz |
| import difflib |
|
|
| from typing import List |
| from collections import Counter |
| from rouge import Rouge |
|
|
| def normalize_answer(s): |
| """Lower text and remove punctuation, articles and extra whitespace.""" |
|
|
| def remove_articles(text): |
| return re.sub(r"\b(a|an|the)\b", " ", text) |
|
|
| def white_space_fix(text): |
| return " ".join(text.split()) |
|
|
| def remove_punc(text): |
| exclude = set(string.punctuation) |
| return "".join(ch for ch in text if ch not in exclude) |
|
|
| def lower(text): |
| return text.lower() |
|
|
| return white_space_fix(remove_articles(remove_punc(lower(s)))) |
|
|
|
|
| def normalize_zh_answer(s): |
| """Lower text and remove punctuation, extra whitespace.""" |
|
|
| def white_space_fix(text): |
| return "".join(text.split()) |
|
|
| def remove_punc(text): |
| cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." |
| all_punctuation = set(string.punctuation + cn_punctuation) |
| return "".join(ch for ch in text if ch not in all_punctuation) |
|
|
| def lower(text): |
| return text.lower() |
|
|
| return white_space_fix(remove_punc(lower(s))) |
|
|
| def count_score(prediction, ground_truth, **kwargs): |
| numbers = re.findall(r"\d+", prediction) |
| right_num = 0 |
| for number in numbers: |
| if str(number) == str(ground_truth): |
| right_num += 1 |
| final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) |
| return float(final_score) |
|
|
| def retrieval_score(prediction, ground_truth, **kwargs): |
| pattern = r'Paragraph (\d+)' |
| matches = re.findall(pattern, ground_truth) |
| ground_truth_id = matches[0] |
| numbers = re.findall(r"\d+", prediction) |
| right_num = 0 |
| for number in numbers: |
| if str(number) == str(ground_truth_id): |
| right_num += 1 |
| final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) |
| return float(final_score) |
|
|
| def retrieval_zh_score(prediction, ground_truth, **kwargs): |
| pattern = r'段落(\d+)' |
| matches = re.findall(pattern, ground_truth) |
| ground_truth_id = matches[0] |
| numbers = re.findall(r"\d+", prediction) |
| right_num = 0 |
| for number in numbers: |
| if str(number) == str(ground_truth_id): |
| right_num += 1 |
| final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) |
| return float(final_score) |
|
|
| def code_sim_score(prediction, ground_truth, **kwargs): |
| all_lines = prediction.lstrip('\n').split('\n') |
| prediction = "" |
| for line in all_lines: |
| if ('`' not in line) and ('#' not in line) and ('//' not in line): |
| prediction = line |
| break |
| return (fuzz.ratio(prediction, ground_truth) / 100) |
|
|
| def classification_score(prediction, ground_truth, **kwargs): |
| em_match_list = [] |
| all_classes = kwargs["all_classes"] |
| for class_name in all_classes: |
| if class_name in prediction: |
| em_match_list.append(class_name) |
| for match_term in em_match_list: |
| if match_term in ground_truth and match_term != ground_truth: |
| em_match_list.remove(match_term) |
| if ground_truth in em_match_list: |
| score = (1.0 / len(em_match_list)) |
| else: |
| score = 0.0 |
| return score |
| |
| def rouge_score(prediction, ground_truth, **kwargs): |
| rouge = Rouge() |
| try: |
| scores = rouge.get_scores([prediction], [ground_truth], avg=True) |
| except: |
| return 0.0 |
| return scores["rouge-l"]["f"] |
|
|
| def rouge_zh_score(prediction, ground_truth, **kwargs): |
| prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) |
| ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False))) |
| score = rouge_score(prediction, ground_truth) |
| return score |
|
|
| def f1_score(prediction, ground_truth, **kwargs): |
| common = Counter(prediction) & Counter(ground_truth) |
| num_same = sum(common.values()) |
| if num_same == 0: |
| return 0 |
| precision = 1.0 * num_same / len(prediction) |
| recall = 1.0 * num_same / len(ground_truth) |
| f1 = (2 * precision * recall) / (precision + recall) |
| return f1 |
|
|
| def qa_f1_score(prediction, ground_truth, **kwargs): |
| normalized_prediction = normalize_answer(prediction) |
| normalized_ground_truth = normalize_answer(ground_truth) |
|
|
| prediction_tokens = normalized_prediction.split() |
| ground_truth_tokens = normalized_ground_truth.split() |
| return f1_score(prediction_tokens, ground_truth_tokens) |
|
|
|
|
| def qa_f1_zh_score(prediction, ground_truth, **kwargs): |
| prediction_tokens = list(jieba.cut(prediction, cut_all=False)) |
| ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False)) |
| prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] |
| ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens] |
| prediction_tokens = [token for token in prediction_tokens if len(token) > 0] |
| ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] |
| return f1_score(prediction_tokens, ground_truth_tokens) |
|
|
| def qa_em_score(prediction, ground_truth, **kwargs): |
| normalized_prediction = normalize_answer(prediction) |
| normalized_ground_truth = normalize_answer(ground_truth) |
| return 1 if (normalized_prediction in normalized_ground_truth or normalized_ground_truth in normalized_prediction) else 0 |
|
|
|
|
| |