| import os |
| import pandas as pd |
| import numpy as np |
| import torch |
| from PIL import Image |
| from transformers import SegformerForSemanticSegmentation, SegformerFeatureExtractor |
| from torch import nn |
| import streamlit as st |
|
|
| img_path = None |
| st.title('Semantic Segmentation using SegFormer') |
| file_upload = st.file_uploader('Raw Input Image') |
| image_path = st.selectbox( |
| 'Choose any one image for inference', |
| ('Select image', 'image1.jpg', 'image2.jpg', 'image3.jpg')) |
|
|
| if file_upload is None: |
| raw_image = image_path |
| else: |
| raw_image = file_upload |
|
|
| if raw_image != 'Select image': |
| df = pd.read_csv('class_dict_seg.csv') |
| classes = df['name'] |
| palette = df[[' r', ' g', ' b']].values |
| id2label = classes.to_dict() |
| label2id = {v: k for k, v in id2label.items()} |
|
|
| image = Image.open(raw_image) |
| image = np.asarray(image) |
| |
| with st.spinner('Loading Model...'): |
| feature_extractor = SegformerFeatureExtractor(align=False, reduce_zero_label=False) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = SegformerForSemanticSegmentation.from_pretrained("deep-learning-analytics/segformer_semantic_segmentation", |
| ignore_mismatched_sizes=True, |
| num_labels=len(id2label), id2label=id2label, label2id=label2id, |
| reshape_last_stage=True) |
| model = model.to(device) |
| model.eval() |
| |
| with st.spinner('Preparing image...'): |
| |
| feature_extractor_inference = SegformerFeatureExtractor(do_random_crop=False, do_pad=False) |
| pixel_values = feature_extractor_inference(image, return_tensors="pt").pixel_values.to(device) |
|
|
| with st.spinner('Running inference...'): |
| outputs = model(pixel_values=pixel_values) |
|
|
| with st.spinner('Postprocessing...'): |
| logits = outputs.logits.cpu() |
| |
| upsampled_logits = nn.functional.interpolate(logits, |
| size=image.shape[:-1], |
| mode='bilinear', |
| align_corners=False) |
| |
| seg = upsampled_logits.argmax(dim=1)[0] |
| color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) |
| all_labels = [] |
| for label, color in enumerate(palette): |
| color_seg[seg == label, :] = color |
| if label in seg: |
| all_labels.append(id2label[label]) |
| |
| color_seg = color_seg[..., ::-1] |
| |
| img = np.array(image) * 0.5 + color_seg * 0.5 |
| img = img.astype(np.uint8) |
| st.image(img, caption="Segmented Image") |
| st.header("Predicted Labels") |
| for idx, label in enumerate(all_labels): |
| st.subheader(f'{idx+1}) {label}') |