| from typing import Dict, List, Any |
| import torch |
| import torchvision |
| import torch.nn as nn |
| import torchvision.transforms as transforms |
|
|
| from PIL import Image |
|
|
|
|
|
|
| MODEL_PATH = 'website_classifier.pth' |
|
|
| |
| def process_image(image): |
| |
| img = image.convert("RGB") |
| |
| |
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| ]) |
| |
| img_t = transform(img) |
| |
| |
| img_u = torch.unsqueeze(img_t, 0) |
| |
| return img_u |
| |
| class PreTrainedPipeline(): |
| def __init__(self, path=""): |
| self.model = torchvision.models.resnet18(pretrained=True) |
| num_ftrs = self.model.fc.in_features |
| self.model.fc = nn.Linear(num_ftrs, 3) |
| self.transform = transforms.Compose( |
| [transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
| self.model.load_state_dict(torch.load(MODEL_PATH)) |
| self.processor = process_image |
| self.classes = ['forum', 'general', 'marketplace'] |
| self.classe_to_idx = {'forum': 0, 'general': 1, 'marketplace': 2} |
| |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| image = data.pop("inputs", data) |
|
|
| |
| image = self.processor(image) |
|
|
| |
| outputs = self.model.generate(image) |
| |
| |
| _, predicted = torch.max(outputs, 1) |
| prediction = self.classes[predicted[0]] |
| return {"class":prediction[0]} |