| import os |
| import logging |
| import traceback |
| from typing import Dict, List, Any |
|
|
| from nemo_skills.inference.server.code_execution_model import get_code_execution_model |
| from nemo_skills.code_execution.sandbox import get_sandbox |
| from nemo_skills.prompt.utils import get_prompt |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class EndpointHandler: |
| """Custom endpoint handler for NeMo Skills code execution inference.""" |
| |
| def __init__(self): |
| """ |
| Initialize the handler with the model and prompt configurations. |
| """ |
| self.model = None |
| self.prompt = None |
| self.initialized = False |
| |
| |
| self.prompt_config_path = os.getenv("PROMPT_CONFIG_PATH", "generic/math") |
| self.prompt_template_path = os.getenv("PROMPT_TEMPLATE_PATH", "openmath-instruct") |
| |
| def _initialize_components(self): |
| """Initialize the model, sandbox, and prompt components lazily.""" |
| if self.initialized: |
| return |
| |
| try: |
| logger.info("Initializing sandbox...") |
| sandbox = get_sandbox(sandbox_type="local") |
| |
| logger.info("Initializing code execution model...") |
| self.model = get_code_execution_model( |
| server_type="vllm", |
| sandbox=sandbox, |
| host="127.0.0.1", |
| port=5000 |
| ) |
| |
| logger.info("Initializing prompt...") |
| if self.prompt_config_path: |
| self.prompt = get_prompt( |
| prompt_config=self.prompt_config_path, |
| prompt_template=self.prompt_template_path |
| ) |
| |
| self.initialized = True |
| logger.info("All components initialized successfully") |
| |
| except Exception as e: |
| logger.warning(f"Failed to initialize the model") |
| |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| Process inference requests. |
| |
| Args: |
| data: Dictionary containing the request data |
| Expected keys: |
| - inputs: str or list of str - the input prompts/problems |
| - parameters: dict (optional) - generation parameters |
| |
| Returns: |
| List of dictionaries containing the generated responses |
| """ |
| try: |
| |
| self._initialize_components() |
| |
| |
| inputs = data.get("inputs", "") |
| parameters = data.get("parameters", {}) |
| |
| |
| if isinstance(inputs, str): |
| prompts = [inputs] |
| elif isinstance(inputs, list): |
| prompts = inputs |
| else: |
| raise ValueError("inputs must be a string or list of strings") |
| |
| |
| if self.prompt is not None: |
| formatted_prompts = [] |
| for prompt_text in prompts: |
| formatted_prompt = self.prompt.fill({"problem": prompt_text, "total_code_executions": 8}) |
| formatted_prompts.append(formatted_prompt) |
| prompts = formatted_prompts |
| |
| |
| extra_generate_params = {} |
| if self.prompt is not None: |
| extra_generate_params = self.prompt.get_code_execution_args() |
| |
| |
| generation_params = { |
| "tokens_to_generate": 12000, |
| "temperature": 0.0, |
| "top_p": 0.95, |
| "top_k": 0, |
| "repetition_penalty": 1.0, |
| "random_seed": 0, |
| } |
| |
| |
| generation_params.update(parameters) |
| generation_params.update(extra_generate_params) |
| |
| logger.info(f"Processing {len(prompts)} prompt(s)") |
| |
| |
| outputs = self.model.generate( |
| prompts=prompts, |
| **generation_params |
| ) |
| |
| |
| results = [] |
| for output in outputs: |
| result = { |
| "generated_text": output.get("generation", ""), |
| "code_rounds_executed": output.get("code_rounds_executed", 0), |
| } |
| results.append(result) |
| |
| logger.info(f"Successfully processed {len(results)} request(s)") |
| return results |
| |
| except Exception as e: |
| logger.error(f"Error processing request: {str(e)}") |
| logger.error(traceback.format_exc()) |
| return [{"error": str(e), "generated_text": ""}] |