VisionLanguageGroup commited on
Commit
0b472f0
·
1 Parent(s): aff3c6f
Files changed (4) hide show
  1. app.py +30 -81
  2. inference_count.py +19 -39
  3. inference_seg.py +0 -42
  4. inference_track.py +28 -32
app.py CHANGED
@@ -18,7 +18,6 @@ from natsort import natsorted
18
  from huggingface_hub import HfApi, upload_file
19
  # import spaces
20
 
21
- # ===== 导入三个推理模块 =====
22
  from inference_seg import load_model as load_seg_model, run as run_seg
23
  from inference_count import load_model as load_count_model, run as run_count
24
  from inference_track import load_model as load_track_model, run as run_track
@@ -27,7 +26,6 @@ HF_TOKEN = os.getenv("HF_TOKEN")
27
  DATASET_REPO = "phoebe777777/celltool_feedback"
28
 
29
 
30
- # ===== 清理缓存目录 =====
31
  print("===== clearing cache =====")
32
  # cache_path = os.path.expanduser("~/.cache/")
33
  cache_path = os.path.expanduser("~/.cache/huggingface/gradio")
@@ -39,7 +37,6 @@ if os.path.exists(cache_path):
39
  except:
40
  pass
41
 
42
- # ===== 全局模型变量 =====
43
  SEG_MODEL = None
44
  SEG_DEVICE = torch.device("cpu")
45
 
@@ -50,7 +47,6 @@ TRACK_MODEL = None
50
  TRACK_DEVICE = torch.device("cpu")
51
 
52
  def load_all_models():
53
- """启动时加载所有模型"""
54
  global SEG_MODEL, SEG_DEVICE
55
  global COUNT_MODEL, COUNT_DEVICE
56
  global TRACK_MODEL, TRACK_DEVICE
@@ -76,14 +72,12 @@ def load_all_models():
76
 
77
  load_all_models()
78
 
79
- # ===== 保存用户反馈 =====
80
  DATASET_DIR = Path("solver_cache")
81
  DATASET_DIR.mkdir(parents=True, exist_ok=True)
82
 
83
  def save_feedback_to_hf(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
84
- """保存反馈到 Hugging Face Dataset"""
85
 
86
- # 如果没有 token,回退到本地存储
87
  if not HF_TOKEN:
88
  print("⚠️ No HF_TOKEN found, using local storage")
89
  save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes)
@@ -102,13 +96,11 @@ def save_feedback_to_hf(query_id, feedback_type, feedback_text=None, img_path=No
102
  try:
103
  api = HfApi()
104
 
105
- # 创建临时文件
106
  filename = f"feedback_{query_id}_{int(time.time())}.json"
107
 
108
  with open(filename, 'w', encoding='utf-8') as f:
109
  json.dump(feedback_data, f, indent=2, ensure_ascii=False)
110
 
111
- # 上传到 dataset
112
  api.upload_file(
113
  path_or_fileobj=filename,
114
  path_in_repo=f"data/{filename}",
@@ -117,19 +109,17 @@ def save_feedback_to_hf(query_id, feedback_type, feedback_text=None, img_path=No
117
  token=HF_TOKEN
118
  )
119
 
120
- # 清理本地文件
121
  os.remove(filename)
122
 
123
  print(f"✅ Feedback saved to HF Dataset: {DATASET_REPO}")
124
 
125
  except Exception as e:
126
  print(f"⚠️ Failed to save to HF Dataset: {e}")
127
- # 回退到本地存储
128
  save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes)
129
 
130
 
131
  def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
132
- """保存用户反馈到JSON文件"""
133
  feedback_data = {
134
  "query_id": query_id,
135
  "feedback_type": feedback_type,
@@ -154,9 +144,8 @@ def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bb
154
  with feedback_file.open("w") as f:
155
  json.dump(feedback_data, f, indent=4, ensure_ascii=False)
156
 
157
- # ===== 辅助函数 =====
158
  def parse_first_bbox(bboxes):
159
- """解析第一个边界框"""
160
  if not bboxes:
161
  return None
162
  b = bboxes[0]
@@ -169,7 +158,7 @@ def parse_first_bbox(bboxes):
169
  return None
170
 
171
  def parse_bboxes(bboxes):
172
- """解析所有边界框"""
173
  if not bboxes:
174
  return None
175
 
@@ -185,7 +174,7 @@ def parse_bboxes(bboxes):
185
  return result
186
 
187
  def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
188
- """将实例掩码转换为彩色图像"""
189
  def hsv_to_rgb(h, s, v):
190
  i = int(h * 6.0)
191
  f = h * 6.0 - i
@@ -210,10 +199,9 @@ def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
210
  color_idx = mask % num_colors
211
  return palette_arr[color_idx]
212
 
213
- # ===== 分割功能 =====
214
  # @spaces.GPU
215
  def segment_with_choice(use_box_choice, annot_value):
216
- """分割主函数 - 每个实例不同颜色+轮廓"""
217
  if annot_value is None or len(annot_value) < 1:
218
  print("❌ No annotation input")
219
  return None, None
@@ -224,18 +212,12 @@ def segment_with_choice(use_box_choice, annot_value):
224
  print(f"🖼️ Image path: {img_path}")
225
  box_array = None
226
  if use_box_choice == "Yes" and bboxes:
227
- # box = parse_first_bbox(bboxes)
228
- # if box:
229
- # xmin, ymin, xmax, ymax = map(int, box)
230
- # box_array = [[xmin, ymin, xmax, ymax]]
231
- # print(f"📦 Using bounding box: {box_array}")
232
  box = parse_bboxes(bboxes)
233
  if box:
234
  box_array = box
235
  print(f"📦 Using bounding boxes: {box_array}")
236
 
237
 
238
- # 运行分割模型
239
  try:
240
  mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
241
  print("📏 mask shape:", mask.shape, "dtype:", mask.dtype, "unique:", np.unique(mask))
@@ -243,13 +225,11 @@ def segment_with_choice(use_box_choice, annot_value):
243
  print(f"❌ Inference failed: {str(e)}")
244
  return None, None
245
 
246
- # 保存原始mask为TIF文件
247
  temp_mask_file = tempfile.NamedTemporaryFile(delete=False, suffix=".tif")
248
  mask_img = Image.fromarray(mask.astype(np.uint16))
249
  mask_img.save(temp_mask_file.name)
250
  print(f"💾 Original mask saved to: {temp_mask_file.name}")
251
 
252
- # 读取原图
253
  try:
254
  img = Image.open(img_path)
255
  print("📷 Image mode:", img.mode, "size:", img.size)
@@ -276,24 +256,19 @@ def segment_with_choice(use_box_choice, annot_value):
276
  print("⚠️ No instance found, returning dummy red image")
277
  return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None
278
 
279
- # ==== Color Overlay (每个实例一个颜色) ====
280
  overlay = img_np.copy()
281
  alpha = 0.5
282
- # cmap = cm.get_cmap("hsv", num_instances + 1)
283
 
284
  for inst_id in np.unique(inst_mask):
285
  if inst_id == 0:
286
  continue
287
  binary_mask = (inst_mask == inst_id).astype(np.uint8)
288
- # color = np.array(cmap(inst_id / (num_instances + 1))[:3]) # RGB only, ignore alpha
289
  color = get_well_spaced_color(inst_id)
290
  overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color
291
 
292
- # 绘制轮廓
293
  contours = measure.find_contours(binary_mask, 0.5)
294
  for contour in contours:
295
  contour = contour.astype(np.int32)
296
- # 确保坐标在范围内
297
  valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1)
298
  valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1)
