| import torch |
| import gradio as gr |
| import json |
| from torchvision import transforms |
| import torch.nn.functional as F |
|
|
| TORCHSCRIPT_PATH = "res/screenclassification-resnet-noisystudent+web350k.torchscript" |
| LABELS_PATH = "res/class_map_enrico.json" |
| IMG_SIZE = 128 |
|
|
| model = torch.jit.load(TORCHSCRIPT_PATH) |
|
|
| with open(LABELS_PATH, "r") as f: |
| label2Idx = json.load(f)["label2Idx"] |
| |
| img_transforms = transforms.Compose([ |
| transforms.Resize(IMG_SIZE), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
| ]) |
| |
| def predict(img): |
| img_input = img_transforms(img).unsqueeze(0) |
| predictions = F.softmax(model(img_input), dim=-1)[0] |
| confidences = {} |
| for label in label2Idx: |
| confidences[label] = float(predictions[int(label2Idx[label])]) |
| |
| return confidences |
| |
| example_imgs = [ |
| "res/example.jpg", |
| "res/screenlane-snapchat-profile.jpg", |
| "res/screenlane-snapchat-settings.jpg", |
| "res/example_pair1.jpg", |
| "res/example_pair2.jpg" |
| ] |
|
|
| interface = gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=5), examples=example_imgs) |
|
|
| interface.launch() |
|
|