import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image import timm import json import gradio as gr DEVICE = "cpu" IMG_SIZE = 384 # Load classes with open('classes.json', 'r') as f: classes = json.load(f) def load_model(): model = timm.create_model("tf_efficientnetv2_s", pretrained=False, num_classes=31, drop_rate=0.3) ckpt = torch.load('model.pt', map_location=DEVICE) model.load_state_dict(ckpt['model_state_dict']) model.eval() return model model = load_model() val_tfms = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), ]) def predict(image): img = val_tfms(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = model(img) probs = F.softmax(logits, dim=1) top_prob, top_idx = probs.topk(1, dim=1) idx_to_class = {v:k for k,v in classes.items()} label = idx_to_class[top_idx[0][0].item()] return f"{label}|({top_prob[0][0].item()})" gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs="text").launch()