| | import os |
| | import logging |
| | import traceback |
| | from typing import Dict, List, Any |
| |
|
| | from nemo_skills.inference.server.code_execution_model import get_code_execution_model |
| | from nemo_skills.code_execution.sandbox import get_sandbox |
| | from nemo_skills.prompt.utils import get_prompt |
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class EndpointHandler: |
| | """Custom endpoint handler for NeMo Skills code execution inference.""" |
| | |
| | def __init__(self): |
| | """ |
| | Initialize the handler with the model and prompt configurations. |
| | """ |
| | self.model = None |
| | self.prompt = None |
| | self.initialized = False |
| | |
| | |
| | self.prompt_config_path = os.getenv("PROMPT_CONFIG_PATH", "generic/math") |
| | self.prompt_template_path = os.getenv("PROMPT_TEMPLATE_PATH", "openmath-instruct") |
| | |
| | def _initialize_components(self): |
| | """Initialize the model, sandbox, and prompt components lazily.""" |
| | if self.initialized: |
| | return |
| | |
| | try: |
| | logger.info("Initializing sandbox...") |
| | sandbox = get_sandbox(sandbox_type="local") |
| | |
| | logger.info("Initializing code execution model...") |
| | self.model = get_code_execution_model( |
| | server_type="vllm", |
| | sandbox=sandbox, |
| | host="127.0.0.1", |
| | port=5000 |
| | ) |
| | |
| | logger.info("Initializing prompt...") |
| | if self.prompt_config_path: |
| | self.prompt = get_prompt( |
| | prompt_config=self.prompt_config_path, |
| | prompt_template=self.prompt_template_path |
| | ) |
| | |
| | self.initialized = True |
| | logger.info("All components initialized successfully") |
| | |
| | except Exception as e: |
| | logger.warning(f"Failed to initialize the model") |
| | |
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | Process inference requests. |
| | |
| | Args: |
| | data: Dictionary containing the request data |
| | Expected keys: |
| | - inputs: str or list of str - the input prompts/problems |
| | - parameters: dict (optional) - generation parameters |
| | |
| | Returns: |
| | List of dictionaries containing the generated responses |
| | """ |
| | try: |
| | |
| | self._initialize_components() |
| | |
| | |
| | inputs = data.get("inputs", "") |
| | parameters = data.get("parameters", {}) |
| | |
| | |
| | if isinstance(inputs, str): |
| | prompts = [inputs] |
| | elif isinstance(inputs, list): |
| | prompts = inputs |
| | else: |
| | raise ValueError("inputs must be a string or list of strings") |
| | |
| | |
| | if self.prompt is not None: |
| | formatted_prompts = [] |
| | for prompt_text in prompts: |
| | formatted_prompt = self.prompt.fill({"problem": prompt_text, "total_code_executions": 8}) |
| | formatted_prompts.append(formatted_prompt) |
| | prompts = formatted_prompts |
| | |
| | |
| | extra_generate_params = {} |
| | if self.prompt is not None: |
| | extra_generate_params = self.prompt.get_code_execution_args() |
| | |
| | |
| | generation_params = { |
| | "tokens_to_generate": 12000, |
| | "temperature": 0.0, |
| | "top_p": 0.95, |
| | "top_k": 0, |
| | "repetition_penalty": 1.0, |
| | "random_seed": 0, |
| | } |
| | |
| | |
| | generation_params.update(parameters) |
| | generation_params.update(extra_generate_params) |
| | |
| | logger.info(f"Processing {len(prompts)} prompt(s)") |
| | |
| | |
| | outputs = self.model.generate( |
| | prompts=prompts, |
| | **generation_params |
| | ) |
| | |
| | |
| | results = [] |
| | for output in outputs: |
| | result = { |
| | "generated_text": output.get("generation", ""), |
| | "code_rounds_executed": output.get("code_rounds_executed", 0), |
| | } |
| | results.append(result) |
| | |
| | logger.info(f"Successfully processed {len(results)} request(s)") |
| | return results |
| | |
| | except Exception as e: |
| | logger.error(f"Error processing request: {str(e)}") |
| | logger.error(traceback.format_exc()) |
| | return [{"error": str(e), "generated_text": ""}] |