299
  overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] # 黄色轮廓
@@ -302,7 +277,7 @@ def segment_with_choice(use_box_choice, annot_value):
302
 
303
  return Image.fromarray(overlay), temp_mask_file.name
304
 
305
- # ===== 计数功能 =====
306
  # @spaces.GPU
307
  def count_cells_handler(use_box_choice, annot_value):
308
  """Counting handler - supports bounding box, returns only density map"""
@@ -315,11 +290,6 @@ def count_cells_handler(use_box_choice, annot_value):
315
  print(f"🖼️ Image path: {image_path}")
316
  box_array = None
317
  if use_box_choice == "Yes" and bboxes:
318
- # box = parse_first_bbox(bboxes)
319
- # if box:
320
- # xmin, ymin, xmax, ymax = map(int, box)
321
- # box_array = [[xmin, ymin, xmax, ymax]]
322
- # print(f"📦 Using bounding box: {box_array}")
323
  box = parse_bboxes(bboxes)
324
  if box:
325
  box_array = box
@@ -341,7 +311,6 @@ def count_cells_handler(use_box_choice, annot_value):
341
 
342
  count = result['count']
343
  density_map = result['density_map']
344
- # save density map as temp file
345
  temp_density_file = tempfile.NamedTemporaryFile(delete=False, suffix=".npy")
346
  np.save(temp_density_file.name, density_map)
347
  print(f"💾 Density map saved to {temp_density_file.name}")
@@ -365,20 +334,16 @@ def count_cells_handler(use_box_choice, annot_value):
365
  return None, None
366
 
367
 
368
- # Normalize density map to [0, 1]
369
  density_normalized = density_map.copy()
370
  if density_normalized.max() > 0:
371
  density_normalized = (density_normalized - density_normalized.min()) / (density_normalized.max() - density_normalized.min())
372
 
373
- # Apply colormap
374
  cmap = cm.get_cmap("jet")
375
  alpha = 0.3
376
  density_colored = cmap(density_normalized)[:, :, :3] # RGB only, ignore alpha
377
 
378
- # Create overlay
379
  overlay = img_np.copy()
380
 
381
- # Blend only where density is significant (optional: threshold)
382
  threshold = 0.01 # Only overlay where density > 1% of max
383
  significant_mask = density_normalized > threshold
384
 
@@ -400,7 +365,6 @@ def count_cells_handler(use_box_choice, annot_value):
400
 
401
  return Image.fromarray(overlay), temp_density_file.name, result_text
402
 
403
- # return density_path, result_text
404
 
405
  except Exception as e:
406
  print(f"❌ Counting error: {e}")
@@ -408,7 +372,7 @@ def count_cells_handler(use_box_choice, annot_value):
408
  traceback.print_exc()
409
  return None, f"❌ Counting failed: {str(e)}"
410
 
411
- # ===== Tracking Functionality =====
412
  def find_tif_dir(root_dir):
413
  """Recursively find the first directory containing .tif files"""
414
  for dirpath, _, filenames in os.walk(root_dir):
@@ -502,14 +466,13 @@ def create_ctc_results_zip(output_dir):
502
  print(f"✅ ZIP created: {zip_path} ({os.path.getsize(zip_path) / 1024:.1f} KB)")
503
  return zip_path
504
 
505
- # 使用更智能的颜色分配 - 让相邻的ID颜色差异更大
506
  def get_well_spaced_color(track_id, num_colors=256):
507
  """Generate well-spaced colors, using contrasting colors for adjacent IDs"""
508
- # 使用质数跳跃来分散颜色
509
  golden_ratio = 0.618033988749895
510
  hue = (track_id * golden_ratio) % 1.0
511
-
512
- # 使用高饱和度和明度
513
  import colorsys
514
  rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
515
  return np.array(rgb)
@@ -568,14 +531,7 @@ def create_tracking_visualization(tif_dir, output_dir, valid_tif_files):
568
  return valid_tif_files[0]
569
 
570
  print(f"📊 Found {len(mask_files)} mask files")
571
-
572
- # Create color map for consistent track IDs
573
- # Use a colormap with many distinct colors
574
- # try:
575
- # cmap = colormaps.get_cmap("hsv")
576
- # except:
577
- # from matplotlib import cm
578
- # cmap = cm.get_cmap("hsv")
579
 
580
  frames = []
581
  alpha = 0.3 # Transparency for overlay
@@ -710,19 +666,19 @@ def create_tracking_visualization(tif_dir, output_dir, valid_tif_files):
710
  # @spaces.GPU
711
  def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj):
712
  """
713
- 支持 ZIP 压缩包上传的 Tracking 处理函数 - 支持首帧边界框
714
 
715
  Parameters:
716
  -----------
717
  use_box_choice : str
718
- "Yes" or "No" - 是否使用边界框
719
  first_frame_annot : tuple or None
720
  (image_path, bboxes) from BBoxAnnotator, only used if user annotated first frame
721
  zip_file_obj : File
722
  Uploaded ZIP file containing TIF sequence
723
  """
724
  if zip_file_obj is None:
