| |
| """ |
| Custom Inference Handler for RNNLM (creative-help) on Hugging Face Inference Endpoints. |
| |
| Implements EndpointHandler as described in: |
| https://huggingface.co/docs/inference-endpoints/en/guides/custom_handler |
| |
| The handler loads the RNNLM model with entity adaptation support and serves |
| text generation requests via the Inference API. |
| """ |
|
|
| import os |
| import sys |
| from typing import Any, Dict, List, Union |
|
|
|
|
| class EndpointHandler: |
| """ |
| Custom handler for RNNLM text generation on Hugging Face Inference Endpoints. |
| Loads the model, tokenizer, and pipeline at init; serves generation requests in __call__. |
| """ |
|
|
| def __init__(self, path: str = ""): |
| """ |
| Initialize the handler. Called when the Endpoint starts. |
| :param path: Path to the model repository (model weights, config, tokenizer). |
| """ |
| self.path = path or "." |
| self.path = os.path.abspath(self.path) |
|
|
| |
| if self.path not in sys.path: |
| sys.path.insert(0, self.path) |
|
|
| |
| from transformers import AutoConfig, AutoModelForCausalLM |
| from rnnlm_model import ( |
| RNNLMConfig, |
| RNNLMForCausalLM, |
| RNNLMTokenizer, |
| RNNLMTextGenerationPipeline, |
| ) |
|
|
| AutoConfig.register("rnnlm", RNNLMConfig) |
| AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM) |
|
|
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| self.path, |
| trust_remote_code=True, |
| ) |
| self.tokenizer = RNNLMTokenizer.from_pretrained(self.path) |
|
|
| |
| self.pipeline = RNNLMTextGenerationPipeline( |
| model=self.model, |
| tokenizer=self.tokenizer, |
| ) |
|
|
| def __call__(self, data: Dict[str, Any]) -> Union[List[Dict[str, str]], Dict[str, Any]]: |
| """ |
| Handle inference requests. Called on every API request. |
| :param data: Request payload with "inputs" (prompt string or list) and optional "parameters". |
| :return: List of dicts with "generated_text" key(s), or single dict for compatibility. |
| """ |
| inputs = data.pop("inputs", None) |
| if inputs is None: |
| return {"error": "Missing 'inputs' in request body"} |
|
|
| parameters = data.pop("parameters", data) or {} |
| if not isinstance(parameters, dict): |
| parameters = {} |
|
|
| |
| gen_kwargs = { |
| "max_new_tokens": parameters.get("max_new_tokens", 50), |
| "do_sample": parameters.get("do_sample", True), |
| "temperature": parameters.get("temperature", 1.0), |
| "pad_token_id": self.tokenizer.pad_token_id, |
| } |
| |
| for k, v in parameters.items(): |
| if k not in gen_kwargs: |
| gen_kwargs[k] = v |
|
|
| |
| try: |
| result = self.pipeline(inputs, **gen_kwargs) |
| except Exception as e: |
| return {"error": str(e)} |
|
|
| |
| if isinstance(result, list): |
| return result |
| return [result] if isinstance(result, dict) else [{"generated_text": str(result)}] |
|
|