creative-help / handler.py
roemmele's picture
Upload folder using huggingface_hub
edbfc07 verified
# coding: utf-8
"""
Custom Inference Handler for RNNLM (creative-help) on Hugging Face Inference Endpoints.
Implements EndpointHandler as described in:
https://huggingface.co/docs/inference-endpoints/en/guides/custom_handler
The handler loads the RNNLM model with entity adaptation support and serves
text generation requests via the Inference API.
"""
import os
import sys
from typing import Any, Dict, List, Union
class EndpointHandler:
"""
Custom handler for RNNLM text generation on Hugging Face Inference Endpoints.
Loads the model, tokenizer, and pipeline at init; serves generation requests in __call__.
"""
def __init__(self, path: str = ""):
"""
Initialize the handler. Called when the Endpoint starts.
:param path: Path to the model repository (model weights, config, tokenizer).
"""
self.path = path or "."
self.path = os.path.abspath(self.path)
# Add model repo to path so we can import rnnlm_model
if self.path not in sys.path:
sys.path.insert(0, self.path)
# Register custom model architecture with Transformers
from transformers import AutoConfig, AutoModelForCausalLM
from rnnlm_model import (
RNNLMConfig,
RNNLMForCausalLM,
RNNLMTokenizer,
RNNLMTextGenerationPipeline,
)
AutoConfig.register("rnnlm", RNNLMConfig)
AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM)
# Load model and tokenizer
self.model = AutoModelForCausalLM.from_pretrained(
self.path,
trust_remote_code=True,
)
self.tokenizer = RNNLMTokenizer.from_pretrained(self.path)
# Create text generation pipeline with entity adaptation
self.pipeline = RNNLMTextGenerationPipeline(
model=self.model,
tokenizer=self.tokenizer,
)
def __call__(self, data: Dict[str, Any]) -> Union[List[Dict[str, str]], Dict[str, Any]]:
"""
Handle inference requests. Called on every API request.
:param data: Request payload with "inputs" (prompt string or list) and optional "parameters".
:return: List of dicts with "generated_text" key(s), or single dict for compatibility.
"""
inputs = data.pop("inputs", None)
if inputs is None:
return {"error": "Missing 'inputs' in request body"}
parameters = data.pop("parameters", data) or {}
if not isinstance(parameters, dict):
parameters = {}
# Default generation parameters
gen_kwargs = {
"max_new_tokens": parameters.get("max_new_tokens", 50),
"do_sample": parameters.get("do_sample", True),
"temperature": parameters.get("temperature", 1.0),
"pad_token_id": self.tokenizer.pad_token_id,
}
# Allow override of other params (top_p, top_k, repetition_penalty, etc.)
for k, v in parameters.items():
if k not in gen_kwargs:
gen_kwargs[k] = v
# Run generation
try:
result = self.pipeline(inputs, **gen_kwargs)
except Exception as e:
return {"error": str(e)}
# Ensure we return a list of dicts (API expects list for batch)
if isinstance(result, list):
return result
return [result] if isinstance(result, dict) else [{"generated_text": str(result)}]