725
- return None, "⚠️ 请上传包含视频帧的压缩包 (.zip)", None, None
726
 
727
  temp_dir = None
728
  output_temp_dir = None
@@ -734,11 +690,6 @@ def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj):
734
  if isinstance(first_frame_annot, (list, tuple)) and len(first_frame_annot) > 1:
735
  bboxes = first_frame_annot[1]
736
  if bboxes:
737
- # box = parse_first_bbox(bboxes)
738
- # if box:
739
- # xmin, ymin, xmax, ymax = map(int, box)
740
- # box_array = [[xmin, ymin, xmax, ymax]]
741
- # print(f"📦 Using bounding box: {box_array}")
742
  box = parse_bboxes(bboxes)
743
  if box:
744
  box_array = box
@@ -880,9 +831,8 @@ def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj):
880
 
881
 
882
 
883
- # ===== 示例图像 =====
884
  example_images_seg = [f for f in glob("example_imgs/seg/*")]
885
- # ["example_imgs/seg/003_img.png", "example_imgs/seg/1977_Well_F-5_Field_1.png"]
886
  example_images_cnt = [f for f in glob("example_imgs/cnt/*")]
887
  example_tracking_zips = [f for f in glob("example_imgs/tra/*.zip")]
888
 
@@ -909,7 +859,6 @@ with gr.Blocks(
909
  object-fit: contain !important;
910
  }
911
 
912
- /* 强制密度图容器和图片高度 */
913
  #density_map_output {
914
  height: 500px !important;
915
  }
