| from typing import Dict, List, Any |
| from ultralytics import YOLO |
| import os |
| import torch |
| import torch.nn as nn |
| import torchvision.transforms as T |
| from PIL import Image |
|
|
| class LinearClassifier(torch.nn.Module): |
| def __init__(self, input_dim=384, output_dim=7): |
| super(LinearClassifier, self).__init__() |
|
|
| self.linear = torch.nn.Linear(input_dim, output_dim) |
| self.linear.weight.data.normal_(mean=0.0, std=0.01) |
| self.linear.bias.data.zero_() |
|
|
| def forward(self, x): |
| return self.linear(x) |
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| |
| self.dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') |
| self.device = torch.device('cuda' if torch.cuda.is_available() else "cpu") |
| self.dinov2_vits14.to(self.device) |
| print('Successfully load dinov2_vits14 model') |
| |
| self.yolov8_model = YOLO(os.path.join(path, 'yolov8_2023-07-19_yolov8m.pt')) |
|
|
| self.linear_model = LinearClassifier() |
| self.linear_model.load_state_dict(torch.load(os.path.join(path, 'linear_2023-07-18_v0.2.pt'))) |
| self.linear_model.eval() |
| |
| self.transform_image = T.Compose([ |
| T.ToTensor(), |
| T.Resize(244), |
| T.CenterCrop(224), |
| T.Normalize([0.5], [0.5]) |
| ]) |
|
|
| with open(os.path.join(path, 'labels.txt'), 'r') as f: |
| self.labels = f.read().split(',') |
|
|
| self.name_en2vi = { |
| "loggerhead": "Quản đồng", |
| "green": "Vích", |
| "leatherback": "Rùa da", |
| "hawksbill": "Đồi mồi", |
| "kemp_ridley": "Vích Kemp", |
| "olive_ridley": "Đồi mồi dứa", |
| "flatback": "Rùa lưng phẳng" |
| } |
| |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| data args: |
| inputs (:obj: `str` | `PIL.Image` | `np.array`) |
| kwargs |
| Return: |
| A :obj:`list` | `dict`: will be serialized and returned |
| """ |
| |
| result = self.yolov8_model(data['inputs']) |
| |
| img = result[0].orig_img[:,:,::-1] |
| H, W, _ = img.shape |
| annotated = img.copy() |
| |
| try: |
| x1, y1, x2, y2 = result[0].boxes.xyxy.numpy().astype('int')[0] |
| if result[0].boxes.conf[0].item() < 0.75: |
| return img.tolist(), "🤔 Hmm... Vích AI không thấy bạn rùa nào trong bức ảnh này. Bạn hãy tải lên một bức hình khác nhé." |
| else: |
| annotated = result[0].plot(labels=False, conf=False)[:,:,::-1] |
| except: |
| |
| return img.tolist(), "🤔 Hmm... Vích AI không thấy bạn rùa nào trong bức ảnh này. Bạn hãy tải lên một bức hình khác nhé." |
|
|
| h, w = y2-y1, x2-x1 |
| offset = abs(h-w) // 2 |
| if h > w: |
| x1 = max(x1 - offset, 0) |
| x2 = min(x2 + offset, W) |
| else: |
| y1 = max(y1 - offset, 0) |
| y2 = min(y2 + offset, H) |
| cropped = img[y1:y2, x1:x2] |
|
|
| new_image = self.transform_image(Image.fromarray(cropped))[:3].unsqueeze(0) |
| embedding = self.dinov2_vits14(new_image.to(self.device)) |
| prediction = self.linear_model(embedding) |
| percentage = nn.Softmax(dim=1)(prediction).detach().numpy().round(2)[0].tolist() |
| result = {} |
| |
| for i in range(len(self.labels)): |
| result[self.name_en2vi[self.labels[i]]] = percentage[i] |
|
|
| |
| return annotated.tolist(), result |