| import cv2 |
| import numpy as np |
| import os |
| import torch |
| import onnxruntime as ort |
| import time |
| from functools import wraps |
| import argparse |
| from PIL import Image |
| from io import BytesIO |
| import streamlit as st |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| mosaic = True |
|
|
| def center_crop(img, new_height, new_width): |
| height, width, _ = img.shape |
| start_x = width//2 - new_width//2 |
| start_y = height//2 - new_height//2 |
| return img[start_y:start_y+new_height, start_x:start_x+new_width] |
|
|
|
|
| def mosaic_crop(img, size): |
| height, width, _ = img.shape |
| padding_height = (size - height % size) % size |
| padding_width = (size - width % size) % size |
|
|
| padded_img = cv2.copyMakeBorder(img, 0, padding_height, 0, padding_width, cv2.BORDER_CONSTANT, value=[0, 0, 0]) |
| tiles = [padded_img[x:x+size, y:y+size] for x in range(0, padded_img.shape[0], size) for y in range(0, padded_img.shape[1], size)] |
|
|
| return tiles, padded_img.shape[0] // size, padded_img.shape[1] // size, padding_height, padding_width |
|
|
| def stitch_tiles(tiles, rows, cols, size): |
| return np.concatenate([np.concatenate([tiles[i*cols + j] for j in range(cols)], axis=1) for i in range(rows)], axis=0) |
|
|
|
|
| def timing_decorator(func): |
| @wraps(func) |
| def wrapper(*args, **kwargs): |
| start_time = time.time() |
| result = func(*args, **kwargs) |
| end_time = time.time() |
|
|
| duration = end_time - start_time |
| print(f"Function '{func.__name__}' took {duration:.6f} seconds") |
| return result |
|
|
| return wrapper |
|
|
| @timing_decorator |
| def process_image(session, img, colors, mosaic=False): |
| if not mosaic: |
| |
| img = center_crop(img, 416, 416) |
| blob = cv2.dnn.blobFromImage(img, 1/255.0, (416, 416), swapRB=True, crop=False) |
|
|
| |
| output = session.run(None, {session.get_inputs()[0].name: blob}) |
|
|
| |
| output_img = output[0].squeeze(0).transpose(1, 2, 0) |
| output_img = (output_img * 122).clip(0, 255).astype(np.uint8) |
| output_mask = output_img.max(axis=2) |
|
|
| output_mask_color = np.zeros((416, 416, 3), dtype=np.uint8) |
|
|
| |
| for class_idx in np.unique(output_mask): |
| if class_idx in colors: |
| output_mask_color[output_mask == class_idx] = colors[class_idx] |
|
|
| |
| transparent_mask = (output_mask == 122) |
|
|
| |
| transparent_mask = np.stack([transparent_mask]*3, axis=-1) |
|
|
| |
| output_mask_color[transparent_mask] = img[transparent_mask] |
|
|
| |
| overlay = cv2.addWeighted(img, 0.6, output_mask_color, 0.4, 0) |
|
|
| return overlay |
| |
|
|
| |
| cuda = torch.cuda.is_available() |
|
|
| if cuda: |
| print("We have a GPU!") |
| providers = ['CUDAExecutionProvider'] if cuda else ['CPUExecutionProvider'] |
|
|
| session = ort.InferenceSession('end2end.onnx', providers=providers) |
|
|
|
|
| |
| colors = {0: (0, 0, 255), 122: (0, 0, 0), 244: (0, 255, 255)} |
|
|
| def load_image(uploaded_file): |
| try: |
| image = Image.open(uploaded_file) |
| return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
| except Exception as e: |
| st.write("Could not load image: ", e) |
| return None |
|
|
|
|
| st.title("OpenLander ONNX app") |
| st.write("Upload an image to process with the ONNX OpenLander model!") |
| st.write("Bear in mind that this model is **much less refined** than the embedded models at the moment.") |
|
|
|
|
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"]) |
| if uploaded_file is not None: |
| img = load_image(uploaded_file) |
| if img.shape[2] == 4: |
| img = img[:, :, :3] |
| img_processed = None |
|
|
| if st.button('Process'): |
| with st.spinner('Processing...'): |
| start = time.time() |
| if mosaic: |
| tiles, rows, cols, padding_height, padding_width = mosaic_crop(img, 416) |
| processed_tiles = [process_image(session, tile, colors, mosaic=True) for tile in tiles] |
| overlay = stitch_tiles(processed_tiles, rows, cols, 416) |
|
|
| |
| overlay = overlay[:overlay.shape[0]-padding_height, :overlay.shape[1]-padding_width] |
| img_processed = overlay |
| else: |
| img_processed = process_image(session, img, colors) |
| end = time.time() |
| st.write(f"Processing time: {end - start} seconds") |
|
|
| st.image(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), caption='Uploaded Image.', use_column_width=True) |
|
|
| if img_processed is not None: |
| st.image(cv2.cvtColor(img_processed, cv2.COLOR_BGR2RGB), caption='Processed Image.', use_column_width=True) |
| st.write("Red => obstacle ||| Yellow => Human obstacle ||| no color => clear for landing or delivery ") |
|
|