@@ -939,7 +888,7 @@ with gr.Blocks(
939
 
940
  # 全局状态
941
  current_query_id = gr.State(str(uuid.uuid4()))
942
- user_uploaded_examples = gr.State(example_images_seg.copy()) # 初始化时包含原始示例
943
 
944
  with gr.Tabs():
945
  # ===== Tab 1: Segmentation =====
@@ -1031,27 +980,27 @@ with gr.Blocks(
1031
  visible=False
1032
  )
1033
 
1034
- # 绑定事件: 运行分割
1035
  run_seg_btn.click(
1036
  fn=segment_with_choice,
1037
  inputs=[use_box_radio, annotator],
1038
  outputs=[seg_output, download_mask_btn]
1039
  )
1040
 
1041
- # 清空按钮事件
1042
  clear_btn.click(
1043
  fn=lambda: None,
1044
  inputs=None,
1045
  outputs=annotator
1046
  )
1047
 
1048
- # 初始化Gallery显示
1049
  demo.load(
1050
  fn=lambda: example_images_seg.copy(),
1051
  outputs=example_gallery
1052
  )
1053
 
1054
- # 绑定事件: 上传示例图片
1055
  def add_to_gallery(img_path, current_imgs):
1056
  if not img_path:
1057
  return current_imgs
@@ -1072,7 +1021,7 @@ with gr.Blocks(
1072
  outputs=example_gallery
1073
  )
1074
 
1075
- # 绑定事件: 点击Gallery加载
1076
  def load_from_gallery(evt: gr.SelectData, all_imgs):
1077
  if evt.index is not None and evt.index < len(all_imgs):
1078
  return all_imgs[evt.index]
@@ -1084,7 +1033,7 @@ with gr.Blocks(
1084
  outputs=annotator
1085
  )
1086
 
1087
- # 绑定事���: 提交反馈
1088
  def submit_user_feedback(query_id, score, comment, annot_val):
1089
  try:
1090
  img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
@@ -1097,7 +1046,7 @@ with gr.Blocks(
1097
  # img_path=img_path,
1098
  # bboxes=bboxes
1099
  # )
1100
- # 使用 HF 存储
1101
  save_feedback_to_hf(
1102
  query_id=query_id,
1103
  feedback_type=f"score_{int(score)}",
@@ -1259,14 +1208,14 @@ with gr.Blocks(
1259
  outputs=[count_output, download_density_btn, count_status]
1260
  )
1261
 
1262
- # 清空按钮事件
1263
  clear_btn.click(
1264
  fn=lambda: None,
1265
  inputs=None,
1266
  outputs=count_annotator
1267
  )
1268
 
1269
- # 绑定事件: 提交反馈
1270
  def submit_user_feedback(query_id, score, comment, annot_val):
1271
  try:
1272
  img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
@@ -1279,7 +1228,7 @@ with gr.Blocks(
1279
  # img_path=img_path,
1280
  # bboxes=bboxes
1281
  # )
1282
- # 使用 HF 存储
1283
  save_feedback_to_hf(
1284
  query_id=query_id,
1285
  feedback_type=f"score_{int(score)}",
@@ -1581,14 +1530,14 @@ with gr.Blocks(
1581
  outputs=[track_download, track_output, track_download, track_first_frame_preview]
1582
  )
1583
 
1584
- # 清空按钮事件
1585
  clear_btn.click(
1586
  fn=lambda: None,
1587
  inputs=None,
1588
  outputs=track_first_frame_annotator
1589
  )
1590
 
1591
- # 绑定事件: 提交反馈
1592
  def submit_user_feedback(query_id, score, comment, annot_val):
1593
  try:
1594
  img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
@@ -1601,7 +1550,7 @@ with gr.Blocks(
1601
  # img_path=img_path,
1602
  # bboxes=bboxes
1603
  # )
1604
- # 使用 HF 存储
1605
  save_feedback_to_hf(
1606
  query_id=query_id,
1607
  feedback_type=f"score_{int(score)}",
 
18
  from huggingface_hub import HfApi, upload_file
19
  # import spaces
20
 
 
21
  from inference_seg import load_model as load_seg_model, run as run_seg
22
  from inference_count import load_model as load_count_model, run as run_count
23
  from inference_track import load_model as load_track_model, run as run_track
 
26
  DATASET_REPO = "phoebe777777/celltool_feedback"
27
 
28
 
 
29
  print("===== clearing cache =====")
30
  # cache_path = os.path.expanduser("~/.cache/")
31
  cache_path = os.path.expanduser("~/.cache/huggingface/gradio")
 
37
  except:
38
  pass
39
 
 
40
  SEG_MODEL = None
41
  SEG_DEVICE = torch.device("cpu")
42
 
 
47
  TRACK_DEVICE = torch.device("cpu")
48
 
49
  def load_all_models():
 
50
  global SEG_MODEL, SEG_DEVICE
51
  global COUNT_MODEL, COUNT_DEVICE
52
  global TRACK_MODEL, TRACK_DEVICE
 
72
 
73
  load_all_models()
74
 
 
75
  DATASET_DIR = Path("solver_cache")
76
  DATASET_DIR.mkdir(parents=True, exist_ok=True)
77
 
78
  def save_feedback_to_hf(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
79
+ """Save feedback to Hugging Face Dataset"""
80
 
 
81
  if not HF_TOKEN:
82
  print("⚠️ No HF_TOKEN found, using local storage")
83
  save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes)
 
96
  try:
97
  api = HfApi()
98
 
 
99
  filename = f"feedback_{query_id}_{int(time.time())}.json"
100
 
101
  with open(filename, 'w', encoding='utf-8') as f:
102
  json.dump(feedback_data, f, indent=2, ensure_ascii=False)
103
 
 
104
  api.upload_file(
105
  path_or_fileobj=filename,
106
  path_in_repo=f"data/{filename}",
 
109
  token=HF_TOKEN
110
  )
111
 
 
112
  os.remove(filename)
113
 
114
  print(f"✅ Feedback saved to HF Dataset: {DATASET_REPO}")
115
 
116
  except Exception as e:
117
  print(f"⚠️ Failed to save to HF Dataset: {e}")
 
118
  save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes)
119
 
120
 
121
  def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
122
+ """Save feedback to local JSON file"""
123
  feedback_data = {
124
  "query_id": query_id,
125
  "feedback_type": feedback_type,
 
144
  with feedback_file.open("w") as f:
145
  json.dump(feedback_data, f, indent=4, ensure_ascii=False)
146
 
 
147
  def parse_first_bbox(bboxes):
148
+ """Parse the first bounding box from the annotation input, supports dict or list format"""
149
  if not bboxes:
150
  return None
151
  b = bboxes[0]
 
158
  return None
159
 
160
  def parse_bboxes(bboxes):
161
+ """Parse all bounding boxes from the annotation input"""
162
  if not bboxes:
163
  return None
164
 
 
174
  return result
175
 
176
  def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
177
+ """Convert a 2D mask of instance IDs to a color image for visualization."""
178
  def hsv_to_rgb(h, s, v):
179
  i = int(h * 6.0)
180
  f = h * 6.0 - i
 
199
  color_idx = mask % num_colors
200
  return palette_arr[color_idx]
201
 
 
202
  # @spaces.GPU
203
  def segment_with_choice(use_box_choice, annot_value):
204
+ """Segmentation handler - supports bounding box, returns colorized overlay and original mask path"""
205
  if annot_value is None or len(annot_value) < 1:
206
  print("❌ No annotation input")
207
  return None, None
 
212
  print(f"🖼️ Image path: {img_path}")
213
  box_array = None
214
  if use_box_choice == "Yes" and bboxes:
 
 
 
 
 
215
  box = parse_bboxes(bboxes)
216
  if box:
217
  box_array = box
218
  print(f"📦 Using bounding boxes: {box_array}")
219
 
220
 
 
221
  try:
222
  mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
223
  print("📏 mask shape:", mask.shape, "dtype:", mask.dtype, "unique:", np.unique(mask))
 
225
  print(f"❌ Inference failed: {str(e)}")
226
  return None, None
227
 
 
228
  temp_mask_file = tempfile.NamedTemporaryFile(delete=False, suffix=".tif")
229
  mask_img = Image.fromarray(mask.astype(np.uint16))
230
  mask_img.save(temp_mask_file.name)
231
  print(f"💾 Original mask saved to: {temp_mask_file.name}")
232
 
 
233
  try:
234
  img = Image.open(img_path)
235
  print("📷 Image mode:", img.mode, "size:", img.size)
 
256
  print("⚠️ No instance found, returning dummy red image")
257
  return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None
258
 
 
259
  overlay = img_np.copy()
260
  alpha = 0.5
 
261
 
262
  for inst_id in np.unique(inst_mask):
263
  if inst_id == 0:
264
  continue
265
  binary_mask = (inst_mask == inst_id).astype(np.uint8)
 
266
  color = get_well_spaced_color(inst_id)
267
  overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color
268
 
 
269
  contours = measure.find_contours(binary_mask, 0.5)
270
  for contour in contours:
271
  contour = contour.astype(np.int32)
 
272
  valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1)
273
  valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1)
274
  overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] # 黄色轮廓
 
277
 
278
  return Image.fromarray(overlay), temp_mask_file.name
279
 
280
+
281
  # @spaces.GPU
282
  def count_cells_handler(use_box_choice, annot_value):
283
  """Counting handler - supports bounding box, returns only density map"""
 
290
  print(f"🖼️ Image path: {image_path}")
291
  box_array = None
292
  if use_box_choice == "Yes" and bboxes:
 
 
 
 
 
293
  box = parse_bboxes(bboxes)
294
  if box:
295
  box_array = box
 
311
 
312
  count = result['count']
313
  density_map = result['density_map']
 
314
  temp_density_file = tempfile.NamedTemporaryFile(delete=False, suffix=".npy")
315
  np.save(temp_density_file.name, density_map)
316
  print(f"💾 Density map saved to {temp_density_file.name}")
 
334
  return None, None
335
 
336
 
 
337
  density_normalized = density_map.copy()
338
  if density_normalized.max() > 0:
339
  density_normalized = (density_normalized - density_normalized.min()) / (density_normalized.max() - density_normalized.min())
340
 
 
341
  cmap = cm.get_cmap("jet")
342
  alpha = 0.3
343
  density_colored = cmap(density_normalized)[:, :, :3] # RGB only, ignore alpha
344
 
 
345
  overlay = img_np.copy()
346
 
 
347
  threshold = 0.01 # Only overlay where density > 1% of max
348
  significant_mask = density_normalized > threshold
349
 
 
365
 
366
  return Image.fromarray(overlay), temp_density_file.name, result_text
367
 
 
368
 
369
  except Exception as e:
370
  print(f"❌ Counting error: {e}")
 
372
  traceback.print_exc()
373
  return None, f"❌ Counting failed: {str(e)}"
374
 
375
+
376
  def find_tif_dir(root_dir):
377
  """Recursively find the first directory containing .tif files"""
378
  for dirpath, _, filenames in os.walk(root_dir):
 
466
  print(f"✅ ZIP created: {zip_path} ({os.path.getsize(zip_path) / 1024:.1f} KB)")
467
  return zip_path
468
 
469
+
470
  def get_well_spaced_color(track_id, num_colors=256):
471
  """Generate well-spaced colors, using contrasting colors for adjacent IDs"""
472
+
473
  golden_ratio = 0.618033988749895
474
  hue = (track_id * golden_ratio) % 1.0
475
+
 
476
  import colorsys
477
  rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
478
  return np.array(rgb)
 
531
  return valid_tif_files[0]
532
 
533
  print(f"📊 Found {len(mask_files)} mask files")
534
+
 
 
 
 
 
 
 
535
 
536
  frames = []
537
  alpha = 0.3 # Transparency for overlay
 
666
  # @spaces.GPU
667
  def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj):
668
  """
669
+ Tracking handler - processes a ZIP of TIF frames, supports bounding box, returns visualization and results ZIP
670
 
671
  Parameters:
672
  -----------
673
  use_box_choice : str
674
+ "Yes" or "No" - whether to use bounding box annotation for tracking
675
  first_frame_annot : tuple or None
676
  (image_path, bboxes) from BBoxAnnotator, only used if user annotated first frame
677
  zip_file_obj : File
678
  Uploaded ZIP file containing TIF sequence
679
  """
680
  if zip_file_obj is None:
681
+ return None, "⚠️ Please upload a ZIP file containing video frames (.zip)", None, None
682
 
683
  temp_dir = None
684
  output_temp_dir = None
 
690
  if isinstance(first_frame_annot, (list, tuple)) and len(first_frame_annot) > 1:
691
  bboxes = first_frame_annot[1]
692
  if bboxes:
 
 
 
 
 
693
  box = parse_bboxes(bboxes)
694
  if box:
695
  box_array = box
 
831
 
832
 
833
 
834
+ # ===== Example Images =====
835
  example_images_seg = [f for f in glob("example_imgs/seg/*")]
 
836
  example_images_cnt = [f for f in glob("example_imgs/cnt/*")]
837
  example_tracking_zips = [f for f in glob("example_imgs/tra/*.zip")]
838
 
 
859
  object-fit: contain !important;
860
  }
861
 
 
862
  #density_map_output {
863
  height: 500px !important;
864
  }
 
888
 
889
  # 全局状态
890
  current_query_id = gr.State(str(uuid.uuid4()))
891
+ user_uploaded_examples = gr.State(example_images_seg.copy())
892
 
893
  with gr.Tabs():
894
  # ===== Tab 1: Segmentation =====
 
980
  visible=False
981
  )
