Sravanth Ganta commited on
Commit
83cbde9
·
1 Parent(s): d56d2e7

Cancer Detector App

Browse files
Files changed (3) hide show
  1. .gitignore.txt +7 -0
  2. app.py +109 -0
  3. requirement.txt.txt +2 -0
.gitignore.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ flagged/
2
+ *.pt
3
+ *.png
4
+ *.jpg
5
+ *.mp4
6
+ *.mkv
7
+ gradio_cached_examples/
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import requests
4
+ import os
5
+
6
+ import os
7
+ from PIL import Image
8
+ import timm
9
+ from torchvision import datasets
10
+ import torchvision
11
+ import torch
12
+ from torchvision.transforms import transforms
13
+ import numpy as np
14
+ from PIL import ImageFile
15
+ import matplotlib.pyplot as plt
16
+ import json
17
+ import warnings
18
+ import time
19
+ import glob
20
+ import shutil
21
+
22
+ warnings.filterwarnings("ignore")
23
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
24
+
25
+
26
+ def predict(image, model, device, class_name):
27
+
28
+ prediction_transform = transforms.Compose([transforms.Resize(size=(224, 224)),
29
+ transforms.ToTensor(),
30
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
31
+ try:
32
+ image = prediction_transform(image)[:3,:,:].unsqueeze(0)
33
+ except:
34
+ image = image.convert('RGB')
35
+ image = prediction_transform(image)[:3,:,:].unsqueeze(0)
36
+
37
+ if device == 'cuda':
38
+ if torch.cuda.is_available():
39
+ image = image.cuda()
40
+ else:
41
+ print("You don't have cuda")
42
+
43
+ with torch.no_grad():
44
+ model.eval()
45
+ pred = model(image)
46
+
47
+
48
+ idx = torch.argmax(pred)
49
+
50
+ prob = pred[0][idx].item()*100
51
+
52
+ return prob, class_name[idx]
53
+
54
+
55
+ model = timm.create_model('resnet50', pretrained=True)
56
+
57
+ model.fc = torch.nn.Sequential(torch.nn.Linear(2048, 256),
58
+ torch.nn.Dropout(0.2),
59
+ torch.nn.ReLU(),
60
+ torch.nn.Linear(256, 64),
61
+ torch.nn.Dropout(0.2),
62
+ torch.nn.ReLU(),
63
+ torch.nn.Linear(64, 32),
64
+ torch.nn.Dropout(0.2),
65
+ torch.nn.ReLU(),
66
+ torch.nn.Linear(32, 4),
67
+ torch.nn.Softmax()
68
+ )
69
+
70
+ model.load_state_dict(torch.load('model_ResNet50_acc_max.pt',map_location=torch.device('cpu')))
71
+
72
+
73
+
74
+ class_name = ['adenocarcinoma',
75
+ 'large.cell.carcinoma',
76
+ 'normal',
77
+ 'squamous.cell.carcinoma']
78
+
79
+
80
+ display_prob = True
81
+ show=True
82
+ path = glob.glob('*.png')
83
+
84
+ def show_preds_image(path):
85
+ for image in path:
86
+ img = Image.open(image)
87
+ if show:
88
+ plt.imshow(img)
89
+ plt.show()
90
+ class_name = ['adenocarcinoma',
91
+ 'large.cell.carcinoma',
92
+ 'normal',
93
+ 'squamous.cell.carcinoma']
94
+ prob, result = predict(img, model, 'cpu', class_name)
95
+ if display_prob:
96
+ print('Probability of {} : {:.6f}'.format(result, prob))
97
+
98
+ return prob, result
99
+
100
+ inputs_image = [
101
+ gr.components.Image(type="filepath", label="Input Image"),
102
+ ]
103
+
104
+ interface_image = gr.Interface(
105
+ fn=show_preds_image,
106
+ inputs=inputs_image,
107
+ title="Cancer detector",
108
+ cache_examples=False,
109
+ )
requirement.txt.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ timm>=0.9.2
2
+ opencv-python>=4.8.0.74