| """Object detection demo with MobileNet SSD. |
| This model and code are based on |
| https://github.com/robmarkcole/object-detection-app |
| """ |
|
|
| import logging |
| import queue |
| from pathlib import Path |
| from typing import List, NamedTuple |
|
|
| import av |
| import cv2 |
| import numpy as np |
| import streamlit as st |
| from streamlit_webrtc import WebRtcMode, webrtc_streamer |
|
|
| from sample_utils.download import download_file |
|
|
| HERE = Path(__file__).parent |
| ROOT = HERE |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" |
| MODEL_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.caffemodel" |
| PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" |
| PROTOTXT_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.prototxt.txt" |
|
|
| CLASSES = [ |
| "background", |
| "aeroplane", |
| "bicycle", |
| "bird", |
| "boat", |
| "bottle", |
| "bus", |
| "car", |
| "cat", |
| "chair", |
| "cow", |
| "diningtable", |
| "dog", |
| "horse", |
| "motorbike", |
| "person", |
| "pottedplant", |
| "sheep", |
| "sofa", |
| "train", |
| "tvmonitor", |
| ] |
|
|
|
|
| class Detection(NamedTuple): |
| class_id: int |
| label: str |
| score: float |
| box: np.ndarray |
|
|
|
|
| @st.cache_resource |
| def generate_label_colors(): |
| return np.random.uniform(0, 255, size=(len(CLASSES), 3)) |
|
|
|
|
| COLORS = generate_label_colors() |
|
|
| download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564) |
| download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353) |
|
|
|
|
| |
| cache_key = "object_detection_dnn" |
| if cache_key in st.session_state: |
| net = st.session_state[cache_key] |
| else: |
| net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH)) |
| st.session_state[cache_key] = net |
|
|
| score_threshold = st.slider("Score threshold", 0.0, 1.0, 0.5, 0.05) |
|
|
| |
| |
| |
| |
| result_queue: "queue.Queue[List[Detection]]" = queue.Queue() |
|
|
|
|
| def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame: |
| image = frame.to_ndarray(format="bgr24") |
|
|
| |
| blob = cv2.dnn.blobFromImage( |
| image=cv2.resize(image, (300, 300)), |
| scalefactor=0.007843, |
| size=(300, 300), |
| mean=(127.5, 127.5, 127.5), |
| ) |
| net.setInput(blob) |
| output = net.forward() |
|
|
| h, w = image.shape[:2] |
|
|
| |
| output = output.squeeze() |
| output = output[output[:, 2] >= score_threshold] |
| detections = [ |
| Detection( |
| class_id=int(detection[1]), |
| label=CLASSES[int(detection[1])], |
| score=float(detection[2]), |
| box=(detection[3:7] * np.array([w, h, w, h])), |
| ) |
| for detection in output |
| ] |
|
|
| |
| for detection in detections: |
| caption = f"{detection.label}: {round(detection.score * 100, 2)}%" |
| color = COLORS[detection.class_id] |
| xmin, ymin, xmax, ymax = detection.box.astype("int") |
|
|
| cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2) |
| cv2.putText( |
| image, |
| caption, |
| (xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15), |
| cv2.FONT_HERSHEY_SIMPLEX, |
| 0.5, |
| color, |
| 2, |
| ) |
|
|
| result_queue.put(detections) |
|
|
| return av.VideoFrame.from_ndarray(image, format="bgr24") |
|
|
|
|
| webrtc_ctx = webrtc_streamer( |
| key="object-detection", |
| mode=WebRtcMode.SENDRECV, |
| video_frame_callback=video_frame_callback, |
| media_stream_constraints={"video": True, "audio": False}, |
| async_processing=True, |
| ) |
|
|
| if st.checkbox("Show the detected labels", value=True): |
| if webrtc_ctx.state.playing: |
| labels_placeholder = st.empty() |
| |
| |
| |
| |
| |
| while True: |
| result = result_queue.get() |
| labels_placeholder.table(result) |
|
|
| st.markdown( |
| "This demo uses a model and code from " |
| "https://github.com/robmarkcole/object-detection-app. " |
| "Many thanks to the project." |
| ) |
|
|