File size: 1,944 Bytes
93cf8c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0149497
93cf8c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


class EndpointHandler:
    def __init__(self, path=""):
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.model.eval()

    def __call__(self, data: dict) -> dict:
        inputs = data.get("inputs", "")
        parameters = data.get("parameters", {})

        max_new_tokens = parameters.get("max_new_tokens", 1024)
        temperature = parameters.get("temperature", 0.3)

        # Format with ChatML if not already formatted
        if "<|im_start|>" not in inputs:
            inputs = (
                f"<|im_start|>system\n"
                f"You are an expert network architect "
                f"with CCDE-level expertise.\n<|im_end|>\n"
                f"<|im_start|>user\n{inputs}<|im_end|>\n"
                f"<|im_start|>assistant\n"
            )

        tokenized = self.tokenizer(
            inputs,
            return_tensors="pt"
        ).to(self.model.device)

        with torch.no_grad():
            output = self.model.generate(  # type: ignore[union-attr]
                **tokenized,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=temperature > 0,
                pad_token_id=self.tokenizer.eos_token_id,
                eos_token_id=self.tokenizer.convert_tokens_to_ids(
                    "<|im_end|>"
                )
            )

        # Decode only new tokens
        new_tokens = output[0][tokenized["input_ids"].shape[1]:]
        response = self.tokenizer.decode(
            new_tokens,
            skip_special_tokens=False
        )

        # Clean up stop tokens
        response = response.replace("<|im_end|>", "").strip()

        return {"generated_text": response}