| |
|
|
| import os |
| import cv2 |
| import numpy as np |
| import pandas as pd |
| import matplotlib.pyplot as plt |
| from tempfile import NamedTemporaryFile |
| from typing import List, Tuple, Dict |
|
|
| import streamlit as st |
| from PIL import Image |
| from streamlit_drawable_canvas import st_canvas |
|
|
| from lane_detection import YOLOVideoDetector, LABEL_MAP |
|
|
|
|
| |
|
|
| def extract_four_points(js) -> List[Tuple[int,int]]: |
| """ |
| Return exactly four (x,y) clicks from streamlit_drawable_canvas JSON, or None. |
| """ |
| if not js or "objects" not in js: |
| return None |
| pts = [] |
| for obj in js["objects"]: |
| if obj.get("type") in {"circle", "rect"}: |
| x = int(obj["left"] + obj.get("radius", 0)) |
| y = int(obj["top"] + obj.get("radius", 0)) |
| pts.append((x, y)) |
| if len(pts) == 4: |
| return pts |
| return None |
|
|
|
|
| def draw_poly(img: np.ndarray, pts: List[Tuple[int,int]], color: Tuple[int,int,int]): |
| """ |
| Draw a closed polygon (4 points) on img in the specified color. |
| """ |
| cv2.polylines(img, [np.array(pts, np.int32)], True, color, 2, cv2.LINE_AA) |
|
|
|
|
| |
|
|
| st.set_page_config(page_title="π¦ MultiβLane Congestion Demo", layout="wide") |
| st.title("π¦ MultiβLane Vehicle Congestion Demo") |
|
|
| |
| if "num_lanes" not in st.session_state: |
| st.session_state.num_lanes = None |
| st.session_state.current_lane = 0 |
| st.session_state.lanes = [] |
| st.session_state.video_path = None |
| st.session_state.video_uploaded = False |
|
|
| |
| if st.session_state.num_lanes is None: |
| n = st.number_input( |
| "How many lanes would you like to define? (1β8)", |
| min_value=1, |
| max_value=8, |
| value=2 |
| ) |
| if st.button("β Set Number of Lanes"): |
| st.session_state.num_lanes = int(n) |
| st.session_state.lanes = [None] * st.session_state.num_lanes |
| st.stop() |
|
|
| |
| if not st.session_state.video_uploaded: |
| uploaded = st.file_uploader( |
| "Upload video (formats: mp4, avi, mov, mkv)", |
| type=["mp4","avi","mov","mkv"] |
| ) |
| if uploaded: |
| tmpfile = NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded.name)[1]) |
| tmpfile.write(uploaded.read()) |
| tmpfile.flush() |
| st.session_state.video_path = tmpfile.name |
| st.session_state.video_uploaded = True |
| else: |
| st.stop() |
|
|
| |
| cap = cv2.VideoCapture(st.session_state.video_path) |
| ret, first_frame = cap.read() |
| cap.release() |
| if not ret: |
| st.error("β Could not read the first frame of the video.") |
| st.stop() |
|
|
| frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) |
| h_orig, w_orig = frame_rgb.shape[:2] |
|
|
| |
| MAX_W = 800 |
| if w_orig > MAX_W: |
| scale = MAX_W / w_orig |
| disp_w = MAX_W |
| disp_h = int(h_orig * scale) |
| frame_disp = cv2.resize(frame_rgb, (disp_w, disp_h), interpolation=cv2.INTER_AREA) |
| else: |
| scale = 1.0 |
| disp_w, disp_h = w_orig, h_orig |
| frame_disp = frame_rgb.copy() |
|
|
| |
| colors = [ |
| (0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), |
| (255, 0, 255), (0, 255, 255), (128, 255, 0), (255, 128, 0) |
| ] |
|
|
| if st.session_state.current_lane < st.session_state.num_lanes: |
| idx = st.session_state.current_lane |
| color = colors[idx % len(colors)] |
| st.subheader(f"2οΈβ£ Click exactly 4 points for Lane #{idx+1}") |
| st.caption("Draw 4 small circles on the image, then press **Confirm Lane**.") |
|
|
| canvas = st_canvas( |
| fill_color="rgba(0,0,0,0)", |
| stroke_width=2, |
| stroke_color=f"#{color[2]:02X}{color[1]:02X}{color[0]:02X}", |
| background_image=Image.fromarray(frame_disp), |
| drawing_mode="point", |
| key=f"lane_canvas_{idx}", |
| height=disp_h, |
| width=disp_w, |
| update_streamlit=True |
| ) |
|
|
| pts_scaled = extract_four_points(canvas.json_data) |
| if pts_scaled: |
| preview = frame_disp.copy() |
| draw_poly(preview, pts_scaled, color) |
| st.image(preview, caption=f"Preview β Lane {idx+1}", use_column_width=True) |
|
|
| if st.button(f"Confirm Lane {idx+1}"): |
| if pts_scaled and len(pts_scaled) == 4: |
| |
| orig_pts = [(int(x/scale), int(y/scale)) for (x,y) in pts_scaled] |
| st.session_state.lanes[idx] = orig_pts |
| st.session_state.current_lane += 1 |
| else: |
| st.warning("β Please click exactly 4 points.") |
| st.stop() |
|
|
| |
| st.subheader("β
All lanes defined:") |
| confirm_img = frame_rgb.copy() |
| for i, poly in enumerate(st.session_state.lanes): |
| c = colors[i % len(colors)] |
| draw_poly(confirm_img, poly, c) |
| cv2.putText( |
| confirm_img, f"L{i+1}", (poly[0][0], poly[0][1] - 10), |
| cv2.FONT_HERSHEY_SIMPLEX, 1.0, c, 2, cv2.LINE_AA |
| ) |
| st.image(confirm_img, caption="All lane regions overlaid", use_column_width=True) |
|
|
| |
| st.subheader("π§ Congestion Thresholds") |
| col1, col2 = st.columns(2) |
| with col1: |
| low_thresh = st.number_input( |
| "Green if PCE <", |
| min_value=0.0, |
| max_value=20.0, |
| value=3.5, |
| step=0.1, |
| format="%.1f", |
| help="Values below this will be colored green" |
| ) |
| with col2: |
| high_thresh = st.number_input( |
| "Red if PCE >", |
| min_value=0.0, |
| max_value=20.0, |
| value=6.5, |
| step=0.1, |
| format="%.1f", |
| help="Values above this will be colored red" |
| ) |
| st.caption("Values between green/red thresholds will be yellow.") |
|
|
| if st.button("π Run Congestion Analysis"): |
| out_tmp = NamedTemporaryFile(delete=False, suffix=".mp4").name |
|
|
| regions: Dict[int, List[Tuple[int,int]]] = { |
| i: st.session_state.lanes[i] for i in range(st.session_state.num_lanes) |
| } |
|
|
| |
| detector = YOLOVideoDetector( |
| "Weights/last.pt", |
| st.session_state.video_path, |
| out_tmp, |
| regions |
| ) |
| |
| detector.classes = list(LABEL_MAP.keys()) |
| detector.conf = 0.35 |
| detector.scale = 1.5 |
|
|
| with st.spinner("Processing videoβthis may take a while..."): |
| df = detector.process_video() |
|
|
| st.success("β
Detection + annotation complete!") |
|
|
| |
| PCE = { |
| "auto": 0.8, |
| "bus": 4.0, |
| "car": 1.0, |
| "electric-rickshaw": 0.8, |
| "large-sized-truck": 4.5, |
| "medium-sized-truck": 3.5, |
| "motorbike": 0.5, |
| "small-sized-truck": 3.0, |
| } |
|
|
| for rid in regions.keys(): |
| def lane_pce(row, rid=rid): |
| total = 0.0 |
| for vt, factor in PCE.items(): |
| coln = f"{vt}_{rid}" |
| cnt = row.get(coln, 0) |
| total += int(cnt) * factor |
| return total |
|
|
| df[f"PCE_lane{rid}"] = df.apply(lane_pce, axis=1) |
| df[f"PCE_lane{rid}_avg"] = df[f"PCE_lane{rid}"].rolling(window=5, min_periods=1).mean() |
|
|
| |
| num_lanes = len(regions) |
| fig, axes = plt.subplots(num_lanes, 1, figsize=(10, 3 * num_lanes), sharex=True) |
|
|
| if num_lanes == 1: |
| axes = [axes] |
|
|
| for rid, ax in zip(regions.keys(), axes): |
| x = df["Frame Number"].values |
| y = df[f"PCE_lane{rid}_avg"].values |
|
|
| |
| ax.plot(x, y, color="gray", linewidth=1.2) |
|
|
| |
| colors_list = [] |
| for yi in y: |
| if yi < low_thresh: |
| colors_list.append("green") |
| elif yi > high_thresh: |
| colors_list.append("red") |
| else: |
| colors_list.append("yellow") |
|
|
| ax.scatter(x, y, c=colors_list, s=20, edgecolors="black", linewidths=0.3) |
|
|
| ax.set_title(f"Lane {rid} PCE (rolling average)") |
| ax.set_ylabel("PCE") |
| ax.grid(alpha=0.3) |
|
|
| axes[-1].set_xlabel("Frame Number") |
| plt.tight_layout() |
|
|
| st.subheader("π LaneβWise Congestion Plots") |
| st.pyplot(fig) |
|
|
| |
| st.subheader("π¬ Annotated Output Video") |
| with open(out_tmp, "rb") as f: |
| st.video(f.read()) |
|
|
| csv_bytes = df.to_csv(index=False).encode("utf-8") |
| st.download_button( |
| label="Download full counts + PCE CSV", |
| data=csv_bytes, |
| file_name="counts_and_pce.csv", |
| mime="text/csv" |
| ) |
|
|