codingmonster1234's picture
add easy split for label and probability
7a0922e verified
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import timm
import json
import gradio as gr
DEVICE = "cpu"
IMG_SIZE = 384
# Load classes
with open('classes.json', 'r') as f:
classes = json.load(f)
def load_model():
model = timm.create_model("tf_efficientnetv2_s", pretrained=False, num_classes=31, drop_rate=0.3)
ckpt = torch.load('model.pt', map_location=DEVICE)
model.load_state_dict(ckpt['model_state_dict'])
model.eval()
return model
model = load_model()
val_tfms = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
def predict(image):
img = val_tfms(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(img)
probs = F.softmax(logits, dim=1)
top_prob, top_idx = probs.topk(1, dim=1)
idx_to_class = {v:k for k,v in classes.items()}
label = idx_to_class[top_idx[0][0].item()]
return f"{label}|({top_prob[0][0].item()})"
gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text").launch()