File size: 3,468 Bytes
edbfc07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# coding: utf-8
"""
Custom Inference Handler for RNNLM (creative-help) on Hugging Face Inference Endpoints.

Implements EndpointHandler as described in:
https://huggingface.co/docs/inference-endpoints/en/guides/custom_handler

The handler loads the RNNLM model with entity adaptation support and serves
text generation requests via the Inference API.
"""

import os
import sys
from typing import Any, Dict, List, Union


class EndpointHandler:
    """
    Custom handler for RNNLM text generation on Hugging Face Inference Endpoints.
    Loads the model, tokenizer, and pipeline at init; serves generation requests in __call__.
    """

    def __init__(self, path: str = ""):
        """
        Initialize the handler. Called when the Endpoint starts.
        :param path: Path to the model repository (model weights, config, tokenizer).
        """
        self.path = path or "."
        self.path = os.path.abspath(self.path)

        # Add model repo to path so we can import rnnlm_model
        if self.path not in sys.path:
            sys.path.insert(0, self.path)

        # Register custom model architecture with Transformers
        from transformers import AutoConfig, AutoModelForCausalLM
        from rnnlm_model import (
            RNNLMConfig,
            RNNLMForCausalLM,
            RNNLMTokenizer,
            RNNLMTextGenerationPipeline,
        )

        AutoConfig.register("rnnlm", RNNLMConfig)
        AutoModelForCausalLM.register(RNNLMConfig, RNNLMForCausalLM)

        # Load model and tokenizer
        self.model = AutoModelForCausalLM.from_pretrained(
            self.path,
            trust_remote_code=True,
        )
        self.tokenizer = RNNLMTokenizer.from_pretrained(self.path)

        # Create text generation pipeline with entity adaptation
        self.pipeline = RNNLMTextGenerationPipeline(
            model=self.model,
            tokenizer=self.tokenizer,
        )

    def __call__(self, data: Dict[str, Any]) -> Union[List[Dict[str, str]], Dict[str, Any]]:
        """
        Handle inference requests. Called on every API request.
        :param data: Request payload with "inputs" (prompt string or list) and optional "parameters".
        :return: List of dicts with "generated_text" key(s), or single dict for compatibility.
        """
        inputs = data.pop("inputs", None)
        if inputs is None:
            return {"error": "Missing 'inputs' in request body"}

        parameters = data.pop("parameters", data) or {}
        if not isinstance(parameters, dict):
            parameters = {}

        # Default generation parameters
        gen_kwargs = {
            "max_new_tokens": parameters.get("max_new_tokens", 50),
            "do_sample": parameters.get("do_sample", True),
            "temperature": parameters.get("temperature", 1.0),
            "pad_token_id": self.tokenizer.pad_token_id,
        }
        # Allow override of other params (top_p, top_k, repetition_penalty, etc.)
        for k, v in parameters.items():
            if k not in gen_kwargs:
                gen_kwargs[k] = v

        # Run generation
        try:
            result = self.pipeline(inputs, **gen_kwargs)
        except Exception as e:
            return {"error": str(e)}

        # Ensure we return a list of dicts (API expects list for batch)
        if isinstance(result, list):
            return result
        return [result] if isinstance(result, dict) else [{"generated_text": str(result)}]