| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from typing import Dict, Any, List |
| | from scipy.special import softmax |
| | import numpy as np |
| | import torch |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path="."): |
| | self.tokenizer = AutoTokenizer.from_pretrained(path) |
| | self.model = AutoModelForCausalLM.from_pretrained(path).to(device) |
| | |
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | data args: |
| | inputs (:obj: `str`) |
| | Return: |
| | A :obj:`list` | `dict`: will be serialized and returned |
| | """ |
| | |
| | input_text = data.pop("inputs", data) |
| | input_ids = self.tokenizer(input_text, return_tensors="pt").to(device) |
| | model_output = self.model(**input_ids) |
| |
|
| | |
| | offset = self._best_offset(input_ids['input_ids'], model_output) |
| | self.logits = model_output.logits[0][offset:] |
| | self.inputs = input_ids['input_ids'][0].cpu().numpy()[1:] |
| |
|
| | |
| | sorted, indicies = self.logits.sort(descending=True) |
| | indicies = indicies.cpu().numpy() |
| | self.sorted = sorted.cpu().detach().numpy() |
| |
|
| | |
| | def parse_tokens(idx): |
| | token_rank = np.where(indicies[idx] == self.inputs[idx])[0][0] |
| | upper_prob = np.sum(softmax(self.sorted[idx])[:token_rank]) |
| | return { |
| | "input": self.tokenizer.decode(self.inputs[idx]), |
| | "rank": token_rank, |
| | "prob": upper_prob, |
| | "most_likely": self.tokenizer.decode(self.logits[idx].argmax()), |
| | "position": idx} |
| |
|
| | tokens = [parse_tokens(idx) for idx in range(len(self.inputs))] |
| | return tokens |
| | |
| | @staticmethod |
| | def _best_offset(inputs, outputs): |
| | """Calculates overlap between input and output tokens""" |
| | MAX_OFFSET = 10 |
| |
|
| | |
| | top_outputs = outputs.logits[0].argmax(dim=-1).cpu().numpy() |
| |
|
| | |
| | matches = np.zeros((len(inputs), len(top_outputs))) |
| | for i, input in enumerate(inputs[:MAX_OFFSET]): |
| | for j, output in enumerate(top_outputs[:i]): |
| | if input == output: |
| | matches[j, i] = 1 |