| from typing import Dict, Any |
| import torch |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| self.tokenizer = AutoTokenizer.from_pretrained(path) |
| self.tokenizer.add_bos_token = True |
|
|
| self.model = AutoModelForCausalLM.from_pretrained( |
| path, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16 |
| ).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| self.generator = pipeline( |
| "text-generation", |
| model=self.model, |
| tokenizer=self.tokenizer, |
| device=0 if torch.cuda.is_available() else -1, |
| return_full_text=False, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16 |
| ) |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| prompt = data.get("inputs", "") |
| if not prompt: |
| return {"error": "Missing 'inputs' field."} |
|
|
| defaults = { |
| "max_new_tokens": 100, |
| "do_sample": True, |
| "temperature": 0.7, |
| "top_p": 0.9, |
| "eos_token_id": self.tokenizer.eos_token_id |
| } |
|
|
| generation_args = {**defaults, **data.get("parameters", {})} |
|
|
| try: |
| outputs = self.generator(prompt, **generation_args) |
| output_text = outputs[0]["generated_text"].strip() |
|
|
| finish_reason = "stop" |
| if len(self.tokenizer.encode(output_text)) >= generation_args["max_new_tokens"]: |
| finish_reason = "length" |
|
|
| return { |
| "choices": [{ |
| "message": { |
| "role": "assistant", |
| "content": output_text |
| }, |
| "finish_reason": finish_reason |
| }] |
| } |
|
|
| except Exception as e: |
| return {"error": str(e)} |
|
|