Spaces:
Build error
Build error
| import streamlit as st | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| # Define the VanillaCNN_SE class | |
| class SEBlock(nn.Module): | |
| def __init__(self, channels, reduction_ratio=16): | |
| super(SEBlock, self).__init__() | |
| self.global_avg_pool = nn.AdaptiveAvgPool2d(1) | |
| self.fc1 = nn.Linear(channels, channels // reduction_ratio) | |
| self.fc2 = nn.Linear(channels // reduction_ratio, channels) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| batch_size, channels, _, _ = x.size() | |
| y = self.global_avg_pool(x).view(batch_size, channels) | |
| y = torch.relu(self.fc1(y)) | |
| y = self.sigmoid(self.fc2(y)).view(batch_size, channels, 1, 1) | |
| return x * y | |
| class VanillaCNN_SE(nn.Module): | |
| def __init__(self, num_classes): | |
| super(VanillaCNN_SE, self).__init__() | |
| self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) | |
| self.bn1 = nn.BatchNorm2d(64) | |
| self.se1 = SEBlock(64) | |
| self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) | |
| self.bn2 = nn.BatchNorm2d(128) | |
| self.se2 = SEBlock(128) | |
| self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) | |
| self.bn3 = nn.BatchNorm2d(256) | |
| self.se3 = SEBlock(256) | |
| self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) | |
| self.bn4 = nn.BatchNorm2d(512) | |
| self.se4 = SEBlock(512) | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| self.fc1 = nn.Linear(512 * 14 * 14, 1024) | |
| self.fc2 = nn.Linear(1024, num_classes) | |
| def forward(self, x): | |
| x = self.pool(torch.relu(self.bn1(self.conv1(x)))) | |
| x = self.se1(x) | |
| x = self.pool(torch.relu(self.bn2(self.conv2(x)))) | |
| x = self.se2(x) | |
| x = self.pool(torch.relu(self.bn3(self.conv3(x)))) | |
| x = self.se3(x) | |
| x = self.pool(torch.relu(self.bn4(self.conv4(x)))) | |
| x = self.se4(x) | |
| x = x.view(x.size(0), -1) | |
| x = torch.relu(self.fc1(x)) | |
| x = self.fc2(x) | |
| return x | |
| # Load the model | |
| def load_model(): | |
| model = VanillaCNN_SE(num_classes=12) # Update num_classes as per your dataset | |
| model.load_state_dict(torch.load("vanilla_cnn_se.pth", map_location=torch.device('cpu'))) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # Define class names | |
| class_names = [ | |
| "Maize", "Common wheat", "Common Chickweed", "Loose Silky-bent", | |
| "Charlock", "Cleavers", "Sugar beet", "Fat Hen", "Scentless Mayweed", | |
| "Small-flowered Cranesbill", "Shepherd’s Purse", "Black-grass" | |
| ] | |
| # Define transformations | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor() | |
| ]) | |
| def mask_image(image): | |
| # Convert PIL image to OpenCV format | |
| image_np = np.array(image) | |
| hsv_img = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV) | |
| # Define green color range | |
| lower_green = np.array([30, 40, 40]) | |
| upper_green = np.array([90, 255, 255]) | |
| # Create a mask for the green area | |
| mask = cv2.inRange(hsv_img, lower_green, upper_green) | |
| masked_img = cv2.bitwise_and(image_np, image_np, mask=mask) | |
| # Convert back to PIL image | |
| return Image.fromarray(masked_img) | |
| def predict_class(image): | |
| # Transform the image for the model | |
| image_tensor = transform(image).unsqueeze(0) | |
| # Predict the class | |
| with torch.no_grad(): | |
| outputs = model(image_tensor) | |
| _, predicted = torch.max(outputs, 1) | |
| return class_names[predicted.item()] | |
| # Streamlit UI | |
| st.title("Plant Seedling Classification") | |
| st.write("Upload an image to classify the plant seedling and view the masked image.") | |
| # File uploader | |
| uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| # Load the image | |
| image = Image.open(uploaded_file).convert("RGB") | |
| # Mask the image | |
| masked_image = mask_image(image) | |
| # Predict the class | |
| predicted_class = predict_class(image) | |
| # Display results | |
| st.image(image, caption="Original Image", use_column_width=True) | |
| st.image(masked_image, caption="Masked Image", use_column_width=True) | |
| st.write(f"### Predicted Class: {predicted_class}") | |