982
 
983
+ # click event for segmentation
984
  run_seg_btn.click(
985
  fn=segment_with_choice,
986
  inputs=[use_box_radio, annotator],
987
  outputs=[seg_output, download_mask_btn]
988
  )
989
 
990
+ # click event for clear button
991
  clear_btn.click(
992
  fn=lambda: None,
993
  inputs=None,
994
  outputs=annotator
995
  )
996
 
997
+ # init Gallery with example images
998
  demo.load(
999
  fn=lambda: example_images_seg.copy(),
1000
  outputs=example_gallery
1001
  )
1002
 
1003
+ # click event for image uploader
1004
  def add_to_gallery(img_path, current_imgs):
1005
  if not img_path:
1006
  return current_imgs
 
1021
  outputs=example_gallery
1022
  )
1023
 
1024
+ # click event for Gallery selection
1025
  def load_from_gallery(evt: gr.SelectData, all_imgs):
1026
  if evt.index is not None and evt.index < len(all_imgs):
1027
  return all_imgs[evt.index]
 
1033
  outputs=annotator
1034
  )
1035
 
1036
+ # click event for submitting feedback
1037
  def submit_user_feedback(query_id, score, comment, annot_val):
1038
  try:
1039
  img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
 
1046
  # img_path=img_path,
1047
  # bboxes=bboxes
1048
  # )
1049
+
1050
  save_feedback_to_hf(
1051
  query_id=query_id,
1052
  feedback_type=f"score_{int(score)}",
 
1208
  outputs=[count_output, download_density_btn, count_status]
1209
  )
1210
 
1211
+ # Clear selection
1212
  clear_btn.click(
1213
  fn=lambda: None,
1214
  inputs=None,
1215
  outputs=count_annotator
1216
  )
1217
 
1218
+ # Submit feedback
1219
  def submit_user_feedback(query_id, score, comment, annot_val):
1220
  try:
1221
  img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
 
1228
  # img_path=img_path,
1229
  # bboxes=bboxes
1230
  # )
1231
+
1232
  save_feedback_to_hf(
1233
  query_id=query_id,
1234
  feedback_type=f"score_{int(score)}",
 
1530
  outputs=[track_download, track_output, track_download, track_first_frame_preview]
1531
  )
1532
 
1533
+ # Clear selection
1534
  clear_btn.click(
1535
  fn=lambda: None,
1536
  inputs=None,
1537
  outputs=track_first_frame_annotator
1538
  )
1539
 
1540
+ # Submit feedback
1541
  def submit_user_feedback(query_id, score, comment, annot_val):
1542
  try:
1543
  img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
 
1550
  # img_path=img_path,
1551
  # bboxes=bboxes
1552
  # )
