| |
| """ |
| @author: Yehao Li |
| @contact: yehaoli.sysu@gmail.com |
| """ |
| import os |
| import sys |
| import pickle |
| import json |
| from json import encoder |
| from .build import EVALUATION_REGISTRY |
|
|
| @EVALUATION_REGISTRY.register() |
| class VQAEvaler(object): |
| def __init__(self, cfg, annfile, output_dir): |
| super(VQAEvaler, self).__init__() |
| label2ans_path = os.path.join(cfg.DATALOADER.ANNO_FOLDER, "trainval_label2ans.pkl") |
| ori_annotation = json.load(open(os.path.join(cfg.DATALOADER.ANNO_FOLDER, "v2_mscoco_val2014_annotations.json"))) |
| self.id2type = {t['question_id']: t['question_type'] for t in ori_annotation['annotations']} |
| self.id2type_answer = {t['question_id']: t['answer_type'] for t in ori_annotation['annotations']} |
| self.label2ans = pickle.load(open(label2ans_path, "rb")) |
|
|
| self.id2label = {} |
| if len(annfile) > 0: |
| answers_val = pickle.load(open(annfile, "rb")) |
| for datum in answers_val: |
| quesid = datum['question_id'] |
| self.id2label[quesid] = {} |
| for i, label in enumerate(datum['labels']): |
| label_str = self.label2ans[label] |
| self.id2label[quesid][label_str] = datum['scores'][i] |
|
|
| if output_dir is not None: |
| self.output_dir = os.path.join(output_dir, 'results') |
| if not os.path.exists(self.output_dir): |
| os.mkdir(self.output_dir) |
| else: |
| self.output_dir = None |
|
|
| def eval(self, results, epoch): |
| for res in results: |
| res['answer'] = self.label2ans[res['answer']] |
| if self.output_dir is not None: |
| json.dump(results, open(os.path.join(self.output_dir, str(epoch) + '.json'), "w", encoding="utf-8")) |
|
|
| accuracy = 0. |
| acc_by_type = dict() |
| acc_by_type_answer = dict() |
| for result in results: |
| quesid = result['question_id'] |
| ans = result['answer'] |
| if quesid not in self.id2type: |
| print("Test Stage has no target") |
| return { "accuracy": 0.0 } |
| type = self.id2type[quesid] |
| ans_type = self.id2type_answer[quesid] |
| if type not in acc_by_type: |
| acc_by_type[type] = [0,0] |
| if ans_type not in acc_by_type_answer: |
| acc_by_type_answer[ans_type] = [0,0] |
| if quesid not in self.id2label: |
| return { "accuracy": 0.0 } |
|
|
| datum = self.id2label[quesid] |
| acc_by_type[type][1] += 1 |
| acc_by_type_answer[ans_type][1] += 1 |
| if ans in datum: |
| accuracy += datum[ans] |
| acc_by_type[type][0] += datum[ans] |
| acc_by_type_answer[ans_type][0] += datum[ans] |
|
|
| accuracy = accuracy / len(results) |
| print('vqa acc: {}'.format(accuracy*100)) |
| for k in acc_by_type.keys(): |
| acc_by_type[k] = acc_by_type[k][0] / acc_by_type[k][1] * 100 |
| for k in acc_by_type_answer.keys(): |
| acc_by_type_answer[k] = acc_by_type_answer[k][0] / acc_by_type_answer[k][1] * 100 |
| print('vqa acc by question type: ', acc_by_type) |
| print("vqa acc by answer type: ", acc_by_type_answer) |
| return { "accuracy": accuracy } |
|
|