| from typing import List |
| from src.interface import ModelInterface |
| from src.data.classification_result import ClassificationResult |
| from transformers import AutoImageProcessor, ResNetForImageClassification |
| import torch |
|
|
| class Resnet50(ModelInterface): |
| def __init__(self): |
| print('init... clip vit model') |
| self.processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") |
| self.model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50") |
|
|
| def classify_image(self, image) -> List[ClassificationResult]: |
| |
| inputs = self.processor(images=image, return_tensors="pt") |
|
|
| |
| outputs = self.model(**inputs) |
| logits = outputs.logits.detach().numpy() |
|
|
| |
| probabilities = torch.nn.functional.softmax(torch.from_numpy(logits), dim=-1).numpy() |
|
|
| |
| top_5 = torch.argsort(torch.from_numpy(probabilities), axis=-1, descending=True)[0][:5].numpy() |
|
|
| |
| results = [ |
| ClassificationResult( |
| class_name=self.model.config.id2label[top_5[i]], |
| confidence=float(probabilities[0][top_5[i]]) |
| ) |
| for i in range(5) |
| ] |
|
|
| return results |