1553
+
1554
  save_feedback_to_hf(
1555
  query_id=query_id,
1556
  feedback_type=f"score_{int(score)}",
inference_count.py CHANGED
@@ -1,5 +1,4 @@
1
  # inference_count.py
2
- # 计数模型推理模块 - 独立版本
3
 
4
  import torch
5
  import numpy as np
@@ -15,24 +14,22 @@ DEVICE = torch.device("cpu")
15
 
16
  def load_model(use_box=False):
17
  """
18
- 加载计数模型
19
 
20
  Args:
21
- use_box: 是否使用边界框
22
 
23
  Returns:
24
- model: 加载的模型
25
- device: 设备
26
  """
27
  global MODEL, DEVICE
28
 
29
  try:
30
  print("🔄 Loading counting model...")
31
-
32
- # 初始化模型
33
  MODEL = CountingModule(use_box=use_box)
34
 
35
- # 从 Hugging Face Hub 下载权重
36
  ckpt_path = hf_hub_download(
37
  repo_id="phoebe777777/111",
38
  filename="microscopy_matching_cnt.pth",
@@ -42,7 +39,6 @@ def load_model(use_box=False):
42
 
43
  print(f"✅ Checkpoint downloaded: {ckpt_path}")
44
 
45
- # 加载权重
46
  MODEL.load_state_dict(
47
  torch.load(ckpt_path, map_location="cpu"),
48
  strict=True
@@ -71,20 +67,20 @@ def load_model(use_box=False):
71
  @torch.no_grad()
72
  def run(model, img_path, box=None, device="cpu", visualize=True):
73
  """
74
- 运行计数推理
75
 
76
  Args:
77
- model: 计数模型
78
- img_path: 图像路径
79
- box: 边界框 [[x1, y1, x2, y2], ...] None
80
- device: 设备
81
- visualize: 是否生成可视化
82
 
83
  Returns:
84
  result_dict: {
85
  'density_map': numpy array,
86
  'count': float,
87
- 'visualized_path': str (如果 visualize=True)
88
  }
89
  """
90
  print("DEVICE:", device)
@@ -107,7 +103,6 @@ def run(model, img_path, box=None, device="cpu", visualize=True):
107
  try:
108
  print(f"🔄 Running counting inference on {img_path}")
109
 
110
- # 运行推理 (调用你的模型的 forward 方法)
111
  with torch.no_grad():
112
  density_map, count = model(img_path, box)
113
 
@@ -118,11 +113,7 @@ def run(model, img_path, box=None, device="cpu", visualize=True):
118
  'count': count,
119
  'visualized_path': None
120
  }
121
-
122
- # 可视化
123
- # if visualize:
124
- # viz_path = visualize_result(img_path, density_map, count)
125
- # result['visualized_path'] = viz_path
126
 
127
  return result
128
 
@@ -140,47 +131,39 @@ def run(model, img_path, box=None, device="cpu", visualize=True):
140
 
141
  def visualize_result(image_path, density_map, count):
142
  """
143
- 可视化计数结果 (与你原来的可视化代码一致)
144
 
145
  Args:
146
- image_path: 原始图像路径
147
- density_map: 密度图 (numpy array)
148
- count: 计数值
149
 
150
  Returns:
151
- output_path: 可视化结果的临时文件路径
152
  """
153
  try:
154
  import skimage.io as io
155
 
156
- # 读取原始图像
157
  img = io.imread(image_path)
158
 
159
- # 处理不同格式的图像
160
  if len(img.shape) == 3 and img.shape[2] > 3:
161
  img = img[:, :, :3]
162
  if len(img.shape) == 2:
163
  img = np.stack([img]*3, axis=-1)
164
 
165
- # 归一化显示
166
  img_show = img.squeeze()
167
  density_map_show = density_map.squeeze()
168
 
169
- # 归一化图像
170
  img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8)
171
 
172
- # 创建可视化 (与你原来的代码一致)
173
  fig, ax = plt.subplots(figsize=(8, 6))
174
 
175
- # 右图: 密度图叠加
176
  ax.imshow(img_show)
177
  ax.imshow(density_map_show, cmap='jet', alpha=0.5)
178
  ax.axis('off')
179
- # ax.set_title(f"Predicted density map, count: {count:.1f}")
180
 
181
  plt.tight_layout()
182
 
183
- # 保存到临时文件
184
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
185
  plt.savefig(temp_file.name, dpi=300)
186
  plt.close()
@@ -195,13 +178,11 @@ def visualize_result(image_path, density_map, count):
195
  return image_path
196
 
197
 
198
- # ===== 测试代码 =====
199
  if __name__ == "__main__":
200
  print("="*60)
201
  print("Testing Counting Model")
202
  print("="*60)
203
-
204
- # 测试模型加载
205
  model, device = load_model(use_box=False)
206
 
207
  if model is not None:
@@ -209,7 +190,6 @@ if __name__ == "__main__":
209
  print("Model loaded successfully, testing inference...")
210
  print("="*60)
211
 
212
- # 测试推理
213
  test_image = "example_imgs/1977_Well_F-5_Field_1.png"
214
 
215
  if os.path.exists(test_image):
 
1
  # inference_count.py
 
2
 
3
  import torch
4
  import numpy as np
 
14
 
15
  def load_model(use_box=False):
16
  """
17
+ load counting model from Hugging Face Hub
18
 
19
  Args:
20
+ use_box: use bounding box as input (default: False)
21
 
22
  Returns:
23
+ model: loaded counting model
24
+ device: device
25
  """
26
  global MODEL, DEVICE
27
 
28
  try:
29
  print("🔄 Loading counting model...")
30
+
 
31
  MODEL = CountingModule(use_box=use_box)
32
 
 
33
  ckpt_path = hf_hub_download(
34
  repo_id="phoebe777777/111",
35
  filename="microscopy_matching_cnt.pth",
 
39
 
40
  print(f"✅ Checkpoint downloaded: {ckpt_path}")
41
 
 
42
  MODEL.load_state_dict(
43
  torch.load(ckpt_path, map_location="cpu"),
44
  strict=True
 
67
  @torch.no_grad()
68
  def run(model, img_path, box=None, device="cpu", visualize=True):
69
  """
70
+ Run counting inference on a single image
71
 
72
  Args:
73
+ model: loaded counting model
74
+ img_path: image path
75
+ box: bounding box [[x1, y1, x2, y2], ...] or None
76
+ device: device
77
+ visualize: whether to generate visualization
78
 
79
  Returns:
80
  result_dict: {
81
  'density_map': numpy array,
82
  'count': float,
83
+ 'visualized_path': str (if visualize=True)
84
  }
85
  """
86
  print("DEVICE:", device)
 
103
  try:
104
  print(f"🔄 Running counting inference on {img_path}")
105
 
 
106
  with torch.no_grad():
107
  density_map, count = model(img_path, box)
108
 
 
113
  'count': count,
114
  'visualized_path': None
115
  }
116
+
 
 
 
 
117
 
118
  return result
119
 
 
131
 
132
  def visualize_result(image_path, density_map, count):
133
  """
134
+ Visualize counting results (consistent with your original visualization code)
135
 
136
  Args:
137
+ image_path: original image path
138
+ density_map: numpy array of predicted density map
139
+ count
140
 
141
  Returns:
142
+ output_path: temporary file path of the visualization result
143
  """
144
  try:
145
  import skimage.io as io
146
 
 
147
  img = io.imread(image_path)
148
 
 
149
  if len(img.shape) == 3 and img.shape[2] > 3:
150
  img = img[:, :, :3]
151
  if len(img.shape) == 2:
152
  img = np.stack([img]*3, axis=-1)
153
 
 
154
  img_show = img.squeeze()
155
  density_map_show = density_map.squeeze()
156
 
 
157
  img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8)
158
 
 
159
  fig, ax = plt.subplots(figsize=(8, 6))
160
 
 
161
  ax.imshow(img_show)
162
  ax.imshow(density_map_show, cmap='jet', alpha=0.5)
163
  ax.axis('off')
 
164
 
165
  plt.tight_layout()
166
 
 
167
  temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
168
  plt.savefig(temp_file.name, dpi=300)
169
  plt.close()
 
178
  return image_path
179
 
180
 
 
181
  if __name__ == "__main__":
182
  print("="*60)
183
  print("Testing Counting Model")
184
  print("="*60)
185
+
 
186
  model, device = load_model(use_box=False)
187
 
188
  if model is not None:
 
190
  print("Model loaded successfully, testing inference...")
191
  print("="*60)
192
 
 
193
  test_image = "example_imgs/1977_Well_F-5_Field_1.png"
194
 
195
  if os.path.exists(test_image):
inference_seg.py CHANGED
@@ -43,45 +43,3 @@ def run(model, img_path, box=None, device="cpu"):
43
  output = model(img_path, box=box)
44
  mask = output
45
  return mask
46
- # import os
47
- # import torch
48
- # import numpy as np
49
- # from huggingface_hub import hf_hub_download
50
- # from segmentation import SegmentationModule
51
-
52
- # MODEL = None
53
- # DEVICE = torch.device("cpu")
54
-
55
- # def load_model(use_box=False):
56
- # global MODEL, DEVICE
57
-
58
- # # === 优化1: 使用 /data 缓存模型,避免写入 .cache ===
59
- # cache_dir = "/data/cellseg_model_cache"
60
- # os.makedirs(cache_dir, exist_ok=True)
61
-
62
- # ckpt_path = hf_hub_download(
63
- # repo_id="Shengxiao0709/cellsegmodel",
64
- # filename="microscopy_matching_seg.pth",
65
- # token=None,
66
- # local_dir=cache_dir, # ✅ 下载到 /data
67
- # local_dir_use_symlinks=False, # ✅ 避免软链接问题
68
- # force_download=False # ✅ 已存在时不重复下载
69
- # )
70
-
71
- # # === 优化2: 加载模型 ===
72
- # MODEL = SegmentationModule(use_box=use_box)
73
- # state_dict = torch.load(ckpt_path, map_location="cpu")
74
- # MODEL.load_state_dict(state_dict, strict=False)
75
- # MODEL.eval()
76
-
77
- # DEVICE = torch.device("cpu")
78
- # print(f"✅ Model loaded from {ckpt_path}")
79
- # return MODEL, DEVICE
80
-
81
-
82
- # @torch.no_grad()
83
- # def run(model, img_path, box=None, device="cpu"):
84
- # output = model(img_path, box=box)
85
- # mask = output["pred"]
86
- # mask = (mask > 0).astype(np.uint8)
87
- # return mask
 
43
  output = model(img_path, box=box)
44
  mask = output
45
  return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference_track.py CHANGED
@@ -1,5 +1,4 @@
1
  # inference_track.py
2
- # 视频跟踪模型推理模块
3
 
4
  import torch
5
  import numpy as np
@@ -15,14 +14,14 @@ DEVICE = torch.device("cpu")
15
 
16
  def load_model(use_box=False):
17
  """
18
- 加载跟踪模型
19
 
20
  Args:
21
- use_box: 是否使用边界框
22
 
23
  Returns:
24
- model: 加载的模型
25
- device: 设备
26
  """
27
  global MODEL, DEVICE
28
 
@@ -32,7 +31,7 @@ def load_model(use_box=False):
32
  # 初始化模型
33
  MODEL = TrackingModule(use_box=use_box)
34
 
35
- # Hugging Face Hub 下载权重
36
  ckpt_path = hf_hub_download(
37
  repo_id="phoebe777777/111",
38
  filename="microscopy_matching_tra.pth",
@@ -42,14 +41,14 @@ def load_model(use_box=False):
42
 
43
  print(f"✅ Checkpoint downloaded: {ckpt_path}")
44
 
45
- # 加载权重
46
  MODEL.load_state_dict(
47
  torch.load(ckpt_path, map_location="cpu"),
48
  strict=True
49
  )
50
  MODEL.eval()
51
 
52
- # 设置设备
53
  if torch.cuda.is_available():
54
  DEVICE = torch.device("cuda")
55
  MODEL.move_to_device(DEVICE)
@@ -72,21 +71,21 @@ def load_model(use_box=False):
72
  @torch.no_grad()
73
  def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
74
  """
75
- 运行视频跟踪推理
76
 
77
  Args:
78
- model: 跟踪模型
79
- video_dir: 视频帧序列目录 (包含连续的图像文件)
80
- box: 边界框 (可选)
81
- device: 设备
82
- output_dir: 输出目录
83
 
84
  Returns:
85
  result_dict: {
86
- 'track_graph': TrackGraph对象,
87
- 'masks': 分割掩码数组 (T, H, W),
88
- 'output_dir': 输出目录路径,
89
- 'num_tracks': 跟踪轨迹数量
90
  }
91
  """
92
  if model is None:
@@ -101,11 +100,11 @@ def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
101
  try:
102
  print(f"🔄 Running tracking inference on {video_dir}")
103
 
104
- # 运行跟踪
105
  track_graph, masks = model.track(
106
  file_dir=video_dir,
107
  boxes=box,
108
- mode="greedy", # 可选: "greedy", "greedy_nodiv", "ilp"
109
  dataname="tracking_result"
110
  )
111
 
@@ -113,7 +112,7 @@ def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
113
  if not os.path.exists(output_dir):
114
  os.makedirs(output_dir)
115
 
116
- # 转换为CTC格式并保存
117
  print("🔄 Converting to CTC format...")
118
  ctc_tracks, masks_tracked = graph_to_ctc(
119
  track_graph,
@@ -122,7 +121,6 @@ def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
122
  )
123
  print(f"✅ CTC results saved to {output_dir}")
124
 
125
- # num_tracks = len(track_graph.tracks())
126
 
127
  print(f"✅ Tracking completed")
128
 
@@ -131,7 +129,6 @@ def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
131
  'masks': masks,
132
  'masks_tracked': masks_tracked,
133
  'output_dir': output_dir,
134
- # 'num_tracks': num_tracks
135
  }
136
 
137
  return result
@@ -151,36 +148,35 @@ def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
151
 
152
  def visualize_tracking_result(masks_tracked, output_path):
153
  """
154
- 可视化跟踪结果 (可选)
155
 
156
  Args:
157
- masks_tracked: 跟踪后的掩码 (T, H, W)
158
- output_path: 输出视频路径
159
 
160
  Returns:
161
- output_path: 视频文件路径
162
  """
163
  try:
164
  import cv2
165
  import matplotlib.pyplot as plt
166
  from matplotlib import cm
167
 
168
- # 获取时间帧数
169
  T, H, W = masks_tracked.shape
170
 
171
- # 创建颜色映射
172
  unique_ids = np.unique(masks_tracked)
173
  num_colors = len(unique_ids)
174
  cmap = cm.get_cmap('tab20', num_colors)
175
 
176
- # 创建视频写入器
177
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
178
  out = cv2.VideoWriter(output_path, fourcc, 5.0, (W, H))
179
 
180
  for t in range(T):
181
  frame = masks_tracked[t]
182
 
183
- # 创建彩色图像
184
  colored_frame = np.zeros((H, W, 3), dtype=np.uint8)
185
  for i, obj_id in enumerate(unique_ids):
186
  if obj_id == 0:
@@ -189,7 +185,7 @@ def visualize_tracking_result(masks_tracked, output_path):
189
  color = np.array(cmap(i % num_colors)[:3]) * 255
190
  colored_frame[mask] = color
191
 
192
- # 转换为BGR (OpenCV格式)
193
  colored_frame_bgr = cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR)
