| # handler.py | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| # check for GPU | |
| device = 0 if torch.cuda.is_available() else -1 | |
| # from PIL import Image | |
| # from torchvision.transforms import Compose, ConvertImageDtype, Normalize, PILToTensor, Resize | |
| # from torchvision.transforms.functional import InterpolationMode | |
| # from pyrovision.models import model_from_hf_hub | |
| # model = model_from_hf_hub("pyronear/mobilenet_v3_small").eval() | |
| # img = Image.open(path_to_an_image).convert("RGB") | |
| # # Preprocessing | |
| # config = model.default_cfg | |
| # transform = Compose([ | |
| # Resize(config['input_shape'][1:], interpolation=InterpolationMode.BILINEAR), | |
| # PILToTensor(), | |
| # ConvertImageDtype(torch.float32), | |
| # Normalize(config['mean'], config['std']) | |
| # ]) | |
| # input_tensor = transform(img).unsqueeze(0) | |
| # # Inference | |
| # with torch.inference_mode(): | |
| # output = model(input_tensor) | |
| # probs = output.squeeze(0).softmax(dim=0) | |
| # multi-model list | |
| multi_model_list = [ | |
| {"model_id": "distilbert-base-uncased-finetuned-sst-2-english", "task": "text-classification"}, | |
| {"model_id": "Helsinki-NLP/opus-mt-en-de", "task": "translation"}, | |
| {"model_id": "facebook/bart-large-cnn", "task": "summarization"}, | |
| {"model_id": "dslim/bert-base-NER", "task": "token-classification"}, | |
| {"model_id": "textattack/bert-base-uncased-ag-news", "task": "text-classification"}, | |
| ] | |
| class EndpointHandler(): | |
| def __init__(self, path=""): | |
| self.multi_model={} | |
| # load all the models onto device | |
| for model in multi_model_list: | |
| self.multi_model[model["model_id"]] = pipeline(model["task"], model=model["model_id"], device=device) | |
| def __call__(self, data): | |
| # deserialize incomin request | |
| inputs = data.pop("inputs", data) | |
| parameters = data.pop("parameters", None) | |
| model_id = data.pop("model_id", None) | |
| # check if model_id is in the list of models | |
| if model_id is None or model_id not in self.multi_model: | |
| raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}") | |
| # pass inputs with all kwargs in data | |
| prediction = {'output':'test'} | |
| # if parameters is not None: | |
| # prediction = self.multi_model[model_id](inputs, **parameters) | |
| # else: | |
| # prediction = self.multi_model[model_id](inputs) | |
| # # postprocess the prediction | |
| return prediction | |