| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Data handling utilities for PEFT benchmarking. |
| """ |
|
|
| import json |
| import os |
| from typing import Optional |
|
|
| from transformers import PreTrainedTokenizer |
| from utils import BenchmarkConfig |
|
|
|
|
| DEFAULT_PROMPTS_PATH = os.path.join(os.path.dirname(__file__), "configs", "prompts.json") |
|
|
|
|
| def load_test_prompts(config: dict) -> dict[str, list[str]]: |
| """ |
| Load prompts from JSON file. |
| |
| Args: |
| config: Configuration containing prompts file path |
| |
| Returns: |
| dictionary with prompts by category |
| """ |
| prompts_file = getattr(config, "prompts_file", DEFAULT_PROMPTS_PATH) |
|
|
| with open(prompts_file) as f: |
| prompts = json.load(f) |
|
|
| return prompts |
|
|
|
|
| def truncate_prompt_for_model( |
| prompt: str, |
| tokenizer: PreTrainedTokenizer, |
| max_length: Optional[int] = None, |
| reserve_output_tokens: int = 50, |
| ) -> str: |
| """ |
| Truncate a prompt to fit within the model's context window. |
| |
| Args: |
| prompt: Input prompt |
| tokenizer: Model tokenizer |
| max_length: Maximum sequence length (if None, uses model's max_length) |
| reserve_output_tokens: Number of tokens to reserve for response |
| |
| Returns: |
| Truncated prompt |
| """ |
| if max_length is None: |
| if hasattr(tokenizer, "model_max_length"): |
| max_length = tokenizer.model_max_length |
| else: |
| max_length = 2048 |
|
|
| max_prompt_length = max_length - reserve_output_tokens |
| input_ids = tokenizer.encode(prompt, return_tensors="pt")[0] |
|
|
| if len(input_ids) <= max_prompt_length: |
| return prompt |
|
|
| truncated_ids = input_ids[:max_prompt_length] |
| truncated_prompt = tokenizer.decode(truncated_ids, skip_special_tokens=True) |
|
|
| return truncated_prompt |
|
|
|
|
| def prepare_benchmark_prompts( |
| config: BenchmarkConfig, |
| tokenizer: PreTrainedTokenizer, |
| max_input_length: Optional[int] = None, |
| seed: int = 42, |
| ) -> dict[str, list[str]]: |
| """ |
| Prepare prompts for benchmarking, ensuring appropriate length and variety. |
| Always returns all prompt categories for consistent benchmarking. |
| |
| Args: |
| config: Benchmark configuration |
| tokenizer: Model tokenizer |
| max_input_length: Maximum input length (overrides model default if provided) |
| seed: Random seed (kept for backwards compatibility) |
| |
| Returns: |
| Dictionary with processed prompts by category (all categories included) |
| """ |
| all_prompts = load_test_prompts(config) |
|
|
| processed_prompts = {} |
| for category, prompts in all_prompts.items(): |
| truncated_prompts = [ |
| truncate_prompt_for_model( |
| prompt, |
| tokenizer, |
| max_length=max_input_length, |
| reserve_output_tokens=getattr(config, "reserve_output_tokens", 50), |
| ) |
| for prompt in prompts |
| ] |
|
|
| processed_prompts[category] = truncated_prompts |
|
|
| return processed_prompts |
|
|