194
  out.write(colored_frame_bgr)
195
 
 
1
  # inference_track.py
 
2
 
3
  import torch
4
  import numpy as np
 
14
 
15
  def load_model(use_box=False):
16
  """
17
+ load tracking model from Hugging Face Hub
18
 
19
  Args:
20
+ use_box: use bounding box as input (default: False)
21
 
22
  Returns:
23
+ model: loaded tracking model
24
+ device
25
  """
26
  global MODEL, DEVICE
27
 
 
31
  # 初始化模型
32
  MODEL = TrackingModule(use_box=use_box)
33
 
34
+ # Load checkpoint from Hugging Face Hub
35
  ckpt_path = hf_hub_download(
36
  repo_id="phoebe777777/111",
37
  filename="microscopy_matching_tra.pth",
 
41
 
42
  print(f"✅ Checkpoint downloaded: {ckpt_path}")
43
 
44
+ # Load weights
45
  MODEL.load_state_dict(
46
  torch.load(ckpt_path, map_location="cpu"),
47
  strict=True
48
  )
49
  MODEL.eval()
50
 
51
+ # Move model to device
52
  if torch.cuda.is_available():
53
  DEVICE = torch.device("cuda")
54
  MODEL.move_to_device(DEVICE)
 
71
  @torch.no_grad()
72
  def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
73
  """
74
+ run tracking inference on video frames
75
 
76
  Args:
77
+ model: loaded tracking model
78
+ video_dir: directory of video frame sequence (contains consecutive image files)
79
+ box: bounding box (optional)
80
+ device: device
81
+ output_dir: output directory
82
 
83
  Returns:
84
  result_dict: {
85
+ 'track_graph': TrackGraph object containing tracking results,
86
+ 'masks': tracked masks (T, H, W),
87
+ 'output_dir': output directory path,
88
+ 'num_tracks': number of tracked trajectories
89
  }
90
  """
91
  if model is None:
 
100
  try:
101
  print(f"🔄 Running tracking inference on {video_dir}")
102
 
103
+ # Run tracking
104
  track_graph, masks = model.track(
105
  file_dir=video_dir,
106
  boxes=box,
107
+ mode="greedy", # Optional: "greedy", "greedy_nodiv", "ilp"
108
  dataname="tracking_result"
109
  )
110
 
 
112
  if not os.path.exists(output_dir):
113
  os.makedirs(output_dir)
114
 
115
+ # Convert tracking results to CTC format and save
116
  print("🔄 Converting to CTC format...")
117
  ctc_tracks, masks_tracked = graph_to_ctc(
118
  track_graph,
 
121
  )
122
  print(f"✅ CTC results saved to {output_dir}")
123
 
 
124
 
125
  print(f"✅ Tracking completed")
126
 
 
129
  'masks': masks,
130
  'masks_tracked': masks_tracked,
131
  'output_dir': output_dir,
 
132
  }
