| import os |
| import cv2 |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import pandas as pd |
| import collections |
| import tempfile |
| from ultralytics import YOLO |
| import math |
|
|
| class MouseTrackerAnalyzer: |
| """基于Ultralytics对象跟踪的鼠强迫游泳实验挣扎度分析器""" |
| def __init__(self, model_path, history_size=5, conf=0.25, iou=0.45, max_det=20, verbose=False): |
| |
| self.model = YOLO(model_path, task="segment", verbose=False) |
| self.history_size = history_size |
| self.verbose = verbose |
| self.struggle_threshold = 0.3 |
| |
| |
| self.conf = conf |
| self.iou = iou |
| self.max_det = max_det |
| |
| |
| self.colors = [ |
| (255, 0, 0), |
| (0, 255, 0), |
| (0, 0, 255), |
| (255, 255, 0), |
| (255, 0, 255), |
| (0, 255, 255), |
| (128, 0, 0), |
| (128, 0, 128), |
| (0, 128, 128), |
| (192, 192, 192), |
| (128, 128, 128), |
| (255, 128, 0), |
| (255, 0, 128), |
| (0, 128, 255), |
| (128, 255, 0), |
| (0, 255, 128) |
| ] |
| |
| self.prev_masks = {} |
| self.histories = {} |
| self.track_ids = set() |
| |
| |
| self.cap = None |
| self.writer = None |
| self.frame_id = 0 |
| self.results = [] |
| self.start_frame = 0 |
| self.end_frame = 0 |
|
|
| def init_video(self, video_path, output_path=None, start_frame=0, end_frame=None): |
| """初始化视频处理""" |
| |
| self.cap = cv2.VideoCapture(video_path) |
| if not self.cap.isOpened(): |
| raise IOError(f"无法打开视频 {video_path}") |
| |
| |
| width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| fps = self.cap.get(cv2.CAP_PROP_FPS) or 30 |
| self.fps = max(fps, 1.0) |
| total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| |
| if self.verbose: |
| print(f"视频尺寸: {width}x{height}, 帧率: {fps}, 总帧数: {total_frames}") |
| |
| |
| self.start_frame = start_frame |
| self.end_frame = end_frame if end_frame is not None else total_frames - 1 |
| |
| |
| if self.start_frame < 0: |
| self.start_frame = 0 |
| if self.end_frame >= total_frames: |
| self.end_frame = total_frames - 1 |
| if self.start_frame > self.end_frame: |
| self.start_frame, self.end_frame = self.end_frame, self.start_frame |
| |
| |
| if self.start_frame > 0: |
| self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.start_frame) |
| |
| |
| if output_path and output_path.lower().endswith(('.mp4', '.avi')): |
| |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| |
| self.writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
| if self.writer.isOpened(): |
| print(f"成功创建输出视频: {output_path}, 尺寸: {width}x{height}") |
| else: |
| print(f"警告: 无法创建输出视频 {output_path}") |
| |
| |
| self.frame_id = self.start_frame |
| self.results = [] |
| self.prev_masks.clear() |
| self.histories.clear() |
| self.track_ids.clear() |
| |
| if self.verbose: |
| print(f"视频初始化完成: 总帧数 {total_frames}, 分析范围 {self.start_frame}-{self.end_frame}") |
| |
| return total_frames, self.start_frame, self.end_frame |
|
|
| def process_frame(self, frame, frame_id): |
| """处理单帧,返回可视化帧和本帧结果列表""" |
| if self.verbose and frame_id % 10 == 0: |
| print(f"process_frame: 处理帧 {frame_id}") |
| |
| try: |
| |
| results = self.model.track( |
| frame, |
| persist=True, |
| conf=self.conf, |
| iou=self.iou, |
| max_det=self.max_det, |
| verbose=False |
| ) |
| |
| |
| frame_results = [] |
| |
| if results[0].boxes is None or len(results[0].boxes) == 0: |
| if self.verbose and frame_id % 50 == 0: |
| print("没有检测到任何对象") |
| return frame.copy(), [] |
| |
| |
| if hasattr(results[0], 'masks') and results[0].masks is not None: |
| |
| masks = results[0].masks.data.cpu().numpy() |
| track_ids = results[0].boxes.id |
| |
| if track_ids is None: |
| if self.verbose and frame_id % 50 == 0: |
| print("没有获取到跟踪ID") |
| return frame.copy(), [] |
| |
| track_ids = track_ids.int().cpu().numpy() |
| |
| if self.verbose and frame_id % 50 == 0: |
| print(f"检测到 {len(masks)} 个掩码,{len(track_ids)} 个跟踪ID") |
| |
| |
| for track_id in track_ids: |
| self.track_ids.add(int(track_id)) |
| |
| |
| for i, (mask, track_id) in enumerate(zip(masks, track_ids)): |
| track_id = int(track_id) |
| |
| |
| bin_mask = (mask > 0.2).astype(np.uint8) |
| |
| |
| kernel = np.ones((5,5), np.uint8) |
| bin_mask = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, kernel) |
| |
| |
| if bin_mask.shape != (frame.shape[0], frame.shape[1]): |
| bin_mask = cv2.resize(bin_mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) |
| |
| |
| if track_id in self.prev_masks: |
| prev_mask = self.prev_masks[track_id] |
| |
| if prev_mask.shape != bin_mask.shape: |
| prev_mask = cv2.resize(prev_mask, (bin_mask.shape[1], bin_mask.shape[0]), interpolation=cv2.INTER_NEAREST) |
| inter = np.logical_and(prev_mask > 0, bin_mask > 0).sum() |
| union = np.logical_or(prev_mask > 0, bin_mask > 0).sum() |
| iou = inter / union if union > 0 else 0 |
| score = 1 - iou |
| if self.verbose and frame_id % 50 == 0: |
| print(f"跟踪ID {track_id} 挣扎分数: {score:.4f} (IoU: {iou:.4f})") |
| else: |
| score = 0.0 |
| if self.verbose and frame_id % 50 == 0: |
| print(f"跟踪ID {track_id} 初始帧,分数为0") |
| |
| |
| self.prev_masks[track_id] = bin_mask |
| |
| if track_id not in self.histories: |
| self.histories[track_id] = collections.deque(maxlen=self.history_size) |
| self.histories[track_id].append(score) |
| |
| |
| is_struggling = score >= self.struggle_threshold |
| |
| |
| ys, xs = np.where(bin_mask > 0) |
| if len(xs) > 0: |
| centroid = (int(xs.mean()), int(ys.mean())) |
| else: |
| |
| box = results[0].boxes[i].xyxy.cpu().numpy()[0] |
| centroid = (int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)) |
| |
| |
| frame_results.append({ |
| 'id': track_id, |
| 'score': float(score), |
| 'centroid': centroid, |
| 'is_struggling': is_struggling |
| }) |
| else: |
| if self.verbose and frame_id % 50 == 0: |
| print("没有检测到任何掩码") |
| return frame.copy(), [] |
| |
| |
| annotated = frame.copy() |
| |
| |
| for result in frame_results: |
| track_id = result['id'] |
| color = self.colors[track_id % len(self.colors)] |
| |
| |
| if track_id in self.prev_masks: |
| mask = self.prev_masks[track_id] |
| |
| if mask.shape != (frame.shape[0], frame.shape[1]): |
| mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) |
| mask_overlay = np.zeros_like(frame) |
| mask_overlay[mask > 0] = color |
| |
| |
| contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| cv2.drawContours(annotated, contours, -1, color, 2) |
| |
| |
| cv2.addWeighted(annotated, 1.0, mask_overlay, 0.4, 0, annotated) |
| |
| |
| centroid = result['centroid'] |
| status_text = "Struggle" if result['is_struggling'] else "Static" |
| cv2.putText(annotated, f"ID:{track_id} {status_text}", |
| (centroid[0], centroid[1]), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) |
| |
| |
| cv2.rectangle(annotated, (0, 0), (frame.shape[1], 40), (0, 0, 0), -1) |
| |
| |
| struggling_count = sum(1 for r in frame_results if r['is_struggling']) |
| total_count = len(frame_results) |
| |
| |
| cv2.putText(annotated, f"Total: {total_count} Struggling: {struggling_count}", |
| (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) |
| |
| |
| |
| if annotated.dtype != np.uint8: |
| annotated = annotated.astype(np.uint8) |
| |
| return annotated, frame_results |
| |
| except Exception as e: |
| import traceback |
| if self.verbose: |
| print(f"处理帧时出错: {str(e)}") |
| traceback.print_exc() |
| |
| return frame.copy(), [] |
|
|
| def process_video(self, video_path, output_path=None, start_frame=0, end_frame=None, callback=None): |
| """处理整段视频,可选的回调函数用于更新进度""" |
| |
| total_frames, start, end = self.init_video(video_path, output_path, start_frame, end_frame) |
| self.results = [] |
| |
| frame_id = start |
| processed_frames = 0 |
| frames_to_process = end - start + 1 |
| last_progress = -1 |
| |
| |
| debug_frame_saved = False |
| |
| while frame_id <= end: |
| ret, frame = self.cap.read() |
| if not ret: |
| break |
| |
| |
| annotated, frame_res = self.process_frame(frame, frame_id) |
| self.results.append(frame_res) |
| |
| |
| if not debug_frame_saved and len(frame_res) > 0: |
| debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg") |
| cv2.imwrite(debug_frame_path, annotated) |
| print(f"调试: 保存了标注帧到 {debug_frame_path}") |
| debug_frame_saved = True |
| |
| |
| if self.writer: |
| |
| if len(annotated.shape) == 3 and annotated.shape[2] == 3: |
| |
| |
| if frame_id == start: |
| print(f"调试: 写入标注帧到视频,形状: {annotated.shape}") |
| |
| try: |
| self.writer.write(annotated) |
| except Exception as e: |
| print(f"调试: 写入帧到视频时出错: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| |
| |
| processed_frames += 1 |
| progress = int(100 * processed_frames / frames_to_process) |
| |
| if progress != last_progress and callback: |
| callback(progress, annotated, frame_res) |
| last_progress = progress |
| |
| frame_id += 1 |
| |
| |
| self.cap.release() |
| if self.writer: |
| self.writer.release() |
| print(f"调试: 视频写入完成,保存到: {output_path}") |
| |
| return self.results |
| |
| def save_results(self, csv_path): |
| """导出分析结果到 CSV""" |
| import csv |
| with open(csv_path, 'w', newline='') as f: |
| writer = csv.writer(f) |
| writer.writerow(['frame_id', 'mouse_id', 'score', 'is_struggling']) |
| for fid, frs in enumerate(self.results): |
| for fr in frs: |
| writer.writerow([ |
| fid + self.start_frame, |
| fr['id'], |
| f"{fr['score']:.4f}", |
| 1 if fr.get('is_struggling', False) else 0 |
| ]) |
| |
| def generate_time_series_plot(self, threshold=None): |
| """生成时序图分析""" |
| try: |
| print(f"Starting to generate time series plot with {len(self.results)} frames of data") |
| |
| if not self.results or len(self.results) < 10: |
| print("Not enough data for time series plot (need at least 10 frames)") |
| return None |
| |
| |
| if threshold is None: |
| threshold = self.struggle_threshold |
| |
| |
| fps = getattr(self, 'fps', None) |
| if fps is None or fps <= 0: |
| fps = 30 |
| print(f"Warning: Invalid frame rate detected, using default: {fps} fps") |
| else: |
| print(f"Using frame rate: {fps} fps") |
| |
| |
| frames = [] |
| mouse_data = {} |
| mouse_positions = {} |
| |
| for frame_id, frame_results in enumerate(self.results): |
| frames.append(frame_id + self.start_frame) |
| for result in frame_results: |
| mouse_id = result['id'] |
| if mouse_id not in mouse_data: |
| mouse_data[mouse_id] = {'frames': [], 'seconds': [], 'scores': [], 'struggling': []} |
| mouse_positions[mouse_id] = [] |
| |
| frame_num = frame_id + self.start_frame |
| second = frame_num / fps |
| |
| mouse_data[mouse_id]['frames'].append(frame_num) |
| mouse_data[mouse_id]['seconds'].append(second) |
| mouse_data[mouse_id]['scores'].append(result['score']) |
| mouse_data[mouse_id]['struggling'].append(1 if result.get('is_struggling', False) else 0) |
| |
| |
| if 'centroid' in result: |
| mouse_positions[mouse_id].append(result['centroid'][0]) |
| |
| print(f"Processed data for {len(mouse_data)} mice") |
| if not mouse_data: |
| print("No valid mouse data to plot") |
| return None |
| |
| |
| avg_positions = {} |
| for mouse_id, positions in mouse_positions.items(): |
| if positions: |
| avg_positions[mouse_id] = sum(positions) / len(positions) |
| else: |
| avg_positions[mouse_id] = float('inf') |
| |
| |
| sorted_mice = sorted(mouse_data.keys(), key=lambda mid: avg_positions.get(mid, float('inf'))) |
| print(f"Mice sorted from left to right: {sorted_mice}") |
| |
| |
| def smooth_data(data, window_size=5): |
| """使用移动平均平滑数据""" |
| if len(data) < window_size: |
| return data |
| smoothed = [] |
| for i in range(len(data)): |
| start = max(0, i - window_size // 2) |
| end = min(len(data), i + window_size // 2 + 1) |
| window = data[start:end] |
| smoothed.append(sum(window) / len(window)) |
| return smoothed |
| |
| |
| num_mice = len(mouse_data) |
| fig, axes = plt.subplots(num_mice, 1, figsize=(12, 4*num_mice), sharex=True) |
| |
| |
| if num_mice == 1: |
| axes = [axes] |
| |
| |
| for idx, mouse_id in enumerate(sorted_mice): |
| data = mouse_data[mouse_id] |
| ax = axes[idx] |
| |
| |
| smoothed_scores = smooth_data(data['scores'], window_size=5) |
| |
| |
| ax.plot(data['seconds'], smoothed_scores, label=f"Smoothed", color='blue', linewidth=2) |
| ax.plot(data['seconds'], data['scores'], label=f"Raw", color='lightblue', alpha=0.5, linewidth=1) |
| |
| |
| for i, is_struggling in enumerate(data['struggling']): |
| if is_struggling: |
| ax.axvspan(data['seconds'][i]-0.5/fps, data['seconds'][i]+0.5/fps, alpha=0.1, color='red') |
| |
| |
| ax.axhline(y=threshold, color='r', linestyle='--', label=f"Threshold ({threshold:.2f})") |
| |
| |
| ax.set_ylabel('Struggle Score') |
| position_text = f"(Position: Left #{sorted_mice.index(mouse_id)+1})" if mouse_id in avg_positions else "" |
| ax.set_title(f'Mouse {mouse_id} Struggle Score {position_text}') |
| ax.legend(loc='upper right') |
| ax.grid(True) |
| |
| |
| ax.set_ylim(-0.05, 1.05) |
| |
| |
| axes[-1].set_xlabel('Time (seconds)') |
| |
| |
| if frames: |
| start_time = self.start_frame / fps |
| end_time = max(frames) / fps |
| |
| axes[-1].set_xlim(start_time, end_time) |
| |
| |
| tick_interval = 0.1 |
| minor_ticks = np.arange(start_time, end_time + tick_interval, tick_interval) |
| axes[-1].set_xticks(minor_ticks, minor=True) |
| |
| |
| major_start = math.ceil(start_time) |
| major_end = math.floor(end_time) |
| major_ticks = np.arange(major_start, major_end + 1, 1.0) |
| axes[-1].set_xticks(major_ticks) |
| axes[-1].set_xticklabels([f"{int(t)}" for t in major_ticks]) |
| |
| |
| axes[-1].grid(True, which='both') |
| axes[-1].grid(which='minor', alpha=0.2) |
| axes[-1].grid(which='major', alpha=0.5) |
| |
| plt.tight_layout() |
| |
| |
| temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False) |
| plt.savefig(temp_file.name, dpi=150, bbox_inches='tight') |
| plt.close() |
| |
| print(f"Time series plot saved to: {temp_file.name}") |
| return temp_file.name |
| |
| except Exception as e: |
| import traceback |
| print(f"Error generating time series plot: {str(e)}") |
| traceback.print_exc() |
| return None |
|
|
| if __name__ == "__main__": |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="鼠强迫游泳实验挣扎度分析") |
| parser.add_argument('--video', type=str, required=True, help='输入视频路径') |
| parser.add_argument('--model', type=str, required=True, help='模型文件路径') |
| parser.add_argument('--output', type=str, help='输出视频路径') |
| parser.add_argument('--csv', type=str, help='输出CSV结果路径') |
| parser.add_argument('--conf', type=float, default=0.25, help='置信度阈值') |
| parser.add_argument('--iou', type=float, default=0.45, help='IOU阈值') |
| parser.add_argument('--max-det', type=int, default=20, help='最大检测数量') |
| parser.add_argument('--threshold', type=float, default=0.3, help='挣扎阈值') |
| parser.add_argument('--start', type=int, default=0, help='起始帧') |
| parser.add_argument('--end', type=int, default=None, help='结束帧') |
| parser.add_argument('--verbose', action='store_true', help='详细输出') |
| |
| args = parser.parse_args() |
| |
| |
| if not args.output: |
| video_name = os.path.splitext(os.path.basename(args.video))[0] |
| args.output = os.path.join(os.path.dirname(args.video), f"{video_name}_out.mp4") |
| |
| if not args.csv: |
| video_name = os.path.splitext(os.path.basename(args.video))[0] |
| args.csv = os.path.join(os.path.dirname(args.video), f"{video_name}_results.csv") |
| |
| |
| analyzer = MouseTrackerAnalyzer( |
| model_path=args.model, |
| conf=args.conf, |
| iou=args.iou, |
| max_det=args.max_det, |
| verbose=args.verbose |
| ) |
| analyzer.struggle_threshold = args.threshold |
| |
| |
| def progress_callback(progress, frame, results): |
| print(f"处理进度: {progress}%, 检测到 {len(results)} 个对象") |
| |
| |
| analyzer.process_video( |
| video_path=args.video, |
| output_path=args.output, |
| start_frame=args.start, |
| end_frame=args.end, |
| callback=progress_callback |
| ) |
| |
| |
| analyzer.save_results(args.csv) |
| |
| |
| plot_path = analyzer.generate_time_series_plot() |
| if plot_path: |
| print(f"挣扎度时序分析图已保存到: {plot_path}") |
| |
| print(f"分析完成,视频已保存到: {args.output}") |
| print(f"结果数据已保存到: {args.csv}") |