| import argparse |
| import json |
| import os |
| import sys |
| from typing import List |
|
|
| import torch |
| import transformers |
| from peft import PeftModel |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig |
|
|
| from utils import * |
| from collator import TestCollator |
| from prompt import all_prompt |
| from evaluate import get_topk_results, get_metrics_results |
|
|
|
|
| def test(args): |
|
|
| set_seed(args.seed) |
| print(vars(args)) |
|
|
| device_map = {"": args.gpu_id} |
| device = torch.device("cuda",args.gpu_id) |
|
|
|
|
| tokenizer = LlamaTokenizer.from_pretrained(args.ckpt_path) |
| if args.lora: |
| model = LlamaForCausalLM.from_pretrained( |
| args.base_model, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| device_map=device_map, |
| ) |
| model.resize_token_embeddings(len(tokenizer)) |
| model = PeftModel.from_pretrained( |
| model, |
| args.ckpt_path, |
| torch_dtype=torch.bfloat16, |
| device_map=device_map, |
| ) |
| else: |
| model = LlamaForCausalLM.from_pretrained( |
| args.ckpt_path, |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| device_map=device_map, |
| ) |
| |
|
|
| if args.test_prompt_ids == "all": |
| if args.test_task.lower() == "seqrec": |
| prompt_ids = range(len(all_prompt["seqrec"])) |
| elif args.test_task.lower() == "itemsearch": |
| prompt_ids = range(len(all_prompt["itemsearch"])) |
| elif args.test_task.lower() == "fusionseqrec": |
| prompt_ids = range(len(all_prompt["fusionseqrec"])) |
| else: |
| prompt_ids = [int(_) for _ in args.test_prompt_ids.split(",")] |
|
|
| test_data = load_test_dataset(args) |
| collator = TestCollator(args, tokenizer) |
| all_items = test_data.get_all_items() |
|
|
|
|
| prefix_allowed_tokens = test_data.get_prefix_allowed_tokens_fn(tokenizer) |
|
|
| test_loader = DataLoader(test_data, batch_size=args.test_batch_size, collate_fn=collator, |
| shuffle=True, num_workers=4, pin_memory=True) |
|
|
|
|
| print("data num:", len(test_data)) |
|
|
| model.eval() |
|
|
| metrics = args.metrics.split(",") |
| all_prompt_results = [] |
| with torch.no_grad(): |
| for prompt_id in prompt_ids: |
|
|
| test_loader.dataset.set_prompt(prompt_id) |
| metrics_results = {} |
| total = 0 |
|
|
| for step, batch in enumerate(tqdm(test_loader)): |
| inputs = batch[0].to(device) |
| targets = batch[1] |
| total += len(targets) |
|
|
| output = model.generate( |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| max_new_tokens=10, |
| |
| prefix_allowed_tokens_fn=prefix_allowed_tokens, |
| num_beams=args.num_beams, |
| num_return_sequences=args.num_beams, |
| output_scores=True, |
| return_dict_in_generate=True, |
| early_stopping=True, |
| ) |
| output_ids = output["sequences"] |
| scores = output["sequences_scores"] |
|
|
| output = tokenizer.batch_decode( |
| output_ids, skip_special_tokens=True |
| ) |
| |
| topk_res = get_topk_results(output,scores,targets,args.num_beams, |
| all_items=all_items if args.filter_items else None) |
|
|
| batch_metrics_res = get_metrics_results(topk_res, metrics) |
| |
|
|
| for m, res in batch_metrics_res.items(): |
| if m not in metrics_results: |
| metrics_results[m] = res |
| else: |
| metrics_results[m] += res |
|
|
| if (step+1)%10 == 0: |
| temp={} |
| for m in metrics_results: |
| temp[m] = metrics_results[m] / total |
| print(temp) |
|
|
| for m in metrics_results: |
| metrics_results[m] = metrics_results[m] / total |
|
|
| all_prompt_results.append(metrics_results) |
| print("======================================================") |
| print("Prompt {} results: ".format(prompt_id), metrics_results) |
| print("======================================================") |
| print("") |
|
|
| mean_results = {} |
| min_results = {} |
| max_results = {} |
|
|
| for m in metrics: |
| all_res = [_[m] for _ in all_prompt_results] |
| mean_results[m] = sum(all_res)/len(all_res) |
| min_results[m] = min(all_res) |
| max_results[m] = max(all_res) |
|
|
| print("======================================================") |
| print("Mean results: ", mean_results) |
| print("Min results: ", min_results) |
| print("Max results: ", max_results) |
| print("======================================================") |
|
|
|
|
| save_data={} |
| save_data["test_prompt_ids"] = args.test_prompt_ids |
| save_data["mean_results"] = mean_results |
| save_data["min_results"] = min_results |
| save_data["max_results"] = max_results |
| save_data["all_prompt_results"] = all_prompt_results |
|
|
| with open(args.results_file, "w") as f: |
| json.dump(save_data, f, indent=4) |
|
|
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="LLMRec_test") |
| parser = parse_global_args(parser) |
| parser = parse_dataset_args(parser) |
| parser = parse_test_args(parser) |
|
|
| args = parser.parse_args() |
|
|
| test(args) |
|
|