# coding: utf-8 """ Custom Inference Handler for RNNLM (creative-help) on Hugging Face Inference Endpoints. Implements EndpointHandler as described in: https://huggingface.co/docs/inference-endpoints/en/guides/custom_handler The handler loads the RNNLM model with entity adaptation support and serves text generation requests via the Inference API. """ import os import sys from typing import Any, Dict, List, Union class EndpointHandler: """ Custom handler for RNNLM text generation on Hugging Face Inference Endpoints. Loads the model, tokenizer, and pipeline at init; serves generation requests in __call__. """ def __init__(self, path: str = ""): """ Initialize the handler. Called when the Endpoint starts. :param path: Path to the model repository (model weights, config, tokenizer). """ self.path = path or "." self.path = os.path.abspath(self.path) # Add model repo to path so we can import rnnlm_model if self.path not in sys.path: sys.path.insert(0, self.path) # Register custom model architecture with Transformers from transformers import AutoConfig, AutoModelForCausalLM from rnnlm_model import ( RNNLMConfig, RNNLMForCausalLM, RNNLMTokenizer, RNNLMTextGenerationPipeline, ) AutoConfig.register("rnnlm", RNNLMConfig) AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM) # Load model and tokenizer self.model = AutoModelForCausalLM.from_pretrained( self.path, trust_remote_code=True, ) self.tokenizer = RNNLMTokenizer.from_pretrained(self.path) # Create text generation pipeline with entity adaptation self.pipeline = RNNLMTextGenerationPipeline( model=self.model, tokenizer=self.tokenizer, ) def __call__(self, data: Dict[str, Any]) -> Union[List[Dict[str, str]], Dict[str, Any]]: """ Handle inference requests. Called on every API request. :param data: Request payload with "inputs" (prompt string or list) and optional "parameters". :return: List of dicts with "generated_text" key(s), or single dict for compatibility. """ inputs = data.pop("inputs", None) if inputs is None: return {"error": "Missing 'inputs' in request body"} parameters = data.pop("parameters", data) or {} if not isinstance(parameters, dict): parameters = {} # Default generation parameters gen_kwargs = { "max_new_tokens": parameters.get("max_new_tokens", 50), "do_sample": parameters.get("do_sample", True), "temperature": parameters.get("temperature", 1.0), "pad_token_id": self.tokenizer.pad_token_id, } # Allow override of other params (top_p, top_k, repetition_penalty, etc.) for k, v in parameters.items(): if k not in gen_kwargs: gen_kwargs[k] = v # Run generation try: result = self.pipeline(inputs, **gen_kwargs) except Exception as e: return {"error": str(e)} # Ensure we return a list of dicts (API expects list for batch) if isinstance(result, list): return result return [result] if isinstance(result, dict) else [{"generated_text": str(result)}]