| |
|
| | from transformers import AutoModel, AutoTokenizer |
| | from dataset import get_data_for_evaluation |
| | from arguments import get_args |
| | from tqdm import tqdm |
| | import torch |
| | import os |
| |
|
| |
|
| | def run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer, max_seq_len=512): |
| | |
| | ranked_indices_list = [] |
| | gold_index_list = [] |
| | for doc_id in tqdm(eval_data): |
| | context_list = documents[doc_id] |
| | |
| | with torch.no_grad(): |
| | |
| | context_embs = [] |
| | for chunk in context_list: |
| | chunk_ids = tokenizer(chunk, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda") |
| |
|
| | c_emb = context_encoder(input_ids=chunk_ids.input_ids, attention_mask=chunk_ids.attention_mask) |
| | c_emb = c_emb.last_hidden_state[:, 0, :] |
| | context_embs.append(c_emb) |
| | context_embs = torch.cat(context_embs, dim=0) |
| | |
| | sample_list = eval_data[doc_id] |
| | query_embs = [] |
| | for item in sample_list: |
| | gold_idx = item['gold_idx'] |
| | gold_index_list.append(gold_idx) |
| |
|
| | query = item['query'] |
| | query_ids = tokenizer(query, max_length=max_seq_len, truncation=True, return_tensors="pt").to("cuda") |
| | q_emb = query_encoder(input_ids=query_ids.input_ids, attention_mask=query_ids.attention_mask) |
| | q_emb = q_emb.last_hidden_state[:, 0, :] |
| | query_embs.append(q_emb) |
| | query_embs = torch.cat(query_embs, dim=0) |
| |
|
| | similarities = query_embs.matmul(context_embs.transpose(0,1)) |
| | ranked_results = torch.argsort(similarities, dim=-1, descending=True) |
| | ranked_indices_list.extend(ranked_results.tolist()) |
| |
|
| | return ranked_indices_list, gold_index_list |
| |
|
| |
|
| | def calculate_recall(ranked_indices_list, gold_index_list, topk): |
| | hit = 0 |
| | for ranked_indices, gold_index in zip(ranked_indices_list, gold_index_list): |
| | for idx in ranked_indices[:topk]: |
| | if idx == gold_index: |
| | hit += 1 |
| | break |
| | recall = hit / len(ranked_indices_list) |
| |
|
| | print("top-%d recall score: %.4f" % (topk, recall)) |
| |
|
| |
|
| | def main(): |
| | args = get_args() |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(args.query_encoder_path) |
| |
|
| | |
| | query_encoder = AutoModel.from_pretrained(args.query_encoder_path) |
| | context_encoder = AutoModel.from_pretrained(args.context_encoder_path) |
| | query_encoder.to("cuda"), query_encoder.eval() |
| | context_encoder.to("cuda"), context_encoder.eval() |
| |
|
| | |
| | if args.eval_dataset == "doc2dial": |
| | input_datapath = os.path.join(args.data_folder, args.doc2dial_datapath) |
| | input_docpath = os.path.join(args.data_folder, args.doc2dial_docpath) |
| | elif args.eval_dataset == "quac": |
| | input_datapath = os.path.join(args.data_folder, args.quac_datapath) |
| | input_docpath = os.path.join(args.data_folder, args.quac_docpath) |
| | elif args.eval_dataset == "qrecc": |
| | input_datapath = os.path.join(args.data_folder, args.qrecc_datapath) |
| | input_docpath = os.path.join(args.data_folder, args.qrecc_docpath) |
| | elif args.eval_dataset == "topiocqa" or args.eval_dataset == "inscit": |
| | raise Exception("We have prepare the function to get queries, but a wikipedia corpus needs to be downloaded") |
| | else: |
| | raise Exception("Please input a correct eval_dataset name!") |
| |
|
| | eval_data, documents = get_data_for_evaluation(input_datapath, input_docpath, args.eval_dataset) |
| | |
| | |
| | ranked_indices_list, gold_index_list = run_retrieval(eval_data, documents, query_encoder, context_encoder, tokenizer) |
| | print("number of the total test samples: %d" % len(ranked_indices_list)) |
| |
|
| | |
| | print("evaluating on %s" % args.eval_dataset) |
| | topk_list = [1, 5, 20] |
| | for topk in topk_list: |
| | calculate_recall(ranked_indices_list, gold_index_list, topk=topk) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|