| import gradio as gr |
| from PIL import Image |
| from joblib import load |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision.models import efficientnet_b0 |
| import torchvision.transforms as transforms |
|
|
|
|
| class MultiModalClassifier(nn.Module): |
| def __init__(self, num_classes, num_features): |
| super(MultiModalClassifier, self).__init__() |
| |
| efficientnet = efficientnet_b0(pretrained=True) |
| |
| |
| self.efficientnet_features = nn.Sequential(*list(efficientnet.children())[:-1]) |
| |
| |
| self.age_dim = 1 |
| self.anatom_site_dim = 1 |
| self.sex_dim = 1 |
| |
| |
| self.fc1 = nn.Linear(num_features + self.age_dim + self.anatom_site_dim + self.sex_dim, 256) |
| self.fc2 = nn.Linear(256, num_classes) |
| |
| |
| self.dropout = nn.Dropout(p=0.5) |
| |
| def forward(self, image, age, anatom_site, sex): |
| |
| image_features = self.efficientnet_features(image) |
| image_features = F.avg_pool2d(image_features, image_features.size()[2:]).view(image.size(0), -1) |
| |
| |
| age = age.view(-1, 1) |
| anatom_site = anatom_site.view(-1, 1) |
| sex = sex.view(-1, 1) |
| |
| additional_features = torch.cat((age, anatom_site, sex), dim=1) |
| combined_features = torch.cat((image_features, additional_features), dim=1) |
| |
| |
| combined_features = F.relu(self.fc1(combined_features)) |
| combined_features = self.dropout(combined_features) |
| output = self.fc2(combined_features) |
| |
| return output |
|
|
| |
| num_classes = 1 |
| num_features = 1280 |
| model = MultiModalClassifier(num_classes, num_features) |
|
|
| |
| model.load_state_dict(torch.load(r'best_epoch_weights.pth',map_location=torch.device('cpu'))) |
|
|
| |
| model.eval() |
|
|
| |
| age_scaler = load(r'age_approx_scaler.joblib') |
|
|
| |
| test_transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| ]) |
|
|
| diagnosis_map = {0: 'benign', 1: 'malignant'} |
|
|
| |
| sexes_mapping = {'male': 0, 'female': 1} |
|
|
| |
| anatom_site_mapping = { |
| 'torso': 0, |
| 'lower extremity': 1, |
| 'head/neck': 2, |
| 'upper extremity': 3, |
| 'palms/soles': 4, |
| 'oral/genital': 5, |
| } |
|
|
| def predict(image, age, gender, anatom_site): |
|
|
| image = Image.fromarray(image) |
| |
| image = test_transform(image) |
| image = image.float() |
| image = image.unsqueeze(0) |
|
|
| sex = torch.tensor([[sexes_mapping[gender.lower()]]], dtype=torch.float32) |
| anatom_site = torch.tensor([[anatom_site_mapping[anatom_site]]], dtype=torch.float32) |
|
|
| |
| scaled_age = age_scaler.transform([[age]]) |
| |
| age_tensor = torch.tensor(np.array(scaled_age), dtype=torch.float32) |
|
|
| |
| output = model(image, age_tensor, anatom_site, sex) |
|
|
| |
| output_sigmoid = torch.sigmoid(output) |
| |
| predicted_class = (output_sigmoid > 0.5).float() |
|
|
| |
| return f"The predicted_class is a {diagnosis_map[int(predicted_class)]}." |
|
|
|
|
| description_html = """ |
| Fill in the required parameters and click 'classify'. |
| """ |
|
|
| example_data = [ |
| ["ISIC_0000060_downsampled.jpg", 35, "Female", "torso"], |
| ["ISIC_0068279.jpg", 45.0, "Female", "head/neck"] |
| ] |
|
|
| inputs = [ |
| "image", |
| gr.Number(label="Age", minimum=0, maximum=120), |
| gr.Dropdown(['Male', 'Female'], label="Gender"), |
| gr.Dropdown(['torso', 'lower extremity', 'head/neck', 'upper extremity', 'palms/soles', 'oral/genital'], label="Anatomical Site") |
| ] |
|
|
| gr.Interface( |
| predict, |
| inputs, |
| outputs = gr.Textbox(label="Output", type="text"), |
| title="Skin Cancer Diagnosis", |
| description=description_html, |
| allow_flagging='never', |
| examples=example_data |
| ).launch() |