| import albumentations |
| import cv2 |
| import torch |
| import timm |
| import gradio as gr |
| import numpy as np |
| import os |
| import random |
|
|
| device = torch.device('cpu') |
|
|
| labels = { |
| 0: 'bacterial_leaf_blight', |
| 1: 'bacterial_leaf_streak', |
| 2: 'bacterial_panicle_blight', |
| 3: 'blast', |
| 4: 'brown_spot', |
| 5: 'dead_heart', |
| 6: 'downy_mildew', |
| 7: 'hispa', |
| 8: 'normal', |
| 9: 'tungro' |
| } |
| |
| def inference_fn(model, image=None): |
| model.eval() |
| image = image.to(device) |
| with torch.no_grad(): |
| output = model(image.unsqueeze(0)) |
| out = output.sigmoid().detach().cpu().numpy().flatten() |
| return out |
| |
| |
| def predict(image=None) -> dict: |
| mean = (0.485, 0.456, 0.406) |
| std = (0.229, 0.224, 0.225) |
| |
| augmentations = albumentations.Compose( |
| [ |
| albumentations.Resize(256, 256), |
| albumentations.HorizontalFlip(p=0.5), |
| albumentations.VerticalFlip(p=0.5), |
| albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True), |
| ] |
| ) |
| |
| augmented = augmentations(image=image) |
| image = augmented["image"] |
| image = np.transpose(image, (2, 0, 1)) |
| image = torch.tensor(image, dtype=torch.float32) |
| model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10) |
| model.load_state_dict(torch.load("paddy_model.pth", map_location=torch.device(device))) |
| model.to(device) |
|
|
| predicted = inference_fn(model, image) |
| |
| return {labels[i]: float(predicted[i]) for i in range(10)} |
| |
|
|
| gr.Interface(fn=predict, |
| inputs=gr.inputs.Image(), |
| outputs=gr.outputs.Label(num_top_classes=10), |
| examples=["200005.jpg", "200006.jpg"], interpretation='default').launch() |