withoutpaper commited on
Commit
96ec22e
·
1 Parent(s): ec207a2

Clean upload with model only

Browse files
gradio_demo.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import gradio as gr
6
+ from models.ResNet_model101 import ResNet101
7
+
8
+ # 定义模型类别
9
+ class_names = {
10
+ 'akiec': 'ACTINIC KERATOSIS',
11
+ 'bcc': 'BASAL CELL CARCINOMA',
12
+ 'bkl': 'BENIGN KERATOSIS-LIKE LESIONS',
13
+ 'df': 'DERMATOFIBROMA',
14
+ 'mel': 'MELANOMA',
15
+ 'nv': 'MELANOCYTIC NEVI',
16
+ 'vasc': 'VASCULAR LESIONS'
17
+ }
18
+ labels = list(class_names.values())
19
+
20
+ # 图像预处理流程
21
+ data_transform = transforms.Compose([
22
+ transforms.Resize((256, 256)),
23
+ transforms.ToTensor(),
24
+ transforms.Normalize(mean=[0.7633, 0.5458, 0.5704], std=[0.09, 0.1188, 0.1334])
25
+ ])
26
+
27
+ # 加载模型
28
+ device = torch.device("cpu")
29
+ model = ResNet101(dropout_prob=0.5)
30
+ model.load_state_dict(torch.load("pth/resnet101_model.pth", map_location=device))
31
+ model.to(device)
32
+ model.eval()
33
+
34
+ # 推理函数
35
+ def classify_skin_image(image: Image.Image):
36
+ image = image.convert("RGB")
37
+ tensor = data_transform(image).unsqueeze(0).to(device)
38
+ with torch.no_grad():
39
+ output = model(tensor)
40
+ pred = output.argmax(dim=1).item()
41
+ confidence = torch.nn.functional.softmax(output, dim=1)[0][pred].item()
42
+ return {labels[pred]: float(confidence)}
43
+
44
+ # 构建 Gradio 界面
45
+ demo = gr.Interface(
46
+ fn=classify_skin_image,
47
+ inputs=gr.Image(type="pil"),
48
+ outputs=gr.Label(num_top_classes=3),
49
+ title="Skin Cancer Classifier (ResNet101)",
50
+ description="Upload a skin lesion image and the model will classify it into one of seven categories.",
51
+ allow_flagging="never"
52
+ )
53
+
54
+ if __name__ == "__main__":
55
+ demo.launch()
models/ResNet_model101.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torchvision import models
3
+
4
+
5
+ class ResNet101(nn.Module):
6
+ def __init__(self, num_classes=7, dropout_prob=0.4):
7
+ super(ResNet101, self).__init__()
8
+ self.resnet101 = models.resnet101(pretrained=True)
9
+
10
+ # Add Dropout between the fully connected layers
11
+ in_features = self.resnet101.fc.in_features
12
+ self.resnet101.fc = nn.Sequential(
13
+ nn.Linear(in_features, 512),
14
+ nn.ReLU(),
15
+ nn.Dropout(p=dropout_prob),
16
+ nn.Linear(512, num_classes)
17
+ )
18
+
19
+ def forward(self, x):
20
+ return self.resnet101(x)
pth/resnet101_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf700e92303306823efe1ec60b4508d55937d639bb0debea477977b51334c9ad
3
+ size 174859770
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch==-2.1.2+cu121-cp38
3
+ torchvision==0.16.2+cu121-cp38
4
+ pillow
templates/index.html ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Skin Cancer Image Classification Demo</title>
7
+ <style>
8
+ /* Reset basic styles */
9
+ * {
10
+ margin: 0;
11
+ padding: 0;
12
+ box-sizing: border-box;
13
+ }
14
+
15
+ /* Body styling */
16
+ body {
17
+ font-family: 'Arial', sans-serif;
18
+ background: #f4f7f6;
19
+ color: #333;
20
+ text-align: center;
21
+ padding: 30px 0;
22
+ }
23
+
24
+ /* Container for content */
25
+ .container {
26
+ width: 90%;
27
+ max-width: 1000px;
28
+ margin: auto;
29
+ background-color: white;
30
+ border-radius: 10px;
31
+ box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
32
+ padding: 20px;
33
+ overflow: hidden;
34
+ }
35
+
36
+ /* Title */
37
+ h1 {
38
+ font-size: 2.5em;
39
+ color: #4CAF50;
40
+ margin-bottom: 20px;
41
+ }
42
+
43
+ h3 {
44
+ font-size: 1.2em;
45
+ margin-top: 20px;
46
+ color: #333;
47
+ }
48
+
49
+ /* Image section styling */
50
+ .image-container {
51
+ display: flex;
52
+ justify-content: center;
53
+ flex-wrap: wrap;
54
+ gap: 20px;
55
+ margin-bottom: 20px;
56
+ }
57
+
58
+ .image-container img {
59
+ width: 180px;
60
+ height: 180px;
61
+ object-fit: cover;
62
+ border-radius: 8px;
63
+ transition: transform 0.3s ease-in-out;
64
+ cursor: pointer;
65
+ border: 2px solid transparent; /* Initially no border */
66
+ }
67
+
68
+ .image-container img:hover {
69
+ transform: scale(1.1);
70
+ }
71
+
72
+ /* Highlight selected image */
73
+ .image-container img.selected {
74
+ border: 4px solid black; /* Border when selected */
75
+ }
76
+
77
+ /* Button Container */
78
+ .button-container {
79
+ margin-top: 20px;
80
+ display: flex;
81
+ flex-direction: column;
82
+ align-items: center;
83
+ }
84
+
85
+ /* Input fields styling */
86
+ input[type="file"], input[type="text"] {
87
+ padding: 10px;
88
+ border-radius: 5px;
89
+ border: 1px solid #ddd;
90
+ width: 60%;
91
+ margin: 10px 0;
92
+ font-size: 1em;
93
+ }
94
+
95
+ input[type="text"] {
96
+ width: 50%;
97
+ }
98
+
99
+ /* Classify button styling */
100
+ button {
101
+ padding: 12px 25px;
102
+ background-color: #4CAF50;
103
+ color: white;
104
+ border: none;
105
+ border-radius: 25px;
106
+ font-size: 1.1em;
107
+ cursor: pointer;
108
+ transition: background-color 0.3s ease, transform 0.2s ease-in-out;
109
+ }
110
+
111
+ button:hover {
112
+ background-color: #45a049;
113
+ transform: translateY(-3px);
114
+ }
115
+
116
+ button:active {
117
+ transform: translateY(2px);
118
+ }
119
+
120
+ /* Feedback styling */
121
+ .feedback {
122
+ margin-top: 20px;
123
+ font-size: 1.2em;
124
+ color: #4CAF50;
125
+ }
126
+
127
+ .feedback.error {
128
+ color: red;
129
+ }
130
+
131
+ /* Media Queries for responsiveness */
132
+ @media (max-width: 768px) {
133
+ .image-container {
134
+ flex-direction: column;
135
+ align-items: center;
136
+ }
137
+
138
+ input[type="file"], input[type="text"] {
139
+ width: 80%;
140
+ }
141
+
142
+ button {
143
+ width: 80%;
144
+ }
145
+ }
146
+ </style>
147
+ </head>
148
+ <body>
149
+ <div class="container">
150
+ <h1>Skin Cancer Image Classification Demo</h1>
151
+ <form method="POST" enctype="multipart/form-data" onsubmit="return validateForm()">
152
+ <!-- Image selection section -->
153
+ <div class="image-container">
154
+ {% for image in image_urls %}
155
+ <div>
156
+ <img src="{{ image }}" alt="Image {{ loop.index }}" onclick="selectImage(this)">
157
+ <input type="radio" name="image_choice" value="{{ loop.index0 }}" id="image{{ loop.index0 }}" style="display:none;">
158
+ </div>
159
+ {% endfor %}
160
+ </div>
161
+
162
+ <!-- Upload image section -->
163
+ <div class="button-container">
164
+ <h3>Or Upload Your Own Image</h3>
165
+ <input type="file" name="image_file" accept="image/*"><br><br>
166
+ <label for="real_category">Real Category: </label>
167
+ <select id="real_category" name="real_category">
168
+ <option value="" disabled selected>Select the category</option>
169
+ <option value="ACTINIC KERATOSIS">ACTINIC KERATOSIS</option>
170
+ <option value="BASAL CELL CARCINOMA">BASAL CELL CARCINOMA</option>
171
+ <option value="BENIGN KERATOSIS-LIKE LESIONS">BENIGN KERATOSIS-LIKE LESIONS</option>
172
+ <option value="DERMATOFIBROMA">DERMATOFIBROMA</option>
173
+ <option value="MELANOMA">MELANOMA</option>
174
+ <option value="MELANOCYTIC NEVI">MELANOCYTIC NEVI</option>
175
+ <option value="VASCULAR LESIONS">VASCULAR LESIONS</option>
176
+ </select><br><br>
177
+ <button type="submit">Classify Image</button>
178
+ </div>
179
+ </form>
180
+
181
+ {% if feedback %}
182
+ <div class="feedback {% if feedback == 'ResNet101: Oops :( I hope I can do better next time' %}error{% endif %}">
183
+ <p><strong>Real Category:</strong> {{ real_category }}</p>
184
+ <p><strong>Prediction:</strong> {{ predicted_category }}</p>
185
+ <p>{{ feedback }}</p>
186
+ </div>
187
+ {% endif %}
188
+ </div>
189
+ <script>
190
+ let selectedImage = null; // Variable to track the selected image
191
+
192
+ function selectImage(imgElement) {
193
+ // If the clicked image is already selected, deselect it
194
+ if (imgElement === selectedImage) {
195
+ imgElement.classList.remove('selected');
196
+ var radioButton = imgElement.nextElementSibling; // Find the corresponding radio button
197
+ radioButton.checked = false;
198
+ selectedImage = null; // Reset the selected image tracker
199
+ } else {
200
+ // Deselect the previously selected image (if any)
201
+ if (selectedImage !== null) {
202
+ selectedImage.classList.remove('selected');
203
+ var previousRadioButton = selectedImage.nextElementSibling;
204
+ previousRadioButton.checked = false;
205
+ }
206
+
207
+ // Select the clicked image and apply the 'selected' class for the border
208
+ imgElement.classList.add('selected');
209
+
210
+ // Mark the radio button as selected
211
+ var radioButton = imgElement.nextElementSibling; // Find the corresponding radio button
212
+ radioButton.checked = true;
213
+
214
+ // Update the selected image tracker
215
+ selectedImage = imgElement;
216
+ }
217
+ }
218
+
219
+ function validateForm() {
220
+ var imageChoice = document.querySelector('input[name="image_choice"]:checked');
221
+ var imageFile = document.querySelector('input[name="image_file"]').files[0];
222
+ var realCategory = document.getElementById('real_category').value.trim();
223
+
224
+ if ((imageChoice && imageFile) || (!imageChoice && !imageFile)) {
225
+ alert("Please either select an image or upload your own image, but not both.");
226
+ return false;
227
+ }
228
+
229
+ if (imageFile && !realCategory) {
230
+ alert("Please select the real category for the uploaded image.");
231
+ return false;
232
+ }
233
+
234
+ return true;
235
+ }
236
+ </script>
237
+
238
+ </body>
239
+ </html>