| import os |
| import cv2 |
| import numpy as np |
| import gradio as gr |
| import torch |
| from mouse_tracker import MouseTrackerAnalyzer |
| from huggingface_hub import hf_hub_download |
|
|
| |
| try: |
| import spaces |
| is_spaces = True |
| print("检测到 Hugging Face Spaces 环境") |
| except ImportError: |
| is_spaces = False |
| print("在本地环境运行") |
|
|
| |
| model_base_name = "fst-v1.3-n" |
| total_frames = 0 |
|
|
| |
| def get_model_file_path(model_suffix): |
| return f"./{model_base_name}{model_suffix}" |
|
|
| |
| def extract_frame(video_path, frame_num): |
| if not video_path: |
| return None |
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| return None |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) |
| ret, frame = cap.read() |
| cap.release() |
| if not ret: |
| return None |
| return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
| |
| def select_video(video_file, model_suffix): |
| global total_frames |
| if not video_file: |
| return None, "请选择视频文件", gr.Slider(0,0,0), gr.Slider(0,0,0) |
| total_frames = int(cv2.VideoCapture(video_file).get(cv2.CAP_PROP_FRAME_COUNT)) |
| |
| cap = cv2.VideoCapture(video_file) |
| ret, frame = cap.read() |
| cap.release() |
| if not ret: |
| return None, "无法读取视频帧", gr.Slider(0,0,0), gr.Slider(0,0,0) |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| |
| start = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1) |
| end = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1) |
| status = f"视频加载成功,总帧数: {total_frames}. 使用模型: {os.path.basename(get_model_file_path(model_suffix))}" |
| return frame_rgb, status, start, end |
|
|
| |
| def preview_frame(video_file, frame_num): |
| if not video_file: |
| return None, "请先选择视频文件" |
| frame = extract_frame(video_file, frame_num) |
| if frame is None: |
| return None, "无法读取指定帧" |
| return frame, f"帧 {frame_num}" |
|
|
| |
| def _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold): |
| if not video: |
| return None, None, "请选择视频文件" |
| if start_frame >= end_frame: |
| return None, None, "起始帧必须小于结束帧" |
| |
| video_name = os.path.splitext(os.path.basename(video))[0] |
| output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4") |
| csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv") |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| model_path = get_model_file_path(model_suffix) |
| if not os.path.exists(model_path): |
| if is_spaces: |
| try: |
| model_path = hf_hub_download( |
| repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME", |
| filename=f"weights/{model_base_name}{model_suffix}" |
| ) |
| except Exception: |
| print(f"下载模型失败: {model_path}") |
| else: |
| print(f"警告: 本地未找到模型文件 {model_path}") |
| |
| analyzer = MouseTrackerAnalyzer( |
| model_path=model_path, |
| conf=conf, |
| iou=iou, |
| max_det=max_det, |
| verbose=True |
| ) |
| analyzer.struggle_threshold = threshold |
| |
| analyzer.process_video( |
| video_path=video, |
| output_path=output_path, |
| start_frame=start_frame, |
| end_frame=end_frame, |
| callback=lambda prog, frm, res: print(f"进度: {prog}% 检测: {len(res)} 项") |
| ) |
| analyzer.save_results(csv_path) |
| |
| plot_path = None |
| if analyzer.results: |
| plot_path = analyzer.generate_time_series_plot() |
| status = f"分析完成。视频: {output_path}, CSV: {csv_path}" |
| if plot_path: |
| status += f", 图表: {plot_path}" |
| return output_path, plot_path, status |
|
|
| |
| if is_spaces: |
| @spaces.GPU(duration=120) |
| def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold): |
| return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold) |
| else: |
| def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold): |
| return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold) |
|
|
| |
| def create_interface(): |
| with gr.Blocks(title="鼠强迫游泳挣扎度分析") as app: |
| gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)") |
| with gr.Row(): |
| with gr.Column(scale=1): |
| video_input = gr.Video(label="输入视频") |
| model_format = gr.Dropdown( |
| label="模型格式", |
| choices=[".onnx", ".engine", ".pt", ".mlpackage"], |
| value=".onnx", |
| interactive=True |
| ) |
| device_info = gr.Textbox( |
| label="系统信息", |
| value=f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}", |
| interactive=False |
| ) |
| conf = gr.Slider(0.1, 0.9, value=0.25, step=0.05, label="置信度阈值") |
| iou = gr.Slider(0.1, 0.9, value=0.45, step=0.05, label="IoU阈值") |
| max_det = gr.Slider(1, 50, value=20, step=1, label="最大检测数") |
| threshold = gr.Slider(0, 1, value=0.3, step=0.01, label="挣扎阈值") |
| start_frame = gr.Slider(0, 999999, value=0, step=1, label="起始帧") |
| end_frame = gr.Slider(0, 999999, value=999999, step=1, label="结束帧") |
| preview_btn = gr.Button("预览帧") |
| start_btn = gr.Button("开始分析", variant="primary") |
| with gr.Column(scale=2): |
| with gr.Tab("预览"): |
| preview_image = gr.Image(label="预览图像", type="numpy", height=400) |
| status_text = gr.Textbox(label="状态", interactive=False) |
| with gr.Tab("结果"): |
| output_video = gr.Video(label="分析结果视频") |
| result_plot = gr.Image(label="挣扎分数时间序列") |
| result_status = gr.Textbox(label="分析状态", interactive=False) |
| |
| video_input.change(select_video, inputs=[video_input, model_format], outputs=[preview_image, status_text, start_frame, end_frame]) |
| preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text]) |
| start_btn.click( |
| start_analysis, |
| inputs=[video_input, model_format, conf, iou, max_det, start_frame, end_frame, threshold], |
| outputs=[output_video, result_plot, result_status] |
| ) |
| return app |
|
|
| if __name__ == "__main__": |
| |
| for key in ['http_proxy', 'https_proxy', 'all_proxy']: |
| os.environ.pop(key, None) |
| print(f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}") |
| print(f"默认模型路径: {get_model_file_path('.onnx')}") |
| app = create_interface() |
| if is_spaces: |
| app.launch() |
| else: |
| app.launch(server_name="0.0.0.0", server_port=7860, share=False) |
|
|