| import torch |
| import torch.nn.functional as F |
| from torchvision import transforms |
| from PIL import Image |
| import timm |
| import json |
| import gradio as gr |
|
|
| DEVICE = "cpu" |
| IMG_SIZE = 384 |
|
|
| |
| with open('classes.json', 'r') as f: |
| classes = json.load(f) |
|
|
| def load_model(): |
| model = timm.create_model("tf_efficientnetv2_s", pretrained=False, num_classes=31, drop_rate=0.3) |
| ckpt = torch.load('model.pt', map_location=DEVICE) |
| model.load_state_dict(ckpt['model_state_dict']) |
| model.eval() |
| return model |
|
|
| model = load_model() |
|
|
| val_tfms = transforms.Compose([ |
| transforms.Resize((IMG_SIZE, IMG_SIZE)), |
| transforms.ToTensor(), |
| ]) |
|
|
| def predict(image): |
| img = val_tfms(image).unsqueeze(0).to(DEVICE) |
| with torch.no_grad(): |
| logits = model(img) |
| probs = F.softmax(logits, dim=1) |
| top_prob, top_idx = probs.topk(1, dim=1) |
| idx_to_class = {v:k for k,v in classes.items()} |
| label = idx_to_class[top_idx[0][0].item()] |
| return f"{label}|({top_prob[0][0].item()})" |
|
|
| gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text").launch() |
|
|