| |
| |
| """ |
| 并行评测脚本:针对 Qwen2.5-7B-Math(或你微调后的 HF 格式权重) |
| |
| 单卡示例: |
| python simple_valid.py \ |
| --model_path /pfs/lichenyi/work/finetune_output_train1/checkpoint-300 \ |
| --data_path /pfs/lichenyi/work/evaluation/valid.json \ |
| --dtype bf16 \ |
| --use_system \ |
| --temperature 0.0 |
| |
| 多卡示例(4 卡): |
| torchrun --nproc_per_node 4 simple_valid.py \ |
| --model_path /pfs/lichenyi/work/finetune_output_train1/checkpoint-300 \ |
| --data_path /pfs/lichenyi/work/evaluation/valid.json \ |
| --dtype bf16 \ |
| --use_system \ |
| --temperature 0.0 |
| |
| 若不显式传 --out_path,将自动写入: |
| /pfs/lichenyi/work/evaluation/predictions/predictions_<basename(model_path)>.json |
| 例如: |
| /pfs/lichenyi/work/evaluation/predictions/predictions_checkpoint-300.json |
| """ |
|
|
| import argparse |
| import json |
| import os |
| from typing import List, Dict, Any |
|
|
| import torch |
| import torch.distributed as dist |
| from tqdm import tqdm |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
|
| |
|
|
| def load_model( |
| model_path: str, |
| load_in_8bit: bool, |
| load_in_4bit: bool, |
| dtype: str, |
| device_map="auto", |
| ): |
| kwargs = {} |
| if load_in_4bit: |
| |
| kwargs.update(dict(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)) |
| elif load_in_8bit: |
| |
| kwargs.update(dict(load_in_8bit=True)) |
| else: |
| |
| if dtype == "bf16" and torch.cuda.is_available(): |
| kwargs.update(dict(dtype=torch.bfloat16)) |
| else: |
| kwargs.update(dict(dtype=torch.float16)) |
|
|
| if torch.cuda.is_available(): |
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| device_map=device_map, |
| trust_remote_code=True, |
| **kwargs, |
| ) |
| model.eval() |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| model_path, |
| trust_remote_code=True, |
| use_fast=True, |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| return model, tokenizer |
|
|
|
|
| |
|
|
| def canonicalize_human(value: str) -> str: |
| |
| return value.split(":::")[0].strip() |
|
|
| def decode_only_new(gen_ids: torch.Tensor, prompt_len: int, tokenizer) -> str: |
| new_tokens = gen_ids[0, prompt_len:] |
| text = tokenizer.decode(new_tokens, skip_special_tokens=False) |
|
|
| |
| stop_markers = [] |
|
|
| |
| if getattr(tokenizer, "eos_token", None): |
| stop_markers.append(tokenizer.eos_token) |
|
|
| |
| stop_markers.extend([ |
| "<|im_end|>", |
| "<|endoftext|>", |
| "<end_of_text>", |
| ]) |
|
|
| for m in stop_markers: |
| if m and m in text: |
| text = text.split(m)[0] |
| break |
|
|
| |
| lines = text.splitlines() |
| block = [] |
| for ln in lines: |
| if ln.strip() == "": |
| break |
| block.append(ln) |
| text = "\n".join(block).strip() |
|
|
| return text |
|
|
|
|
|
|
| def build_model_inputs(messages, tokenizer, device): |
| """ |
| 兼容有/没有 chat_template 的 Qwen2.5-7B-Math: |
| - 优先用 tokenizer.apply_chat_template |
| - 如果你的 Math 模型没带 chat_template,则退化为简单字符串拼接 + tokenizer() |
| """ |
| |
| try: |
| model_inputs = tokenizer.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| ) |
| model_inputs = {k: v.to(device) for k, v in model_inputs.items()} |
| return model_inputs |
| except Exception: |
| |
| |
| text_parts = [] |
| for m in messages: |
| role = m["role"] |
| content = m["content"] |
| if role == "system": |
| text_parts.append(f"[SYSTEM]\n{content}\n") |
| elif role == "user": |
| text_parts.append(f"[USER]\n{content}\n") |
| |
| text = "\n".join(text_parts) + "\n[ASSISTANT]\n" |
|
|
| enc = tokenizer( |
| text, |
| return_tensors="pt", |
| ) |
| enc = {k: v.to(device) for k, v in enc.items()} |
| return enc |
|
|
|
|
| |
|
|
| def setup_distributed(): |
| """ |
| 如果用 torchrun 启动,就初始化分布式;否则退化为单进程。 |
| 返回: (distributed, rank, world_size, local_rank) |
| """ |
| world_size = int(os.environ.get("WORLD_SIZE", "1")) |
| distributed = world_size > 1 |
|
|
| if not distributed: |
| return False, 0, 1, 0 |
|
|
| dist.init_process_group(backend="nccl") |
| rank = dist.get_rank() |
| world_size = dist.get_world_size() |
| local_rank = int(os.environ.get("LOCAL_RANK", 0)) |
|
|
| torch.cuda.set_device(local_rank) |
|
|
| return True, rank, world_size, local_rank |
|
|
|
|
| |
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--model_path", type=str, required=True, |
| help="本地模型路径或HF模型名,如 /pfs/.../Qwen2.5-7B-Math") |
| ap.add_argument("--data_path", type=str, required=True, help="测试集 JSON 路径") |
| ap.add_argument( |
| "--out_path", |
| type=str, |
| default="", |
| help="输出预测 JSON 路径;留空则自动根据 model_path 生成", |
| ) |
| ap.add_argument("--max_new_tokens", type=int, default=128) |
| ap.add_argument("--temperature", type=float, default=0.1) |
| ap.add_argument("--top_p", type=float, default=0.95) |
| ap.add_argument("--load_in_8bit", action="store_true") |
| ap.add_argument("--load_in_4bit", action="store_true") |
| ap.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16") |
| ap.add_argument("--use_system", action="store_true", help="把样本里的 system 也塞到对话中") |
| args = ap.parse_args() |
|
|
| |
| distributed, rank, world_size, local_rank = setup_distributed() |
|
|
| if distributed and rank == 0: |
| print(f"[INFO] Distributed inference, world_size={world_size}") |
|
|
| if distributed: |
| |
| device_map = {"": local_rank} |
| else: |
| device_map = "auto" |
|
|
| |
| if args.out_path: |
| out_path = args.out_path |
| os.makedirs(out_path, exist_ok=True) |
|
|
| |
| base_name = os.path.basename(os.path.normpath(args.model_path)) |
| if not base_name: |
| base_name = os.path.basename(args.model_path.rstrip("/")) |
|
|
| out_path = os.path.join( |
| args.out_path, |
| f"predictions_{base_name}.json", |
| ) |
| else: |
| |
| base_out_dir = "/pfs/lichenyi/work/evaluation/predictions" |
| os.makedirs(base_out_dir, exist_ok=True) |
|
|
| |
| base_name = os.path.basename(os.path.normpath(args.model_path)) |
| if not base_name: |
| base_name = os.path.basename(args.model_path.rstrip("/")) |
|
|
| out_path = os.path.join( |
| base_out_dir, |
| f"predictions_{base_name}.json", |
| ) |
|
|
| if rank == 0: |
| print(f"[INFO] Output path: {out_path}") |
|
|
| |
| if rank == 0: |
| print(f"[INFO] Loading model from {args.model_path} ...") |
|
|
| model, tokenizer = load_model( |
| args.model_path, |
| args.load_in_8bit, |
| args.load_in_4bit, |
| args.dtype, |
| device_map=device_map, |
| ) |
| |
| |
| extra_eos_tokens = ["<|im_end|>", "<|endoftext|>", "<end_of_text>"] |
| eos_ids = set() |
|
|
| if getattr(tokenizer, "eos_token_id", None) is not None: |
| if isinstance(tokenizer.eos_token_id, int): |
| eos_ids.add(tokenizer.eos_token_id) |
| else: |
| eos_ids.update(tokenizer.eos_token_id) |
|
|
| vocab = tokenizer.get_vocab() |
| for tok in extra_eos_tokens: |
| if tok in vocab: |
| eos_ids.add(vocab[tok]) |
|
|
| if len(eos_ids) == 0: |
| eos_token_id = None |
| elif len(eos_ids) == 1: |
| eos_token_id = next(iter(eos_ids)) |
| else: |
| |
| eos_token_id = list(eos_ids) |
|
|
|
|
| |
| if rank == 0: |
| print(f"[INFO] Loading dataset from {args.data_path} ...") |
|
|
| with open(args.data_path, "r", encoding="utf-8") as f: |
| dataset: List[Dict[str, Any]] = json.load(f) |
| num_samples = len(dataset) |
|
|
| |
| indices = list(range(rank, num_samples, world_size)) |
|
|
| if rank == 0: |
| iter_indices = tqdm(indices, desc="Running inference") |
| else: |
| iter_indices = indices |
|
|
| results = [] |
|
|
| for idx in iter_indices: |
| item = dataset[idx] |
|
|
| |
| system_text = item.get("system", "").strip() |
| prompt_text = "" |
| gt_text = "" |
|
|
| |
| for turn in item.get("conversations", []): |
| if turn.get("from") == "human": |
| prompt_text = canonicalize_human(turn.get("value", "")) |
| elif turn.get("from") == "gpt": |
| gt_text = turn.get("value", "").strip() |
|
|
| |
| messages = [] |
| if args.use_system and system_text: |
| messages.append({"role": "system", "content": system_text}) |
| messages.append({"role": "user", "content": prompt_text}) |
|
|
| |
| model_inputs = build_model_inputs(messages, tokenizer, model.device) |
|
|
| gen_kwargs = dict( |
| max_new_tokens=args.max_new_tokens, |
| do_sample=args.temperature > 0, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| if eos_token_id is not None: |
| gen_kwargs["eos_token_id"] = eos_token_id |
|
|
| with torch.no_grad(): |
| output_ids = model.generate( |
| **model_inputs, |
| **gen_kwargs, |
| ) |
|
|
| prompt_len = model_inputs["input_ids"].shape[-1] |
| pred = decode_only_new(output_ids, prompt_len, tokenizer) |
|
|
| results.append({ |
| "id": idx, |
| "system": system_text if args.use_system else "", |
| "prompt": prompt_text, |
| "ground_truth": gt_text, |
| "model_output": pred |
| }) |
|
|
| |
| if distributed: |
| |
| all_results = [None for _ in range(world_size)] |
| dist.all_gather_object(all_results, results) |
|
|
| if rank == 0: |
| merged = [] |
| for part in all_results: |
| merged.extend(part) |
| merged.sort(key=lambda x: x["id"]) |
|
|
| with open(out_path, "w", encoding="utf-8") as f: |
| json.dump(merged, f, ensure_ascii=False, indent=2) |
| print(f"[OK] 写入 {out_path} (共 {len(merged)} 条)") |
|
|
| dist.barrier() |
| dist.destroy_process_group() |
| else: |
| with open(out_path, "w", encoding="utf-8") as f: |
| json.dump(results, f, ensure_ascii=False, indent=2) |
| print(f"[OK] 写入 {out_path} (共 {len(results)} 条)") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|