File size: 1,103 Bytes
a99d095
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a0922e
a99d095
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import timm
import json
import gradio as gr

DEVICE = "cpu"
IMG_SIZE = 384

# Load classes
with open('classes.json', 'r') as f:
    classes = json.load(f)

def load_model():
    model = timm.create_model("tf_efficientnetv2_s", pretrained=False, num_classes=31, drop_rate=0.3)
    ckpt = torch.load('model.pt', map_location=DEVICE)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    return model

model = load_model()

val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

def predict(image):
    img = val_tfms(image).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = model(img)
        probs = F.softmax(logits, dim=1)
        top_prob, top_idx = probs.topk(1, dim=1)
        idx_to_class = {v:k for k,v in classes.items()}
        label = idx_to_class[top_idx[0][0].item()]
        return f"{label}|({top_prob[0][0].item()})"

gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text").launch()