| | |
| | """ |
| | Custom Inference Handler for RNNLM (creative-help) on Hugging Face Inference Endpoints. |
| | |
| | Implements EndpointHandler as described in: |
| | https://huggingface.co/docs/inference-endpoints/en/guides/custom_handler |
| | |
| | The handler loads the RNNLM model with entity adaptation support and serves |
| | text generation requests via the Inference API. |
| | """ |
| |
|
| | import os |
| | import sys |
| | from typing import Any, Dict, List, Union |
| |
|
| |
|
| | class EndpointHandler: |
| | """ |
| | Custom handler for RNNLM text generation on Hugging Face Inference Endpoints. |
| | Loads the model, tokenizer, and pipeline at init; serves generation requests in __call__. |
| | """ |
| |
|
| | def __init__(self, path: str = ""): |
| | """ |
| | Initialize the handler. Called when the Endpoint starts. |
| | :param path: Path to the model repository (model weights, config, tokenizer). |
| | """ |
| | self.path = path or "." |
| | self.path = os.path.abspath(self.path) |
| |
|
| | |
| | if self.path not in sys.path: |
| | sys.path.insert(0, self.path) |
| |
|
| | |
| | from transformers import AutoConfig, AutoModelForCausalLM |
| | from rnnlm_model import ( |
| | RNNLMConfig, |
| | RNNLMForCausalLM, |
| | RNNLMTokenizer, |
| | RNNLMTextGenerationPipeline, |
| | ) |
| |
|
| | AutoConfig.register("rnnlm", RNNLMConfig) |
| | AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM) |
| |
|
| | |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | self.path, |
| | trust_remote_code=True, |
| | ) |
| | self.tokenizer = RNNLMTokenizer.from_pretrained(self.path) |
| |
|
| | |
| | self.pipeline = RNNLMTextGenerationPipeline( |
| | model=self.model, |
| | tokenizer=self.tokenizer, |
| | ) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Union[List[Dict[str, str]], Dict[str, Any]]: |
| | """ |
| | Handle inference requests. Called on every API request. |
| | :param data: Request payload with "inputs" (prompt string or list) and optional "parameters". |
| | :return: List of dicts with "generated_text" key(s), or single dict for compatibility. |
| | """ |
| | inputs = data.pop("inputs", None) |
| | if inputs is None: |
| | return {"error": "Missing 'inputs' in request body"} |
| |
|
| | parameters = data.pop("parameters", data) or {} |
| | if not isinstance(parameters, dict): |
| | parameters = {} |
| |
|
| | |
| | gen_kwargs = { |
| | "max_new_tokens": parameters.get("max_new_tokens", 50), |
| | "do_sample": parameters.get("do_sample", True), |
| | "temperature": parameters.get("temperature", 1.0), |
| | "pad_token_id": self.tokenizer.pad_token_id, |
| | } |
| | |
| | for k, v in parameters.items(): |
| | if k not in gen_kwargs: |
| | gen_kwargs[k] = v |
| |
|
| | |
| | try: |
| | result = self.pipeline(inputs, **gen_kwargs) |
| | except Exception as e: |
| | return {"error": str(e)} |
| |
|
| | |
| | if isinstance(result, list): |
| | return result |
| | return [result] if isinstance(result, dict) else [{"generated_text": str(result)}] |
| |
|