133
 
134
  return result
 
148
 
149
  def visualize_tracking_result(masks_tracked, output_path):
150
  """
151
+ visualize tracking results
152
 
153
  Args:
154
+ masks_tracked: masks with tracking results (T, H, W)
155
+ output_path: output video file path
156
 
157
  Returns:
158
+ output_path: output video file path
159
  """
160
  try:
161
  import cv2
162
  import matplotlib.pyplot as plt
163
  from matplotlib import cm
164
 
 
165
  T, H, W = masks_tracked.shape
166
 
167
+ # create a color map for unique track IDs
168
  unique_ids = np.unique(masks_tracked)
169
  num_colors = len(unique_ids)
170
  cmap = cm.get_cmap('tab20', num_colors)
171
 
172
+ # create video writer
173
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
174
  out = cv2.VideoWriter(output_path, fourcc, 5.0, (W, H))
175
 
176
  for t in range(T):
177
  frame = masks_tracked[t]
178
 
179
+ # create colored image
180
  colored_frame = np.zeros((H, W, 3), dtype=np.uint8)
181
  for i, obj_id in enumerate(unique_ids):
182
  if obj_id == 0:
 
185
  color = np.array(cmap(i % num_colors)[:3]) * 255
186
  colored_frame[mask] = color
187
 
188
+ # convert to BGR (OpenCV format)
189
  colored_frame_bgr = cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR)
190
  out.write(colored_frame_bgr)
191