File size: 3,605 Bytes
d31170d
 
 
 
 
 
d82f8ec
 
d31170d
 
 
 
 
 
 
 
 
 
 
 
 
9aa5f09
ac3ee8a
 
2c62d9e
f851410
6feb474
e5ef4e7
6feb474
e5ef4e7
6feb474
ac3ee8a
 
 
d31170d
 
 
 
 
473b629
d31170d
8f49a5b
d31170d
 
 
 
 
ac3ee8a
4f4cb25
d31170d
 
 
ac3ee8a
d31170d
 
 
 
 
6a19c7a
d31170d
59ac305
ac3ee8a
6caae34
ac3ee8a
 
 
 
 
 
 
b8cbc9e
ac3ee8a
 
 
2c62d9e
ac3ee8a
 
 
 
d31170d
ac3ee8a
 
 
d31170d
12e6dd9
d31170d
d0c5619
d31170d
 
f4315f4
ac3ee8a
f4315f4
e5ef4e7
f4315f4
d31170d
e5c1737
5b33f29
191feea
 
d31170d
55ede45
 
8d9210d
d31170d
 
f7aadc0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import torchvision
from torchvision import transforms
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from resnet_lightning import ResNet18Model
import gradio as gr

model = ResNet18Model.load_from_checkpoint("epoch=19-step=3920.ckpt")

inv_normalize = transforms.Normalize(
    mean = [-0.50/0.23, -0.50/0.23, -0.50/0.23],
    std= [1/0.23, 1/0.23,1/0.23]
)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

model_layer_names = ["1", "2", "3"]

def get_layer(layer_name):
    print("layer name:", layer_name)
    if layer_name == 1:
        return [model.layer1[-1]]
    elif layer_name == 2:
        return [model.layer2[-1]]
    elif layer_name == 3:
        return [model.layer3[-1]]
    else:
        return None

def resize_image_pil(image, new_width, new_height):
    img =  Image.fromarray(np.array(image))
    width, height = img.size

    width_scale = new_width/width
    height_scale = new_height/height
    scale = min(width_scale, height_scale)
    resized = img.resize((int(width*scale), int(height*scale)), Image.NEAREST)
    resized = resized.crop((0,0,new_width, new_height))

    return resized


def inference(input_img, show_gradcam, layer_name, num_classes, transparancy = 0.5):
    print(show_gradcam, layer_name, num_classes, transparancy)
    input_img = resize_image_pil(input_img,32,32)
    input_img = np.array(input_img)
    org_img = input_img
    
    input_img= input_img.reshape((32,32,3))
    transform = transforms.ToTensor()
    input_img = transform(input_img)
    input_img = input_img.unsqueeze(0)
    outputs = model(input_img)
    # print(outputs)
    softmax =  torch.nn.Softmax(dim=0)
    o = softmax(outputs.flatten())
    
    output_numpy = np.squeeze(np.asarray(outputs.detach().numpy()))
    index_sort = np.argsort(output_numpy)[::-1]
    
    confidences = {}
    for i in range(int(num_classes)):
        confidences[classes[index_sort[i]]] = float(o[index_sort[i]])

    
    prediction= torch.max(outputs, 1)

    if show_gradcam:
        target_layers = get_layer(layer_name)
        print("target layer",target_layers)
        cam = GradCAM(model=model, target_layers=target_layers)
        grayscale_cam = cam(input_tensor= input_img)
        grayscale_cam = grayscale_cam[0, :]
        visualization = show_cam_on_image(org_img/255,grayscale_cam,use_rgb=True,
                                     image_weight=transparancy)
    else:
        visualization = org_img

    
    return classes[int(prediction[0].item())], visualization, confidences
    
demo = gr.Interface(
    inference,
    inputs = [
        gr.Image(width=256,height=256,label="Input image"),
        gr.Number(value=3, maximum=10, minimum=1,step=1.0, precision=0,label="Number of classes to display"),
        gr.Checkbox(True, label="Show GradCAM Image"),
        gr.Dropdown(model_layer_names, value=3, label="Which layer for Gradcam"),
        gr.Slider(0, 1, value=0.5,label="Overall opacity of the overlay"),
    ],
    outputs = [
        gr.Label(label="Class", container=True, show_label= True),
        gr.Image(width= 256, height=256,label="Output Image"),
        gr.Label(label="Confidences", container=True, show_label= True),
    ],
    title = "CIFAR 10 trained on ResNet model in pytorch lightning with Gradcam",
    description = " A simple gradio inference to infer on resnet18 model",
    examples = [["cat.jpg",1, True, 10, 0.4]]
)

if __name__ == "__main__":
    demo.launch()