Spaces:
Running on Zero
Running on Zero
Commit ·
0b472f0
1
Parent(s): aff3c6f
update
Browse files- app.py +30 -81
- inference_count.py +19 -39
- inference_seg.py +0 -42
- inference_track.py +28 -32
app.py
CHANGED
|
@@ -18,7 +18,6 @@ from natsort import natsorted
|
|
| 18 |
from huggingface_hub import HfApi, upload_file
|
| 19 |
# import spaces
|
| 20 |
|
| 21 |
-
# ===== 导入三个推理模块 =====
|
| 22 |
from inference_seg import load_model as load_seg_model, run as run_seg
|
| 23 |
from inference_count import load_model as load_count_model, run as run_count
|
| 24 |
from inference_track import load_model as load_track_model, run as run_track
|
|
@@ -27,7 +26,6 @@ HF_TOKEN = os.getenv("HF_TOKEN")
|
|
| 27 |
DATASET_REPO = "phoebe777777/celltool_feedback"
|
| 28 |
|
| 29 |
|
| 30 |
-
# ===== 清理缓存目录 =====
|
| 31 |
print("===== clearing cache =====")
|
| 32 |
# cache_path = os.path.expanduser("~/.cache/")
|
| 33 |
cache_path = os.path.expanduser("~/.cache/huggingface/gradio")
|
|
@@ -39,7 +37,6 @@ if os.path.exists(cache_path):
|
|
| 39 |
except:
|
| 40 |
pass
|
| 41 |
|
| 42 |
-
# ===== 全局模型变量 =====
|
| 43 |
SEG_MODEL = None
|
| 44 |
SEG_DEVICE = torch.device("cpu")
|
| 45 |
|
|
@@ -50,7 +47,6 @@ TRACK_MODEL = None
|
|
| 50 |
TRACK_DEVICE = torch.device("cpu")
|
| 51 |
|
| 52 |
def load_all_models():
|
| 53 |
-
"""启动时加载所有模型"""
|
| 54 |
global SEG_MODEL, SEG_DEVICE
|
| 55 |
global COUNT_MODEL, COUNT_DEVICE
|
| 56 |
global TRACK_MODEL, TRACK_DEVICE
|
|
@@ -76,14 +72,12 @@ def load_all_models():
|
|
| 76 |
|
| 77 |
load_all_models()
|
| 78 |
|
| 79 |
-
# ===== 保存用户反馈 =====
|
| 80 |
DATASET_DIR = Path("solver_cache")
|
| 81 |
DATASET_DIR.mkdir(parents=True, exist_ok=True)
|
| 82 |
|
| 83 |
def save_feedback_to_hf(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
|
| 84 |
-
"""
|
| 85 |
|
| 86 |
-
# 如果没有 token,回退到本地存储
|
| 87 |
if not HF_TOKEN:
|
| 88 |
print("⚠️ No HF_TOKEN found, using local storage")
|
| 89 |
save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes)
|
|
@@ -102,13 +96,11 @@ def save_feedback_to_hf(query_id, feedback_type, feedback_text=None, img_path=No
|
|
| 102 |
try:
|
| 103 |
api = HfApi()
|
| 104 |
|
| 105 |
-
# 创建临时文件
|
| 106 |
filename = f"feedback_{query_id}_{int(time.time())}.json"
|
| 107 |
|
| 108 |
with open(filename, 'w', encoding='utf-8') as f:
|
| 109 |
json.dump(feedback_data, f, indent=2, ensure_ascii=False)
|
| 110 |
|
| 111 |
-
# 上传到 dataset
|
| 112 |
api.upload_file(
|
| 113 |
path_or_fileobj=filename,
|
| 114 |
path_in_repo=f"data/{filename}",
|
|
@@ -117,19 +109,17 @@ def save_feedback_to_hf(query_id, feedback_type, feedback_text=None, img_path=No
|
|
| 117 |
token=HF_TOKEN
|
| 118 |
)
|
| 119 |
|
| 120 |
-
# 清理本地文件
|
| 121 |
os.remove(filename)
|
| 122 |
|
| 123 |
print(f"✅ Feedback saved to HF Dataset: {DATASET_REPO}")
|
| 124 |
|
| 125 |
except Exception as e:
|
| 126 |
print(f"⚠️ Failed to save to HF Dataset: {e}")
|
| 127 |
-
# 回退到本地存储
|
| 128 |
save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes)
|
| 129 |
|
| 130 |
|
| 131 |
def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
|
| 132 |
-
"""
|
| 133 |
feedback_data = {
|
| 134 |
"query_id": query_id,
|
| 135 |
"feedback_type": feedback_type,
|
|
@@ -154,9 +144,8 @@ def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bb
|
|
| 154 |
with feedback_file.open("w") as f:
|
| 155 |
json.dump(feedback_data, f, indent=4, ensure_ascii=False)
|
| 156 |
|
| 157 |
-
# ===== 辅助函数 =====
|
| 158 |
def parse_first_bbox(bboxes):
|
| 159 |
-
"""
|
| 160 |
if not bboxes:
|
| 161 |
return None
|
| 162 |
b = bboxes[0]
|
|
@@ -169,7 +158,7 @@ def parse_first_bbox(bboxes):
|
|
| 169 |
return None
|
| 170 |
|
| 171 |
def parse_bboxes(bboxes):
|
| 172 |
-
"""
|
| 173 |
if not bboxes:
|
| 174 |
return None
|
| 175 |
|
|
@@ -185,7 +174,7 @@ def parse_bboxes(bboxes):
|
|
| 185 |
return result
|
| 186 |
|
| 187 |
def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
|
| 188 |
-
"""
|
| 189 |
def hsv_to_rgb(h, s, v):
|
| 190 |
i = int(h * 6.0)
|
| 191 |
f = h * 6.0 - i
|
|
@@ -210,10 +199,9 @@ def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
|
|
| 210 |
color_idx = mask % num_colors
|
| 211 |
return palette_arr[color_idx]
|
| 212 |
|
| 213 |
-
# ===== 分割功能 =====
|
| 214 |
# @spaces.GPU
|
| 215 |
def segment_with_choice(use_box_choice, annot_value):
|
| 216 |
-
"""
|
| 217 |
if annot_value is None or len(annot_value) < 1:
|
| 218 |
print("❌ No annotation input")
|
| 219 |
return None, None
|
|
@@ -224,18 +212,12 @@ def segment_with_choice(use_box_choice, annot_value):
|
|
| 224 |
print(f"🖼️ Image path: {img_path}")
|
| 225 |
box_array = None
|
| 226 |
if use_box_choice == "Yes" and bboxes:
|
| 227 |
-
# box = parse_first_bbox(bboxes)
|
| 228 |
-
# if box:
|
| 229 |
-
# xmin, ymin, xmax, ymax = map(int, box)
|
| 230 |
-
# box_array = [[xmin, ymin, xmax, ymax]]
|
| 231 |
-
# print(f"📦 Using bounding box: {box_array}")
|
| 232 |
box = parse_bboxes(bboxes)
|
| 233 |
if box:
|
| 234 |
box_array = box
|
| 235 |
print(f"📦 Using bounding boxes: {box_array}")
|
| 236 |
|
| 237 |
|
| 238 |
-
# 运行分割模型
|
| 239 |
try:
|
| 240 |
mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
|
| 241 |
print("📏 mask shape:", mask.shape, "dtype:", mask.dtype, "unique:", np.unique(mask))
|
|
@@ -243,13 +225,11 @@ def segment_with_choice(use_box_choice, annot_value):
|
|
| 243 |
print(f"❌ Inference failed: {str(e)}")
|
| 244 |
return None, None
|
| 245 |
|
| 246 |
-
# 保存原始mask为TIF文件
|
| 247 |
temp_mask_file = tempfile.NamedTemporaryFile(delete=False, suffix=".tif")
|
| 248 |
mask_img = Image.fromarray(mask.astype(np.uint16))
|
| 249 |
mask_img.save(temp_mask_file.name)
|
| 250 |
print(f"💾 Original mask saved to: {temp_mask_file.name}")
|
| 251 |
|
| 252 |
-
# 读取原图
|
| 253 |
try:
|
| 254 |
img = Image.open(img_path)
|
| 255 |
print("📷 Image mode:", img.mode, "size:", img.size)
|
|
@@ -276,24 +256,19 @@ def segment_with_choice(use_box_choice, annot_value):
|
|
| 276 |
print("⚠️ No instance found, returning dummy red image")
|
| 277 |
return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None
|
| 278 |
|
| 279 |
-
# ==== Color Overlay (每个实例一个颜色) ====
|
| 280 |
overlay = img_np.copy()
|
| 281 |
alpha = 0.5
|
| 282 |
-
# cmap = cm.get_cmap("hsv", num_instances + 1)
|
| 283 |
|
| 284 |
for inst_id in np.unique(inst_mask):
|
| 285 |
if inst_id == 0:
|
| 286 |
continue
|
| 287 |
binary_mask = (inst_mask == inst_id).astype(np.uint8)
|
| 288 |
-
# color = np.array(cmap(inst_id / (num_instances + 1))[:3]) # RGB only, ignore alpha
|
| 289 |
color = get_well_spaced_color(inst_id)
|
| 290 |
overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color
|
| 291 |
|
| 292 |
-
# 绘制轮廓
|
| 293 |
contours = measure.find_contours(binary_mask, 0.5)
|
| 294 |
for contour in contours:
|
| 295 |
contour = contour.astype(np.int32)
|
| 296 |
-
# 确保坐标在范围内
|
| 297 |
valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1)
|
| 298 |
valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1)
|
| 299 |
overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] # 黄色轮廓
|
|
@@ -302,7 +277,7 @@ def segment_with_choice(use_box_choice, annot_value):
|
|
| 302 |
|
| 303 |
return Image.fromarray(overlay), temp_mask_file.name
|
| 304 |
|
| 305 |
-
|
| 306 |
# @spaces.GPU
|
| 307 |
def count_cells_handler(use_box_choice, annot_value):
|
| 308 |
"""Counting handler - supports bounding box, returns only density map"""
|
|
@@ -315,11 +290,6 @@ def count_cells_handler(use_box_choice, annot_value):
|
|
| 315 |
print(f"🖼️ Image path: {image_path}")
|
| 316 |
box_array = None
|
| 317 |
if use_box_choice == "Yes" and bboxes:
|
| 318 |
-
# box = parse_first_bbox(bboxes)
|
| 319 |
-
# if box:
|
| 320 |
-
# xmin, ymin, xmax, ymax = map(int, box)
|
| 321 |
-
# box_array = [[xmin, ymin, xmax, ymax]]
|
| 322 |
-
# print(f"📦 Using bounding box: {box_array}")
|
| 323 |
box = parse_bboxes(bboxes)
|
| 324 |
if box:
|
| 325 |
box_array = box
|
|
@@ -341,7 +311,6 @@ def count_cells_handler(use_box_choice, annot_value):
|
|
| 341 |
|
| 342 |
count = result['count']
|
| 343 |
density_map = result['density_map']
|
| 344 |
-
# save density map as temp file
|
| 345 |
temp_density_file = tempfile.NamedTemporaryFile(delete=False, suffix=".npy")
|
| 346 |
np.save(temp_density_file.name, density_map)
|
| 347 |
print(f"💾 Density map saved to {temp_density_file.name}")
|
|
@@ -365,20 +334,16 @@ def count_cells_handler(use_box_choice, annot_value):
|
|
| 365 |
return None, None
|
| 366 |
|
| 367 |
|
| 368 |
-
# Normalize density map to [0, 1]
|
| 369 |
density_normalized = density_map.copy()
|
| 370 |
if density_normalized.max() > 0:
|
| 371 |
density_normalized = (density_normalized - density_normalized.min()) / (density_normalized.max() - density_normalized.min())
|
| 372 |
|
| 373 |
-
# Apply colormap
|
| 374 |
cmap = cm.get_cmap("jet")
|
| 375 |
alpha = 0.3
|
| 376 |
density_colored = cmap(density_normalized)[:, :, :3] # RGB only, ignore alpha
|
| 377 |
|
| 378 |
-
# Create overlay
|
| 379 |
overlay = img_np.copy()
|
| 380 |
|
| 381 |
-
# Blend only where density is significant (optional: threshold)
|
| 382 |
threshold = 0.01 # Only overlay where density > 1% of max
|
| 383 |
significant_mask = density_normalized > threshold
|
| 384 |
|
|
@@ -400,7 +365,6 @@ def count_cells_handler(use_box_choice, annot_value):
|
|
| 400 |
|
| 401 |
return Image.fromarray(overlay), temp_density_file.name, result_text
|
| 402 |
|
| 403 |
-
# return density_path, result_text
|
| 404 |
|
| 405 |
except Exception as e:
|
| 406 |
print(f"❌ Counting error: {e}")
|
|
@@ -408,7 +372,7 @@ def count_cells_handler(use_box_choice, annot_value):
|
|
| 408 |
traceback.print_exc()
|
| 409 |
return None, f"❌ Counting failed: {str(e)}"
|
| 410 |
|
| 411 |
-
|
| 412 |
def find_tif_dir(root_dir):
|
| 413 |
"""Recursively find the first directory containing .tif files"""
|
| 414 |
for dirpath, _, filenames in os.walk(root_dir):
|
|
@@ -502,14 +466,13 @@ def create_ctc_results_zip(output_dir):
|
|
| 502 |
print(f"✅ ZIP created: {zip_path} ({os.path.getsize(zip_path) / 1024:.1f} KB)")
|
| 503 |
return zip_path
|
| 504 |
|
| 505 |
-
|
| 506 |
def get_well_spaced_color(track_id, num_colors=256):
|
| 507 |
"""Generate well-spaced colors, using contrasting colors for adjacent IDs"""
|
| 508 |
-
|
| 509 |
golden_ratio = 0.618033988749895
|
| 510 |
hue = (track_id * golden_ratio) % 1.0
|
| 511 |
-
|
| 512 |
-
# 使用高饱和度和明度
|
| 513 |
import colorsys
|
| 514 |
rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
|
| 515 |
return np.array(rgb)
|
|
@@ -568,14 +531,7 @@ def create_tracking_visualization(tif_dir, output_dir, valid_tif_files):
|
|
| 568 |
return valid_tif_files[0]
|
| 569 |
|
| 570 |
print(f"📊 Found {len(mask_files)} mask files")
|
| 571 |
-
|
| 572 |
-
# Create color map for consistent track IDs
|
| 573 |
-
# Use a colormap with many distinct colors
|
| 574 |
-
# try:
|
| 575 |
-
# cmap = colormaps.get_cmap("hsv")
|
| 576 |
-
# except:
|
| 577 |
-
# from matplotlib import cm
|
| 578 |
-
# cmap = cm.get_cmap("hsv")
|
| 579 |
|
| 580 |
frames = []
|
| 581 |
alpha = 0.3 # Transparency for overlay
|
|
@@ -710,19 +666,19 @@ def create_tracking_visualization(tif_dir, output_dir, valid_tif_files):
|
|
| 710 |
# @spaces.GPU
|
| 711 |
def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj):
|
| 712 |
"""
|
| 713 |
-
|
| 714 |
|
| 715 |
Parameters:
|
| 716 |
-----------
|
| 717 |
use_box_choice : str
|
| 718 |
-
"Yes" or "No" -
|
| 719 |
first_frame_annot : tuple or None
|
| 720 |
(image_path, bboxes) from BBoxAnnotator, only used if user annotated first frame
|
| 721 |
zip_file_obj : File
|
| 722 |
Uploaded ZIP file containing TIF sequence
|
| 723 |
"""
|
| 724 |
if zip_file_obj is None:
|
| 725 |
-
return None, "⚠️
|
| 726 |
|
| 727 |
temp_dir = None
|
| 728 |
output_temp_dir = None
|
|
@@ -734,11 +690,6 @@ def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj):
|
|
| 734 |
if isinstance(first_frame_annot, (list, tuple)) and len(first_frame_annot) > 1:
|
| 735 |
bboxes = first_frame_annot[1]
|
| 736 |
if bboxes:
|
| 737 |
-
# box = parse_first_bbox(bboxes)
|
| 738 |
-
# if box:
|
| 739 |
-
# xmin, ymin, xmax, ymax = map(int, box)
|
| 740 |
-
# box_array = [[xmin, ymin, xmax, ymax]]
|
| 741 |
-
# print(f"📦 Using bounding box: {box_array}")
|
| 742 |
box = parse_bboxes(bboxes)
|
| 743 |
if box:
|
| 744 |
box_array = box
|
|
@@ -880,9 +831,8 @@ def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj):
|
|
| 880 |
|
| 881 |
|
| 882 |
|
| 883 |
-
# =====
|
| 884 |
example_images_seg = [f for f in glob("example_imgs/seg/*")]
|
| 885 |
-
# ["example_imgs/seg/003_img.png", "example_imgs/seg/1977_Well_F-5_Field_1.png"]
|
| 886 |
example_images_cnt = [f for f in glob("example_imgs/cnt/*")]
|
| 887 |
example_tracking_zips = [f for f in glob("example_imgs/tra/*.zip")]
|
| 888 |
|
|
@@ -909,7 +859,6 @@ with gr.Blocks(
|
|
| 909 |
object-fit: contain !important;
|
| 910 |
}
|
| 911 |
|
| 912 |
-
/* 强制密度图容器和图片高度 */
|
| 913 |
#density_map_output {
|
| 914 |
height: 500px !important;
|
| 915 |
}
|
|
@@ -939,7 +888,7 @@ with gr.Blocks(
|
|
| 939 |
|
| 940 |
# 全局状态
|
| 941 |
current_query_id = gr.State(str(uuid.uuid4()))
|
| 942 |
-
user_uploaded_examples = gr.State(example_images_seg.copy())
|
| 943 |
|
| 944 |
with gr.Tabs():
|
| 945 |
# ===== Tab 1: Segmentation =====
|
|
@@ -1031,27 +980,27 @@ with gr.Blocks(
|
|
| 1031 |
visible=False
|
| 1032 |
)
|
| 1033 |
|
| 1034 |
-
#
|
| 1035 |
run_seg_btn.click(
|
| 1036 |
fn=segment_with_choice,
|
| 1037 |
inputs=[use_box_radio, annotator],
|
| 1038 |
outputs=[seg_output, download_mask_btn]
|
| 1039 |
)
|
| 1040 |
|
| 1041 |
-
#
|
| 1042 |
clear_btn.click(
|
| 1043 |
fn=lambda: None,
|
| 1044 |
inputs=None,
|
| 1045 |
outputs=annotator
|
| 1046 |
)
|
| 1047 |
|
| 1048 |
-
#
|
| 1049 |
demo.load(
|
| 1050 |
fn=lambda: example_images_seg.copy(),
|
| 1051 |
outputs=example_gallery
|
| 1052 |
)
|
| 1053 |
|
| 1054 |
-
#
|
| 1055 |
def add_to_gallery(img_path, current_imgs):
|
| 1056 |
if not img_path:
|
| 1057 |
return current_imgs
|
|
@@ -1072,7 +1021,7 @@ with gr.Blocks(
|
|
| 1072 |
outputs=example_gallery
|
| 1073 |
)
|
| 1074 |
|
| 1075 |
-
#
|
| 1076 |
def load_from_gallery(evt: gr.SelectData, all_imgs):
|
| 1077 |
if evt.index is not None and evt.index < len(all_imgs):
|
| 1078 |
return all_imgs[evt.index]
|
|
@@ -1084,7 +1033,7 @@ with gr.Blocks(
|
|
| 1084 |
outputs=annotator
|
| 1085 |
)
|
| 1086 |
|
| 1087 |
-
#
|
| 1088 |
def submit_user_feedback(query_id, score, comment, annot_val):
|
| 1089 |
try:
|
| 1090 |
img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
|
|
@@ -1097,7 +1046,7 @@ with gr.Blocks(
|
|
| 1097 |
# img_path=img_path,
|
| 1098 |
# bboxes=bboxes
|
| 1099 |
# )
|
| 1100 |
-
|
| 1101 |
save_feedback_to_hf(
|
| 1102 |
query_id=query_id,
|
| 1103 |
feedback_type=f"score_{int(score)}",
|
|
@@ -1259,14 +1208,14 @@ with gr.Blocks(
|
|
| 1259 |
outputs=[count_output, download_density_btn, count_status]
|
| 1260 |
)
|
| 1261 |
|
| 1262 |
-
#
|
| 1263 |
clear_btn.click(
|
| 1264 |
fn=lambda: None,
|
| 1265 |
inputs=None,
|
| 1266 |
outputs=count_annotator
|
| 1267 |
)
|
| 1268 |
|
| 1269 |
-
#
|
| 1270 |
def submit_user_feedback(query_id, score, comment, annot_val):
|
| 1271 |
try:
|
| 1272 |
img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
|
|
@@ -1279,7 +1228,7 @@ with gr.Blocks(
|
|
| 1279 |
# img_path=img_path,
|
| 1280 |
# bboxes=bboxes
|
| 1281 |
# )
|
| 1282 |
-
|
| 1283 |
save_feedback_to_hf(
|
| 1284 |
query_id=query_id,
|
| 1285 |
feedback_type=f"score_{int(score)}",
|
|
@@ -1581,14 +1530,14 @@ with gr.Blocks(
|
|
| 1581 |
outputs=[track_download, track_output, track_download, track_first_frame_preview]
|
| 1582 |
)
|
| 1583 |
|
| 1584 |
-
#
|
| 1585 |
clear_btn.click(
|
| 1586 |
fn=lambda: None,
|
| 1587 |
inputs=None,
|
| 1588 |
outputs=track_first_frame_annotator
|
| 1589 |
)
|
| 1590 |
|
| 1591 |
-
#
|
| 1592 |
def submit_user_feedback(query_id, score, comment, annot_val):
|
| 1593 |
try:
|
| 1594 |
img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
|
|
@@ -1601,7 +1550,7 @@ with gr.Blocks(
|
|
| 1601 |
# img_path=img_path,
|
| 1602 |
# bboxes=bboxes
|
| 1603 |
# )
|
| 1604 |
-
|
| 1605 |
save_feedback_to_hf(
|
| 1606 |
query_id=query_id,
|
| 1607 |
feedback_type=f"score_{int(score)}",
|
|
|
|
| 18 |
from huggingface_hub import HfApi, upload_file
|
| 19 |
# import spaces
|
| 20 |
|
|
|
|
| 21 |
from inference_seg import load_model as load_seg_model, run as run_seg
|
| 22 |
from inference_count import load_model as load_count_model, run as run_count
|
| 23 |
from inference_track import load_model as load_track_model, run as run_track
|
|
|
|
| 26 |
DATASET_REPO = "phoebe777777/celltool_feedback"
|
| 27 |
|
| 28 |
|
|
|
|
| 29 |
print("===== clearing cache =====")
|
| 30 |
# cache_path = os.path.expanduser("~/.cache/")
|
| 31 |
cache_path = os.path.expanduser("~/.cache/huggingface/gradio")
|
|
|
|
| 37 |
except:
|
| 38 |
pass
|
| 39 |
|
|
|
|
| 40 |
SEG_MODEL = None
|
| 41 |
SEG_DEVICE = torch.device("cpu")
|
| 42 |
|
|
|
|
| 47 |
TRACK_DEVICE = torch.device("cpu")
|
| 48 |
|
| 49 |
def load_all_models():
|
|
|
|
| 50 |
global SEG_MODEL, SEG_DEVICE
|
| 51 |
global COUNT_MODEL, COUNT_DEVICE
|
| 52 |
global TRACK_MODEL, TRACK_DEVICE
|
|
|
|
| 72 |
|
| 73 |
load_all_models()
|
| 74 |
|
|
|
|
| 75 |
DATASET_DIR = Path("solver_cache")
|
| 76 |
DATASET_DIR.mkdir(parents=True, exist_ok=True)
|
| 77 |
|
| 78 |
def save_feedback_to_hf(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
|
| 79 |
+
"""Save feedback to Hugging Face Dataset"""
|
| 80 |
|
|
|
|
| 81 |
if not HF_TOKEN:
|
| 82 |
print("⚠️ No HF_TOKEN found, using local storage")
|
| 83 |
save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes)
|
|
|
|
| 96 |
try:
|
| 97 |
api = HfApi()
|
| 98 |
|
|
|
|
| 99 |
filename = f"feedback_{query_id}_{int(time.time())}.json"
|
| 100 |
|
| 101 |
with open(filename, 'w', encoding='utf-8') as f:
|
| 102 |
json.dump(feedback_data, f, indent=2, ensure_ascii=False)
|
| 103 |
|
|
|
|
| 104 |
api.upload_file(
|
| 105 |
path_or_fileobj=filename,
|
| 106 |
path_in_repo=f"data/{filename}",
|
|
|
|
| 109 |
token=HF_TOKEN
|
| 110 |
)
|
| 111 |
|
|
|
|
| 112 |
os.remove(filename)
|
| 113 |
|
| 114 |
print(f"✅ Feedback saved to HF Dataset: {DATASET_REPO}")
|
| 115 |
|
| 116 |
except Exception as e:
|
| 117 |
print(f"⚠️ Failed to save to HF Dataset: {e}")
|
|
|
|
| 118 |
save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes)
|
| 119 |
|
| 120 |
|
| 121 |
def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
|
| 122 |
+
"""Save feedback to local JSON file"""
|
| 123 |
feedback_data = {
|
| 124 |
"query_id": query_id,
|
| 125 |
"feedback_type": feedback_type,
|
|
|
|
| 144 |
with feedback_file.open("w") as f:
|
| 145 |
json.dump(feedback_data, f, indent=4, ensure_ascii=False)
|
| 146 |
|
|
|
|
| 147 |
def parse_first_bbox(bboxes):
|
| 148 |
+
"""Parse the first bounding box from the annotation input, supports dict or list format"""
|
| 149 |
if not bboxes:
|
| 150 |
return None
|
| 151 |
b = bboxes[0]
|
|
|
|
| 158 |
return None
|
| 159 |
|
| 160 |
def parse_bboxes(bboxes):
|
| 161 |
+
"""Parse all bounding boxes from the annotation input"""
|
| 162 |
if not bboxes:
|
| 163 |
return None
|
| 164 |
|
|
|
|
| 174 |
return result
|
| 175 |
|
| 176 |
def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
|
| 177 |
+
"""Convert a 2D mask of instance IDs to a color image for visualization."""
|
| 178 |
def hsv_to_rgb(h, s, v):
|
| 179 |
i = int(h * 6.0)
|
| 180 |
f = h * 6.0 - i
|
|
|
|
| 199 |
color_idx = mask % num_colors
|
| 200 |
return palette_arr[color_idx]
|
| 201 |
|
|
|
|
| 202 |
# @spaces.GPU
|
| 203 |
def segment_with_choice(use_box_choice, annot_value):
|
| 204 |
+
"""Segmentation handler - supports bounding box, returns colorized overlay and original mask path"""
|
| 205 |
if annot_value is None or len(annot_value) < 1:
|
| 206 |
print("❌ No annotation input")
|
| 207 |
return None, None
|
|
|
|
| 212 |
print(f"🖼️ Image path: {img_path}")
|
| 213 |
box_array = None
|
| 214 |
if use_box_choice == "Yes" and bboxes:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
box = parse_bboxes(bboxes)
|
| 216 |
if box:
|
| 217 |
box_array = box
|
| 218 |
print(f"📦 Using bounding boxes: {box_array}")
|
| 219 |
|
| 220 |
|
|
|
|
| 221 |
try:
|
| 222 |
mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
|
| 223 |
print("📏 mask shape:", mask.shape, "dtype:", mask.dtype, "unique:", np.unique(mask))
|
|
|
|
| 225 |
print(f"❌ Inference failed: {str(e)}")
|
| 226 |
return None, None
|
| 227 |
|
|
|
|
| 228 |
temp_mask_file = tempfile.NamedTemporaryFile(delete=False, suffix=".tif")
|
| 229 |
mask_img = Image.fromarray(mask.astype(np.uint16))
|
| 230 |
mask_img.save(temp_mask_file.name)
|
| 231 |
print(f"💾 Original mask saved to: {temp_mask_file.name}")
|
| 232 |
|
|
|
|
| 233 |
try:
|
| 234 |
img = Image.open(img_path)
|
| 235 |
print("📷 Image mode:", img.mode, "size:", img.size)
|
|
|
|
| 256 |
print("⚠️ No instance found, returning dummy red image")
|
| 257 |
return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None
|
| 258 |
|
|
|
|
| 259 |
overlay = img_np.copy()
|
| 260 |
alpha = 0.5
|
|
|
|
| 261 |
|
| 262 |
for inst_id in np.unique(inst_mask):
|
| 263 |
if inst_id == 0:
|
| 264 |
continue
|
| 265 |
binary_mask = (inst_mask == inst_id).astype(np.uint8)
|
|
|
|
| 266 |
color = get_well_spaced_color(inst_id)
|
| 267 |
overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color
|
| 268 |
|
|
|
|
| 269 |
contours = measure.find_contours(binary_mask, 0.5)
|
| 270 |
for contour in contours:
|
| 271 |
contour = contour.astype(np.int32)
|
|
|
|
| 272 |
valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1)
|
| 273 |
valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1)
|
| 274 |
overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] # 黄色轮廓
|
|
|
|
| 277 |
|
| 278 |
return Image.fromarray(overlay), temp_mask_file.name
|
| 279 |
|
| 280 |
+
|
| 281 |
# @spaces.GPU
|
| 282 |
def count_cells_handler(use_box_choice, annot_value):
|
| 283 |
"""Counting handler - supports bounding box, returns only density map"""
|
|
|
|
| 290 |
print(f"🖼️ Image path: {image_path}")
|
| 291 |
box_array = None
|
| 292 |
if use_box_choice == "Yes" and bboxes:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
box = parse_bboxes(bboxes)
|
| 294 |
if box:
|
| 295 |
box_array = box
|
|
|
|
| 311 |
|
| 312 |
count = result['count']
|
| 313 |
density_map = result['density_map']
|
|
|
|
| 314 |
temp_density_file = tempfile.NamedTemporaryFile(delete=False, suffix=".npy")
|
| 315 |
np.save(temp_density_file.name, density_map)
|
| 316 |
print(f"💾 Density map saved to {temp_density_file.name}")
|
|
|
|
| 334 |
return None, None
|
| 335 |
|
| 336 |
|
|
|
|
| 337 |
density_normalized = density_map.copy()
|
| 338 |
if density_normalized.max() > 0:
|
| 339 |
density_normalized = (density_normalized - density_normalized.min()) / (density_normalized.max() - density_normalized.min())
|
| 340 |
|
|
|
|
| 341 |
cmap = cm.get_cmap("jet")
|
| 342 |
alpha = 0.3
|
| 343 |
density_colored = cmap(density_normalized)[:, :, :3] # RGB only, ignore alpha
|
| 344 |
|
|
|
|
| 345 |
overlay = img_np.copy()
|
| 346 |
|
|
|
|
| 347 |
threshold = 0.01 # Only overlay where density > 1% of max
|
| 348 |
significant_mask = density_normalized > threshold
|
| 349 |
|
|
|
|
| 365 |
|
| 366 |
return Image.fromarray(overlay), temp_density_file.name, result_text
|
| 367 |
|
|
|
|
| 368 |
|
| 369 |
except Exception as e:
|
| 370 |
print(f"❌ Counting error: {e}")
|
|
|
|
| 372 |
traceback.print_exc()
|
| 373 |
return None, f"❌ Counting failed: {str(e)}"
|
| 374 |
|
| 375 |
+
|
| 376 |
def find_tif_dir(root_dir):
|
| 377 |
"""Recursively find the first directory containing .tif files"""
|
| 378 |
for dirpath, _, filenames in os.walk(root_dir):
|
|
|
|
| 466 |
print(f"✅ ZIP created: {zip_path} ({os.path.getsize(zip_path) / 1024:.1f} KB)")
|
| 467 |
return zip_path
|
| 468 |
|
| 469 |
+
|
| 470 |
def get_well_spaced_color(track_id, num_colors=256):
|
| 471 |
"""Generate well-spaced colors, using contrasting colors for adjacent IDs"""
|
| 472 |
+
|
| 473 |
golden_ratio = 0.618033988749895
|
| 474 |
hue = (track_id * golden_ratio) % 1.0
|
| 475 |
+
|
|
|
|
| 476 |
import colorsys
|
| 477 |
rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
|
| 478 |
return np.array(rgb)
|
|
|
|
| 531 |
return valid_tif_files[0]
|
| 532 |
|
| 533 |
print(f"📊 Found {len(mask_files)} mask files")
|
| 534 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
|
| 536 |
frames = []
|
| 537 |
alpha = 0.3 # Transparency for overlay
|
|
|
|
| 666 |
# @spaces.GPU
|
| 667 |
def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj):
|
| 668 |
"""
|
| 669 |
+
Tracking handler - processes a ZIP of TIF frames, supports bounding box, returns visualization and results ZIP
|
| 670 |
|
| 671 |
Parameters:
|
| 672 |
-----------
|
| 673 |
use_box_choice : str
|
| 674 |
+
"Yes" or "No" - whether to use bounding box annotation for tracking
|
| 675 |
first_frame_annot : tuple or None
|
| 676 |
(image_path, bboxes) from BBoxAnnotator, only used if user annotated first frame
|
| 677 |
zip_file_obj : File
|
| 678 |
Uploaded ZIP file containing TIF sequence
|
| 679 |
"""
|
| 680 |
if zip_file_obj is None:
|
| 681 |
+
return None, "⚠️ Please upload a ZIP file containing video frames (.zip)", None, None
|
| 682 |
|
| 683 |
temp_dir = None
|
| 684 |
output_temp_dir = None
|
|
|
|
| 690 |
if isinstance(first_frame_annot, (list, tuple)) and len(first_frame_annot) > 1:
|
| 691 |
bboxes = first_frame_annot[1]
|
| 692 |
if bboxes:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 693 |
box = parse_bboxes(bboxes)
|
| 694 |
if box:
|
| 695 |
box_array = box
|
|
|
|
| 831 |
|
| 832 |
|
| 833 |
|
| 834 |
+
# ===== Example Images =====
|
| 835 |
example_images_seg = [f for f in glob("example_imgs/seg/*")]
|
|
|
|
| 836 |
example_images_cnt = [f for f in glob("example_imgs/cnt/*")]
|
| 837 |
example_tracking_zips = [f for f in glob("example_imgs/tra/*.zip")]
|
| 838 |
|
|
|
|
| 859 |
object-fit: contain !important;
|
| 860 |
}
|
| 861 |
|
|
|
|
| 862 |
#density_map_output {
|
| 863 |
height: 500px !important;
|
| 864 |
}
|
|
|
|
| 888 |
|
| 889 |
# 全局状态
|
| 890 |
current_query_id = gr.State(str(uuid.uuid4()))
|
| 891 |
+
user_uploaded_examples = gr.State(example_images_seg.copy())
|
| 892 |
|
| 893 |
with gr.Tabs():
|
| 894 |
# ===== Tab 1: Segmentation =====
|
|
|
|
| 980 |
visible=False
|
| 981 |
)
|
| 982 |
|
| 983 |
+
# click event for segmentation
|
| 984 |
run_seg_btn.click(
|
| 985 |
fn=segment_with_choice,
|
| 986 |
inputs=[use_box_radio, annotator],
|
| 987 |
outputs=[seg_output, download_mask_btn]
|
| 988 |
)
|
| 989 |
|
| 990 |
+
# click event for clear button
|
| 991 |
clear_btn.click(
|
| 992 |
fn=lambda: None,
|
| 993 |
inputs=None,
|
| 994 |
outputs=annotator
|
| 995 |
)
|
| 996 |
|
| 997 |
+
# init Gallery with example images
|
| 998 |
demo.load(
|
| 999 |
fn=lambda: example_images_seg.copy(),
|
| 1000 |
outputs=example_gallery
|
| 1001 |
)
|
| 1002 |
|
| 1003 |
+
# click event for image uploader
|
| 1004 |
def add_to_gallery(img_path, current_imgs):
|
| 1005 |
if not img_path:
|
| 1006 |
return current_imgs
|
|
|
|
| 1021 |
outputs=example_gallery
|
| 1022 |
)
|
| 1023 |
|
| 1024 |
+
# click event for Gallery selection
|
| 1025 |
def load_from_gallery(evt: gr.SelectData, all_imgs):
|
| 1026 |
if evt.index is not None and evt.index < len(all_imgs):
|
| 1027 |
return all_imgs[evt.index]
|
|
|
|
| 1033 |
outputs=annotator
|
| 1034 |
)
|
| 1035 |
|
| 1036 |
+
# click event for submitting feedback
|
| 1037 |
def submit_user_feedback(query_id, score, comment, annot_val):
|
| 1038 |
try:
|
| 1039 |
img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
|
|
|
|
| 1046 |
# img_path=img_path,
|
| 1047 |
# bboxes=bboxes
|
| 1048 |
# )
|
| 1049 |
+
|
| 1050 |
save_feedback_to_hf(
|
| 1051 |
query_id=query_id,
|
| 1052 |
feedback_type=f"score_{int(score)}",
|
|
|
|
| 1208 |
outputs=[count_output, download_density_btn, count_status]
|
| 1209 |
)
|
| 1210 |
|
| 1211 |
+
# Clear selection
|
| 1212 |
clear_btn.click(
|
| 1213 |
fn=lambda: None,
|
| 1214 |
inputs=None,
|
| 1215 |
outputs=count_annotator
|
| 1216 |
)
|
| 1217 |
|
| 1218 |
+
# Submit feedback
|
| 1219 |
def submit_user_feedback(query_id, score, comment, annot_val):
|
| 1220 |
try:
|
| 1221 |
img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
|
|
|
|
| 1228 |
# img_path=img_path,
|
| 1229 |
# bboxes=bboxes
|
| 1230 |
# )
|
| 1231 |
+
|
| 1232 |
save_feedback_to_hf(
|
| 1233 |
query_id=query_id,
|
| 1234 |
feedback_type=f"score_{int(score)}",
|
|
|
|
| 1530 |
outputs=[track_download, track_output, track_download, track_first_frame_preview]
|
| 1531 |
)
|
| 1532 |
|
| 1533 |
+
# Clear selection
|
| 1534 |
clear_btn.click(
|
| 1535 |
fn=lambda: None,
|
| 1536 |
inputs=None,
|
| 1537 |
outputs=track_first_frame_annotator
|
| 1538 |
)
|
| 1539 |
|
| 1540 |
+
# Submit feedback
|
| 1541 |
def submit_user_feedback(query_id, score, comment, annot_val):
|
| 1542 |
try:
|
| 1543 |
img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
|
|
|
|
| 1550 |
# img_path=img_path,
|
| 1551 |
# bboxes=bboxes
|
| 1552 |
# )
|
| 1553 |
+
|
| 1554 |
save_feedback_to_hf(
|
| 1555 |
query_id=query_id,
|
| 1556 |
feedback_type=f"score_{int(score)}",
|
inference_count.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
# inference_count.py
|
| 2 |
-
# 计数模型推理模块 - 独立版本
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import numpy as np
|
|
@@ -15,24 +14,22 @@ DEVICE = torch.device("cpu")
|
|
| 15 |
|
| 16 |
def load_model(use_box=False):
|
| 17 |
"""
|
| 18 |
-
|
| 19 |
|
| 20 |
Args:
|
| 21 |
-
use_box:
|
| 22 |
|
| 23 |
Returns:
|
| 24 |
-
model:
|
| 25 |
-
device:
|
| 26 |
"""
|
| 27 |
global MODEL, DEVICE
|
| 28 |
|
| 29 |
try:
|
| 30 |
print("🔄 Loading counting model...")
|
| 31 |
-
|
| 32 |
-
# 初始化模型
|
| 33 |
MODEL = CountingModule(use_box=use_box)
|
| 34 |
|
| 35 |
-
# 从 Hugging Face Hub 下载权重
|
| 36 |
ckpt_path = hf_hub_download(
|
| 37 |
repo_id="phoebe777777/111",
|
| 38 |
filename="microscopy_matching_cnt.pth",
|
|
@@ -42,7 +39,6 @@ def load_model(use_box=False):
|
|
| 42 |
|
| 43 |
print(f"✅ Checkpoint downloaded: {ckpt_path}")
|
| 44 |
|
| 45 |
-
# 加载权重
|
| 46 |
MODEL.load_state_dict(
|
| 47 |
torch.load(ckpt_path, map_location="cpu"),
|
| 48 |
strict=True
|
|
@@ -71,20 +67,20 @@ def load_model(use_box=False):
|
|
| 71 |
@torch.no_grad()
|
| 72 |
def run(model, img_path, box=None, device="cpu", visualize=True):
|
| 73 |
"""
|
| 74 |
-
|
| 75 |
|
| 76 |
Args:
|
| 77 |
-
model:
|
| 78 |
-
img_path:
|
| 79 |
-
box:
|
| 80 |
-
device:
|
| 81 |
-
visualize:
|
| 82 |
|
| 83 |
Returns:
|
| 84 |
result_dict: {
|
| 85 |
'density_map': numpy array,
|
| 86 |
'count': float,
|
| 87 |
-
'visualized_path': str (
|
| 88 |
}
|
| 89 |
"""
|
| 90 |
print("DEVICE:", device)
|
|
@@ -107,7 +103,6 @@ def run(model, img_path, box=None, device="cpu", visualize=True):
|
|
| 107 |
try:
|
| 108 |
print(f"🔄 Running counting inference on {img_path}")
|
| 109 |
|
| 110 |
-
# 运行推理 (调用你的模型的 forward 方法)
|
| 111 |
with torch.no_grad():
|
| 112 |
density_map, count = model(img_path, box)
|
| 113 |
|
|
@@ -118,11 +113,7 @@ def run(model, img_path, box=None, device="cpu", visualize=True):
|
|
| 118 |
'count': count,
|
| 119 |
'visualized_path': None
|
| 120 |
}
|
| 121 |
-
|
| 122 |
-
# 可视化
|
| 123 |
-
# if visualize:
|
| 124 |
-
# viz_path = visualize_result(img_path, density_map, count)
|
| 125 |
-
# result['visualized_path'] = viz_path
|
| 126 |
|
| 127 |
return result
|
| 128 |
|
|
@@ -140,47 +131,39 @@ def run(model, img_path, box=None, device="cpu", visualize=True):
|
|
| 140 |
|
| 141 |
def visualize_result(image_path, density_map, count):
|
| 142 |
"""
|
| 143 |
-
|
| 144 |
|
| 145 |
Args:
|
| 146 |
-
image_path:
|
| 147 |
-
density_map:
|
| 148 |
-
count
|
| 149 |
|
| 150 |
Returns:
|
| 151 |
-
output_path:
|
| 152 |
"""
|
| 153 |
try:
|
| 154 |
import skimage.io as io
|
| 155 |
|
| 156 |
-
# 读取原始图像
|
| 157 |
img = io.imread(image_path)
|
| 158 |
|
| 159 |
-
# 处理不同格式的图像
|
| 160 |
if len(img.shape) == 3 and img.shape[2] > 3:
|
| 161 |
img = img[:, :, :3]
|
| 162 |
if len(img.shape) == 2:
|
| 163 |
img = np.stack([img]*3, axis=-1)
|
| 164 |
|
| 165 |
-
# 归一化显示
|
| 166 |
img_show = img.squeeze()
|
| 167 |
density_map_show = density_map.squeeze()
|
| 168 |
|
| 169 |
-
# 归一化图像
|
| 170 |
img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8)
|
| 171 |
|
| 172 |
-
# 创建可视化 (与你原来的代码一致)
|
| 173 |
fig, ax = plt.subplots(figsize=(8, 6))
|
| 174 |
|
| 175 |
-
# 右图: 密度图叠加
|
| 176 |
ax.imshow(img_show)
|
| 177 |
ax.imshow(density_map_show, cmap='jet', alpha=0.5)
|
| 178 |
ax.axis('off')
|
| 179 |
-
# ax.set_title(f"Predicted density map, count: {count:.1f}")
|
| 180 |
|
| 181 |
plt.tight_layout()
|
| 182 |
|
| 183 |
-
# 保存到临时文件
|
| 184 |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
| 185 |
plt.savefig(temp_file.name, dpi=300)
|
| 186 |
plt.close()
|
|
@@ -195,13 +178,11 @@ def visualize_result(image_path, density_map, count):
|
|
| 195 |
return image_path
|
| 196 |
|
| 197 |
|
| 198 |
-
# ===== 测试代码 =====
|
| 199 |
if __name__ == "__main__":
|
| 200 |
print("="*60)
|
| 201 |
print("Testing Counting Model")
|
| 202 |
print("="*60)
|
| 203 |
-
|
| 204 |
-
# 测试模型加载
|
| 205 |
model, device = load_model(use_box=False)
|
| 206 |
|
| 207 |
if model is not None:
|
|
@@ -209,7 +190,6 @@ if __name__ == "__main__":
|
|
| 209 |
print("Model loaded successfully, testing inference...")
|
| 210 |
print("="*60)
|
| 211 |
|
| 212 |
-
# 测试推理
|
| 213 |
test_image = "example_imgs/1977_Well_F-5_Field_1.png"
|
| 214 |
|
| 215 |
if os.path.exists(test_image):
|
|
|
|
| 1 |
# inference_count.py
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
|
|
|
| 14 |
|
| 15 |
def load_model(use_box=False):
|
| 16 |
"""
|
| 17 |
+
load counting model from Hugging Face Hub
|
| 18 |
|
| 19 |
Args:
|
| 20 |
+
use_box: use bounding box as input (default: False)
|
| 21 |
|
| 22 |
Returns:
|
| 23 |
+
model: loaded counting model
|
| 24 |
+
device: device
|
| 25 |
"""
|
| 26 |
global MODEL, DEVICE
|
| 27 |
|
| 28 |
try:
|
| 29 |
print("🔄 Loading counting model...")
|
| 30 |
+
|
|
|
|
| 31 |
MODEL = CountingModule(use_box=use_box)
|
| 32 |
|
|
|
|
| 33 |
ckpt_path = hf_hub_download(
|
| 34 |
repo_id="phoebe777777/111",
|
| 35 |
filename="microscopy_matching_cnt.pth",
|
|
|
|
| 39 |
|
| 40 |
print(f"✅ Checkpoint downloaded: {ckpt_path}")
|
| 41 |
|
|
|
|
| 42 |
MODEL.load_state_dict(
|
| 43 |
torch.load(ckpt_path, map_location="cpu"),
|
| 44 |
strict=True
|
|
|
|
| 67 |
@torch.no_grad()
|
| 68 |
def run(model, img_path, box=None, device="cpu", visualize=True):
|
| 69 |
"""
|
| 70 |
+
Run counting inference on a single image
|
| 71 |
|
| 72 |
Args:
|
| 73 |
+
model: loaded counting model
|
| 74 |
+
img_path: image path
|
| 75 |
+
box: bounding box [[x1, y1, x2, y2], ...] or None
|
| 76 |
+
device: device
|
| 77 |
+
visualize: whether to generate visualization
|
| 78 |
|
| 79 |
Returns:
|
| 80 |
result_dict: {
|
| 81 |
'density_map': numpy array,
|
| 82 |
'count': float,
|
| 83 |
+
'visualized_path': str (if visualize=True)
|
| 84 |
}
|
| 85 |
"""
|
| 86 |
print("DEVICE:", device)
|
|
|
|
| 103 |
try:
|
| 104 |
print(f"🔄 Running counting inference on {img_path}")
|
| 105 |
|
|
|
|
| 106 |
with torch.no_grad():
|
| 107 |
density_map, count = model(img_path, box)
|
| 108 |
|
|
|
|
| 113 |
'count': count,
|
| 114 |
'visualized_path': None
|
| 115 |
}
|
| 116 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
return result
|
| 119 |
|
|
|
|
| 131 |
|
| 132 |
def visualize_result(image_path, density_map, count):
|
| 133 |
"""
|
| 134 |
+
Visualize counting results (consistent with your original visualization code)
|
| 135 |
|
| 136 |
Args:
|
| 137 |
+
image_path: original image path
|
| 138 |
+
density_map: numpy array of predicted density map
|
| 139 |
+
count
|
| 140 |
|
| 141 |
Returns:
|
| 142 |
+
output_path: temporary file path of the visualization result
|
| 143 |
"""
|
| 144 |
try:
|
| 145 |
import skimage.io as io
|
| 146 |
|
|
|
|
| 147 |
img = io.imread(image_path)
|
| 148 |
|
|
|
|
| 149 |
if len(img.shape) == 3 and img.shape[2] > 3:
|
| 150 |
img = img[:, :, :3]
|
| 151 |
if len(img.shape) == 2:
|
| 152 |
img = np.stack([img]*3, axis=-1)
|
| 153 |
|
|
|
|
| 154 |
img_show = img.squeeze()
|
| 155 |
density_map_show = density_map.squeeze()
|
| 156 |
|
|
|
|
| 157 |
img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8)
|
| 158 |
|
|
|
|
| 159 |
fig, ax = plt.subplots(figsize=(8, 6))
|
| 160 |
|
|
|
|
| 161 |
ax.imshow(img_show)
|
| 162 |
ax.imshow(density_map_show, cmap='jet', alpha=0.5)
|
| 163 |
ax.axis('off')
|
|
|
|
| 164 |
|
| 165 |
plt.tight_layout()
|
| 166 |
|
|
|
|
| 167 |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
| 168 |
plt.savefig(temp_file.name, dpi=300)
|
| 169 |
plt.close()
|
|
|
|
| 178 |
return image_path
|
| 179 |
|
| 180 |
|
|
|
|
| 181 |
if __name__ == "__main__":
|
| 182 |
print("="*60)
|
| 183 |
print("Testing Counting Model")
|
| 184 |
print("="*60)
|
| 185 |
+
|
|
|
|
| 186 |
model, device = load_model(use_box=False)
|
| 187 |
|
| 188 |
if model is not None:
|
|
|
|
| 190 |
print("Model loaded successfully, testing inference...")
|
| 191 |
print("="*60)
|
| 192 |
|
|
|
|
| 193 |
test_image = "example_imgs/1977_Well_F-5_Field_1.png"
|
| 194 |
|
| 195 |
if os.path.exists(test_image):
|
inference_seg.py
CHANGED
|
@@ -43,45 +43,3 @@ def run(model, img_path, box=None, device="cpu"):
|
|
| 43 |
output = model(img_path, box=box)
|
| 44 |
mask = output
|
| 45 |
return mask
|
| 46 |
-
# import os
|
| 47 |
-
# import torch
|
| 48 |
-
# import numpy as np
|
| 49 |
-
# from huggingface_hub import hf_hub_download
|
| 50 |
-
# from segmentation import SegmentationModule
|
| 51 |
-
|
| 52 |
-
# MODEL = None
|
| 53 |
-
# DEVICE = torch.device("cpu")
|
| 54 |
-
|
| 55 |
-
# def load_model(use_box=False):
|
| 56 |
-
# global MODEL, DEVICE
|
| 57 |
-
|
| 58 |
-
# # === 优化1: 使用 /data 缓存模型,避免写入 .cache ===
|
| 59 |
-
# cache_dir = "/data/cellseg_model_cache"
|
| 60 |
-
# os.makedirs(cache_dir, exist_ok=True)
|
| 61 |
-
|
| 62 |
-
# ckpt_path = hf_hub_download(
|
| 63 |
-
# repo_id="Shengxiao0709/cellsegmodel",
|
| 64 |
-
# filename="microscopy_matching_seg.pth",
|
| 65 |
-
# token=None,
|
| 66 |
-
# local_dir=cache_dir, # ✅ 下载到 /data
|
| 67 |
-
# local_dir_use_symlinks=False, # ✅ 避免软链接问题
|
| 68 |
-
# force_download=False # ✅ 已存在时不重复下载
|
| 69 |
-
# )
|
| 70 |
-
|
| 71 |
-
# # === 优化2: 加载模型 ===
|
| 72 |
-
# MODEL = SegmentationModule(use_box=use_box)
|
| 73 |
-
# state_dict = torch.load(ckpt_path, map_location="cpu")
|
| 74 |
-
# MODEL.load_state_dict(state_dict, strict=False)
|
| 75 |
-
# MODEL.eval()
|
| 76 |
-
|
| 77 |
-
# DEVICE = torch.device("cpu")
|
| 78 |
-
# print(f"✅ Model loaded from {ckpt_path}")
|
| 79 |
-
# return MODEL, DEVICE
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
# @torch.no_grad()
|
| 83 |
-
# def run(model, img_path, box=None, device="cpu"):
|
| 84 |
-
# output = model(img_path, box=box)
|
| 85 |
-
# mask = output["pred"]
|
| 86 |
-
# mask = (mask > 0).astype(np.uint8)
|
| 87 |
-
# return mask
|
|
|
|
| 43 |
output = model(img_path, box=box)
|
| 44 |
mask = output
|
| 45 |
return mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference_track.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
# inference_track.py
|
| 2 |
-
# 视频跟踪模型推理模块
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import numpy as np
|
|
@@ -15,14 +14,14 @@ DEVICE = torch.device("cpu")
|
|
| 15 |
|
| 16 |
def load_model(use_box=False):
|
| 17 |
"""
|
| 18 |
-
|
| 19 |
|
| 20 |
Args:
|
| 21 |
-
use_box:
|
| 22 |
|
| 23 |
Returns:
|
| 24 |
-
model:
|
| 25 |
-
device
|
| 26 |
"""
|
| 27 |
global MODEL, DEVICE
|
| 28 |
|
|
@@ -32,7 +31,7 @@ def load_model(use_box=False):
|
|
| 32 |
# 初始化模型
|
| 33 |
MODEL = TrackingModule(use_box=use_box)
|
| 34 |
|
| 35 |
-
#
|
| 36 |
ckpt_path = hf_hub_download(
|
| 37 |
repo_id="phoebe777777/111",
|
| 38 |
filename="microscopy_matching_tra.pth",
|
|
@@ -42,14 +41,14 @@ def load_model(use_box=False):
|
|
| 42 |
|
| 43 |
print(f"✅ Checkpoint downloaded: {ckpt_path}")
|
| 44 |
|
| 45 |
-
#
|
| 46 |
MODEL.load_state_dict(
|
| 47 |
torch.load(ckpt_path, map_location="cpu"),
|
| 48 |
strict=True
|
| 49 |
)
|
| 50 |
MODEL.eval()
|
| 51 |
|
| 52 |
-
#
|
| 53 |
if torch.cuda.is_available():
|
| 54 |
DEVICE = torch.device("cuda")
|
| 55 |
MODEL.move_to_device(DEVICE)
|
|
@@ -72,21 +71,21 @@ def load_model(use_box=False):
|
|
| 72 |
@torch.no_grad()
|
| 73 |
def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
|
| 74 |
"""
|
| 75 |
-
|
| 76 |
|
| 77 |
Args:
|
| 78 |
-
model:
|
| 79 |
-
video_dir:
|
| 80 |
-
box:
|
| 81 |
-
device:
|
| 82 |
-
output_dir:
|
| 83 |
|
| 84 |
Returns:
|
| 85 |
result_dict: {
|
| 86 |
-
'track_graph': TrackGraph
|
| 87 |
-
'masks':
|
| 88 |
-
'output_dir':
|
| 89 |
-
'num_tracks':
|
| 90 |
}
|
| 91 |
"""
|
| 92 |
if model is None:
|
|
@@ -101,11 +100,11 @@ def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
|
|
| 101 |
try:
|
| 102 |
print(f"🔄 Running tracking inference on {video_dir}")
|
| 103 |
|
| 104 |
-
#
|
| 105 |
track_graph, masks = model.track(
|
| 106 |
file_dir=video_dir,
|
| 107 |
boxes=box,
|
| 108 |
-
mode="greedy", #
|
| 109 |
dataname="tracking_result"
|
| 110 |
)
|
| 111 |
|
|
@@ -113,7 +112,7 @@ def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
|
|
| 113 |
if not os.path.exists(output_dir):
|
| 114 |
os.makedirs(output_dir)
|
| 115 |
|
| 116 |
-
#
|
| 117 |
print("🔄 Converting to CTC format...")
|
| 118 |
ctc_tracks, masks_tracked = graph_to_ctc(
|
| 119 |
track_graph,
|
|
@@ -122,7 +121,6 @@ def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
|
|
| 122 |
)
|
| 123 |
print(f"✅ CTC results saved to {output_dir}")
|
| 124 |
|
| 125 |
-
# num_tracks = len(track_graph.tracks())
|
| 126 |
|
| 127 |
print(f"✅ Tracking completed")
|
| 128 |
|
|
@@ -131,7 +129,6 @@ def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
|
|
| 131 |
'masks': masks,
|
| 132 |
'masks_tracked': masks_tracked,
|
| 133 |
'output_dir': output_dir,
|
| 134 |
-
# 'num_tracks': num_tracks
|
| 135 |
}
|
| 136 |
|
| 137 |
return result
|
|
@@ -151,36 +148,35 @@ def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
|
|
| 151 |
|
| 152 |
def visualize_tracking_result(masks_tracked, output_path):
|
| 153 |
"""
|
| 154 |
-
|
| 155 |
|
| 156 |
Args:
|
| 157 |
-
masks_tracked:
|
| 158 |
-
output_path:
|
| 159 |
|
| 160 |
Returns:
|
| 161 |
-
output_path:
|
| 162 |
"""
|
| 163 |
try:
|
| 164 |
import cv2
|
| 165 |
import matplotlib.pyplot as plt
|
| 166 |
from matplotlib import cm
|
| 167 |
|
| 168 |
-
# 获取时间帧数
|
| 169 |
T, H, W = masks_tracked.shape
|
| 170 |
|
| 171 |
-
#
|
| 172 |
unique_ids = np.unique(masks_tracked)
|
| 173 |
num_colors = len(unique_ids)
|
| 174 |
cmap = cm.get_cmap('tab20', num_colors)
|
| 175 |
|
| 176 |
-
#
|
| 177 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 178 |
out = cv2.VideoWriter(output_path, fourcc, 5.0, (W, H))
|
| 179 |
|
| 180 |
for t in range(T):
|
| 181 |
frame = masks_tracked[t]
|
| 182 |
|
| 183 |
-
#
|
| 184 |
colored_frame = np.zeros((H, W, 3), dtype=np.uint8)
|
| 185 |
for i, obj_id in enumerate(unique_ids):
|
| 186 |
if obj_id == 0:
|
|
@@ -189,7 +185,7 @@ def visualize_tracking_result(masks_tracked, output_path):
|
|
| 189 |
color = np.array(cmap(i % num_colors)[:3]) * 255
|
| 190 |
colored_frame[mask] = color
|
| 191 |
|
| 192 |
-
#
|
| 193 |
colored_frame_bgr = cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR)
|
| 194 |
out.write(colored_frame_bgr)
|
| 195 |
|
|
|
|
| 1 |
# inference_track.py
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
|
|
|
| 14 |
|
| 15 |
def load_model(use_box=False):
|
| 16 |
"""
|
| 17 |
+
load tracking model from Hugging Face Hub
|
| 18 |
|
| 19 |
Args:
|
| 20 |
+
use_box: use bounding box as input (default: False)
|
| 21 |
|
| 22 |
Returns:
|
| 23 |
+
model: loaded tracking model
|
| 24 |
+
device
|
| 25 |
"""
|
| 26 |
global MODEL, DEVICE
|
| 27 |
|
|
|
|
| 31 |
# 初始化模型
|
| 32 |
MODEL = TrackingModule(use_box=use_box)
|
| 33 |
|
| 34 |
+
# Load checkpoint from Hugging Face Hub
|
| 35 |
ckpt_path = hf_hub_download(
|
| 36 |
repo_id="phoebe777777/111",
|
| 37 |
filename="microscopy_matching_tra.pth",
|
|
|
|
| 41 |
|
| 42 |
print(f"✅ Checkpoint downloaded: {ckpt_path}")
|
| 43 |
|
| 44 |
+
# Load weights
|
| 45 |
MODEL.load_state_dict(
|
| 46 |
torch.load(ckpt_path, map_location="cpu"),
|
| 47 |
strict=True
|
| 48 |
)
|
| 49 |
MODEL.eval()
|
| 50 |
|
| 51 |
+
# Move model to device
|
| 52 |
if torch.cuda.is_available():
|
| 53 |
DEVICE = torch.device("cuda")
|
| 54 |
MODEL.move_to_device(DEVICE)
|
|
|
|
| 71 |
@torch.no_grad()
|
| 72 |
def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
|
| 73 |
"""
|
| 74 |
+
run tracking inference on video frames
|
| 75 |
|
| 76 |
Args:
|
| 77 |
+
model: loaded tracking model
|
| 78 |
+
video_dir: directory of video frame sequence (contains consecutive image files)
|
| 79 |
+
box: bounding box (optional)
|
| 80 |
+
device: device
|
| 81 |
+
output_dir: output directory
|
| 82 |
|
| 83 |
Returns:
|
| 84 |
result_dict: {
|
| 85 |
+
'track_graph': TrackGraph object containing tracking results,
|
| 86 |
+
'masks': tracked masks (T, H, W),
|
| 87 |
+
'output_dir': output directory path,
|
| 88 |
+
'num_tracks': number of tracked trajectories
|
| 89 |
}
|
| 90 |
"""
|
| 91 |
if model is None:
|
|
|
|
| 100 |
try:
|
| 101 |
print(f"🔄 Running tracking inference on {video_dir}")
|
| 102 |
|
| 103 |
+
# Run tracking
|
| 104 |
track_graph, masks = model.track(
|
| 105 |
file_dir=video_dir,
|
| 106 |
boxes=box,
|
| 107 |
+
mode="greedy", # Optional: "greedy", "greedy_nodiv", "ilp"
|
| 108 |
dataname="tracking_result"
|
| 109 |
)
|
| 110 |
|
|
|
|
| 112 |
if not os.path.exists(output_dir):
|
| 113 |
os.makedirs(output_dir)
|
| 114 |
|
| 115 |
+
# Convert tracking results to CTC format and save
|
| 116 |
print("🔄 Converting to CTC format...")
|
| 117 |
ctc_tracks, masks_tracked = graph_to_ctc(
|
| 118 |
track_graph,
|
|
|
|
| 121 |
)
|
| 122 |
print(f"✅ CTC results saved to {output_dir}")
|
| 123 |
|
|
|
|
| 124 |
|
| 125 |
print(f"✅ Tracking completed")
|
| 126 |
|
|
|
|
| 129 |
'masks': masks,
|
| 130 |
'masks_tracked': masks_tracked,
|
| 131 |
'output_dir': output_dir,
|
|
|
|
| 132 |
}
|
| 133 |
|
| 134 |
return result
|
|
|
|
| 148 |
|
| 149 |
def visualize_tracking_result(masks_tracked, output_path):
|
| 150 |
"""
|
| 151 |
+
visualize tracking results
|
| 152 |
|
| 153 |
Args:
|
| 154 |
+
masks_tracked: masks with tracking results (T, H, W)
|
| 155 |
+
output_path: output video file path
|
| 156 |
|
| 157 |
Returns:
|
| 158 |
+
output_path: output video file path
|
| 159 |
"""
|
| 160 |
try:
|
| 161 |
import cv2
|
| 162 |
import matplotlib.pyplot as plt
|
| 163 |
from matplotlib import cm
|
| 164 |
|
|
|
|
| 165 |
T, H, W = masks_tracked.shape
|
| 166 |
|
| 167 |
+
# create a color map for unique track IDs
|
| 168 |
unique_ids = np.unique(masks_tracked)
|
| 169 |
num_colors = len(unique_ids)
|
| 170 |
cmap = cm.get_cmap('tab20', num_colors)
|
| 171 |
|
| 172 |
+
# create video writer
|
| 173 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 174 |
out = cv2.VideoWriter(output_path, fourcc, 5.0, (W, H))
|
| 175 |
|
| 176 |
for t in range(T):
|
| 177 |
frame = masks_tracked[t]
|
| 178 |
|
| 179 |
+
# create colored image
|
| 180 |
colored_frame = np.zeros((H, W, 3), dtype=np.uint8)
|
| 181 |
for i, obj_id in enumerate(unique_ids):
|
| 182 |
if obj_id == 0:
|
|
|
|
| 185 |
color = np.array(cmap(i % num_colors)[:3]) * 255
|
| 186 |
colored_frame[mask] = color
|
| 187 |
|
| 188 |
+
# convert to BGR (OpenCV format)
|
| 189 |
colored_frame_bgr = cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR)
|
| 190 |
out.write(colored_frame_bgr)
|
| 191 |
|