| from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline |
| import torch |
|
|
| def load_model(model_id): |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| device_map="auto", |
| torch_dtype=torch.float16, |
| load_in_4bit=True |
| ) |
| return model, tokenizer |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| self.model, self.tokenizer = load_model(path) |
| self.pipeline = TextGenerationPipeline( |
| model=self.model, |
| tokenizer=self.tokenizer |
| ) |
|
|
| def __call__(self, data): |
| |
| if isinstance(data, dict): |
| text = data.get("inputs", "") |
| else: |
| text = data |
|
|
| |
| generation_kwargs = { |
| "max_new_tokens": 512, |
| "temperature": 0.7, |
| "top_p": 0.95, |
| "repetition_penalty": 1.15, |
| "do_sample": True, |
| "pad_token_id": self.tokenizer.pad_token_id, |
| "eos_token_id": self.tokenizer.eos_token_id, |
| } |
|
|
| |
| if isinstance(data, dict) and "parameters" in data: |
| generation_kwargs.update(data["parameters"]) |
|
|
| try: |
| |
| outputs = self.pipeline( |
| text, |
| **generation_kwargs |
| ) |
|
|
| |
| if isinstance(outputs, list): |
| return [{"generated_text": output["generated_text"]} for output in outputs] |
| return [{"generated_text": outputs["generated_text"]}] |
|
|
| except Exception as e: |
| return [{"error": str(e)}] |