Leema Krishna Murali commited on
Commit
f3d0a26
·
1 Parent(s): a2f1ff3

Initial commit

Browse files
app.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ from visualizer import draw_box_on_frame, create_comparison_strip
5
+ from preview import preview_trajectory
6
+ from pipeline_adapter import (
7
+ extract_first_frame,
8
+ load_all_frames,
9
+ run_pipeline_motion_edit,
10
+ run_pipeline_insertion # ← need to add this
11
+ )
12
+
13
+
14
+ def build_interface():
15
+
16
+ # Load Qwen-Image-Edit once at startup (not per-click — model is ~20GB)
17
+ _qwen_edit_pipe = None
18
+ try:
19
+ from frame_editor import load_qwen_image_edit
20
+ _qwen_edit_pipe = load_qwen_image_edit(use_lightning=True, device="cuda")
21
+ print("Qwen-Image-Edit ready.")
22
+ except Exception as e:
23
+ print(f"Qwen-Image-Edit not available: {e}")
24
+
25
+ with gr.Blocks(title="TRACE Prototype", theme=gr.themes.Soft()) as demo:
26
+
27
+ gr.Markdown("# TRACE Prototype — Object Motion Editing")
28
+
29
+ with gr.Tabs():
30
+
31
+ # ── Tab 1: Motion Edit (existing) ─────────────────────────
32
+ # with gr.Tab("Motion Path Edit"):
33
+ # gr.Markdown(
34
+ # "Move an **existing object** in the video "
35
+ # "to a new trajectory."
36
+ # )
37
+
38
+ # with gr.Row():
39
+ # with gr.Column():
40
+ # video_input_edit = gr.Video(label="Input Video")
41
+ # video_info_edit = gr.Markdown("")
42
+
43
+ # with gr.Column():
44
+ # first_frame_edit = gr.Image(
45
+ # label="First Frame + Trajectory Preview",
46
+ # interactive=False
47
+ # )
48
+
49
+ # gr.Markdown("**Start Box** — draw around the object")
50
+ # with gr.Row():
51
+ # sx1 = gr.Number(label="x1", value=100, precision=0)
52
+ # sy1 = gr.Number(label="y1", value=100, precision=0)
53
+ # sx2 = gr.Number(label="x2", value=200, precision=0)
54
+ # sy2 = gr.Number(label="y2", value=200, precision=0)
55
+
56
+ # gr.Markdown("**End Box** — where you want it to go")
57
+ # with gr.Row():
58
+ # ex1 = gr.Number(label="x1", value=500, precision=0)
59
+ # ey1 = gr.Number(label="y1", value=200, precision=0)
60
+ # ex2 = gr.Number(label="x2", value=600, precision=0)
61
+ # ey2 = gr.Number(label="y2", value=300, precision=0)
62
+
63
+ # prompt_edit = gr.Textbox(
64
+ # label="Scene Description",
65
+ # placeholder="a dog running in a park..."
66
+ # )
67
+
68
+ # with gr.Row():
69
+ # stage1_method = gr.Radio(
70
+ # choices=["linear", "cotracker"],
71
+ # value="linear",
72
+ # label="Stage 1 Method"
73
+ # )
74
+ # use_vace_edit = gr.Checkbox(
75
+ # label="Use VACE",
76
+ # value=False
77
+ # )
78
+
79
+ # run_edit_btn = gr.Button("Run Motion Edit", variant="primary")
80
+
81
+ # with gr.Row():
82
+ # output_video_edit = gr.Video(label="Output Video")
83
+ # metrics_edit = gr.Markdown("")
84
+
85
+ # comparison_edit = gr.Image(label="Frame Comparison", interactive=False)
86
+
87
+ # ── Tab 2: Object Insertion (NEW — uses Qwen) ─────────────
88
+ with gr.Tab("Object Insertion"):
89
+ gr.Markdown(
90
+ "Insert a **new object** into the video using "
91
+ "Qwen to edit the first frame, then propagate."
92
+ )
93
+
94
+ with gr.Row():
95
+ with gr.Column():
96
+ video_input_ins = gr.Video(label="Input Video")
97
+ video_info_ins = gr.Markdown("")
98
+
99
+ with gr.Column():
100
+ first_frame_ins = gr.Image(
101
+ label="First Frame Preview",
102
+ interactive=False
103
+ )
104
+
105
+ gr.Markdown("**Insertion Box** — where to place the new object")
106
+ with gr.Row():
107
+ ix1 = gr.Number(label="x1", value=40, precision=0)
108
+ iy1 = gr.Number(label="y1", value=40, precision=0)
109
+ ix2 = gr.Number(label="x2", value=300, precision=0)
110
+ iy2 = gr.Number(label="y2", value=300, precision=0)
111
+
112
+ gr.Markdown("**End Box** — where the object should arrive")
113
+ with gr.Row():
114
+ iex1 = gr.Number(label="x1", value=500, precision=0)
115
+ iey1 = gr.Number(label="y1", value=200, precision=0)
116
+ iex2 = gr.Number(label="x2", value=600, precision=0)
117
+ iey2 = gr.Number(label="y2", value=300, precision=0)
118
+
119
+ # ── The Qwen-specific inputs ───────────────────────────
120
+ gr.Markdown("**Object Description** — what Qwen will insert")
121
+ with gr.Row():
122
+ with gr.Column():
123
+ object_description = gr.Textbox(
124
+ label="Object to Insert (Qwen prompt)",
125
+ placeholder="a red helium balloon with a white string",
126
+ info="Qwen uses this to paint the object into frame 1"
127
+ )
128
+ scene_prompt = gr.Textbox(
129
+ label="Full Scene Prompt (for video synthesis)",
130
+ placeholder="a peaceful park scene with a red balloon"
131
+ )
132
+
133
+ with gr.Column():
134
+ gr.Markdown("Using **Qwen-Image-Edit-2511** for object insertion")
135
+
136
+ # use_vace_ins = gr.Checkbox(
137
+ # label="Use VACE",
138
+ # value=False
139
+ # )
140
+
141
+ # ── Qwen output preview before running video ───────────
142
+ gr.Markdown("**Step 1 Preview** — see Qwen's edit before running video")
143
+ preview_qwen_btn = gr.Button(
144
+ "Preview First Frame Edit",
145
+ variant="secondary"
146
+ )
147
+ edited_frame_preview = gr.Image(
148
+ label="Qwen-Edited First Frame",
149
+ interactive=False
150
+ )
151
+ qwen_status = gr.Markdown("")
152
+
153
+ # gr.Markdown("---")
154
+ # run_ins_btn = gr.Button(
155
+ # "Run Full Insertion Pipeline",
156
+ # variant="primary"
157
+ # )
158
+
159
+ # with gr.Row():
160
+ # output_video_ins = gr.Video(label="Output Video")
161
+ # metrics_ins = gr.Markdown("")
162
+
163
+ # comparison_ins = gr.Image(
164
+ # label="Frame Comparison",
165
+ # interactive=False
166
+ # )
167
+
168
+ # ── Wire Up Tab 1 ─────────────────────────────────────────────
169
+ #_state = {"frames": None, "first_frame": None}
170
+
171
+
172
+ # def on_video_upload_edit(video_path):
173
+ # if video_path is None:
174
+ # return None, "Upload a video."
175
+ # first_frame = extract_first_frame(video_path)
176
+ # _state["first_frame"] = first_frame
177
+ # return Image.fromarray(first_frame), "Video loaded."
178
+
179
+ # def on_boxes_changed_edit(sx1, sy1, sx2, sy2, ex1, ey1, ex2, ey2):
180
+ # if _state["first_frame"] is None:
181
+ # return None
182
+ # from preview import preview_trajectory
183
+ # preview = preview_trajectory(
184
+ # _state["first_frame"],
185
+ # [sx1, sy1, sx2, sy2],
186
+ # [ex1, ey1, ex2, ey2]
187
+ # )
188
+ # return Image.fromarray(preview)
189
+
190
+ # video_input_edit.change(
191
+ # fn=on_video_upload_edit,
192
+ # inputs=[video_input_edit],
193
+ # outputs=[first_frame_edit, video_info_edit]
194
+ # )
195
+
196
+ # for inp in [sx1, sy1, sx2, sy2, ex1, ey1, ex2, ey2]:
197
+ # inp.change(
198
+ # fn=on_boxes_changed_edit,
199
+ # inputs=[sx1, sy1, sx2, sy2, ex1, ey1, ex2, ey2],
200
+ # outputs=[first_frame_edit]
201
+ # )
202
+
203
+ # def on_run_edit(video_path, sx1, sy1, sx2, sy2, ex1, ey1, ex2, ey2,
204
+ # prompt, stage1_method, use_vace, progress=gr.Progress()):
205
+ # if video_path is None:
206
+ # raise gr.Error("Please upload a video first.")
207
+ # if sx2 <= sx1 or sy2 <= sy1:
208
+ # raise gr.Error("Start box is invalid: x2 must be > x1, y2 must be > y1")
209
+ # if ex2 <= ex1 or ey2 <= ey1:
210
+ # raise gr.Error("End box is invalid: x2 must be > x1, y2 must be > y1")
211
+
212
+ # def prog(frac, msg):
213
+ # progress(frac, desc=msg)
214
+
215
+ # output_path, result_frames, pred_boxes, metrics = \
216
+ # run_pipeline_motion_edit(
217
+ # video_path=video_path,
218
+ # start_box=[sx1, sy1, sx2, sy2],
219
+ # end_box=[ex1, ey1, ex2, ey2],
220
+ # prompt=prompt,
221
+ # stage1_method=stage1_method,
222
+ # use_vace=use_vace,
223
+ # progress_callback=prog
224
+ # )
225
+
226
+ # if _state["frames"] is None:
227
+ # _state["frames"] = load_all_frames(video_path)
228
+
229
+ # comparison = create_comparison_strip(
230
+ # _state["frames"],
231
+ # result_frames,
232
+ # pred_boxes,
233
+ # sample_ts=[0, 20, 40, 60, 80]
234
+ # )
235
+
236
+ # return output_path, Image.fromarray(comparison), metrics
237
+
238
+
239
+ # run_edit_btn.click(
240
+ # fn=on_run_edit,
241
+ # inputs=[
242
+ # video_input_edit,
243
+ # sx1, sy1, sx2, sy2,
244
+ # ex1, ey1, ex2, ey2,
245
+ # prompt_edit, stage1_method, use_vace_edit
246
+ # ],
247
+ # outputs=[output_video_edit, comparison_edit, metrics_edit]
248
+ # )
249
+
250
+ # ── Wire Up Tab 2 (Qwen insertion) ────────────────────────────
251
+ _ins_state = {"first_frame": None, "edited_frame": None}
252
+
253
+
254
+ def on_video_upload_ins(video_path):
255
+ if video_path is None:
256
+ return None, "Upload a video."
257
+ first_frame = extract_first_frame(video_path)
258
+ _ins_state["first_frame"] = first_frame
259
+ return Image.fromarray(first_frame), "Video loaded."
260
+
261
+ def on_preview_qwen(
262
+ video_path,
263
+ ix1, iy1, ix2, iy2,
264
+ object_description,
265
+ progress=gr.Progress()
266
+ ):
267
+ if _ins_state["first_frame"] is None:
268
+ raise gr.Error("Upload a video first.")
269
+ if not object_description.strip():
270
+ raise gr.Error("Enter an object description.")
271
+ if _qwen_edit_pipe is None:
272
+ raise gr.Error("Qwen-Image-Edit failed to load at startup. Check logs.")
273
+
274
+ insertion_box = [ix1, iy1, ix2, iy2]
275
+
276
+ progress(0.3, "Editing first frame with Qwen-Image-Edit...")
277
+ from frame_editor import insert_object_qwen_edit
278
+ edited = insert_object_qwen_edit(
279
+ first_frame=_ins_state["first_frame"],
280
+ box=insertion_box,
281
+ object_description=object_description,
282
+ pipe=_qwen_edit_pipe,
283
+ )
284
+
285
+ _ins_state["edited_frame"] = edited
286
+
287
+ preview = draw_box_on_frame(
288
+ edited,
289
+ insertion_box,
290
+ color=(255, 220, 0),
291
+ label="inserted here"
292
+ )
293
+
294
+ progress(1.0, "Done!")
295
+ return (
296
+ Image.fromarray(preview),
297
+ "First frame edited."
298
+ )
299
+
300
+
301
+ def on_run_insertion(
302
+ video_path,
303
+ ix1, iy1, ix2, iy2,
304
+ iex1, iey1, iex2, iey2,
305
+ scene_prompt,
306
+ use_vace_ins,
307
+ progress=gr.Progress()
308
+ ):
309
+ if _ins_state["edited_frame"] is None:
310
+ raise gr.Error(
311
+ "Run 'Preview First Frame Edit' first — "
312
+ "the edited frame is needed as appearance reference."
313
+ )
314
+
315
+ output_path, result_frames, pred_boxes, metrics = \
316
+ run_pipeline_insertion(
317
+ video_path=video_path,
318
+ edited_first_frame=_ins_state["edited_frame"],
319
+ start_box=[ix1, iy1, ix2, iy2],
320
+ end_box=[iex1, iey1, iex2, iey2],
321
+ prompt=scene_prompt,
322
+ use_vace=use_vace_ins,
323
+ progress_callback=lambda f, m: progress(f, desc=m)
324
+ )
325
+
326
+ frames = load_all_frames(video_path)
327
+ comparison = create_comparison_strip(
328
+ frames, result_frames, pred_boxes
329
+ )
330
+
331
+ return (
332
+ output_path,
333
+ Image.fromarray(comparison),
334
+ metrics
335
+ )
336
+
337
+ video_input_ins.change(
338
+ fn=on_video_upload_ins,
339
+ inputs=[video_input_ins],
340
+ outputs=[first_frame_ins, video_info_ins]
341
+ )
342
+
343
+ preview_qwen_btn.click(
344
+ fn=on_preview_qwen,
345
+ inputs=[
346
+ video_input_ins,
347
+ ix1, iy1, ix2, iy2,
348
+ object_description,
349
+ ],
350
+ outputs=[edited_frame_preview, qwen_status]
351
+ )
352
+
353
+ # run_ins_btn.click(
354
+ # fn=on_run_insertion,
355
+ # inputs=[
356
+ # video_input_ins,
357
+ # ix1, iy1, ix2, iy2,
358
+ # iex1, iey1, iex2, iey2,
359
+ # scene_prompt,
360
+ # use_vace_ins
361
+ # ],
362
+ # outputs=[output_video_ins, comparison_ins, metrics_ins]
363
+ # )
364
+
365
+ return demo
366
+
367
+
368
+ if __name__ == "__main__":
369
+ demo = build_interface()
370
+ demo.launch(share=True)
frame_editor.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # frame_editor.py
2
+
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ import cv2
7
+
8
+ def load_qwen_image_edit(use_lightning=True, device="cuda"):
9
+ from diffusers import QwenImageEditPlusPipeline, FlowMatchEulerDiscreteScheduler
10
+
11
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
12
+ "Qwen/Qwen-Image-Edit-2511", subfolder="scheduler"
13
+ )
14
+ pipe = QwenImageEditPlusPipeline.from_pretrained(
15
+ "Qwen/Qwen-Image-Edit-2511",
16
+ scheduler=scheduler,
17
+ torch_dtype=torch.bfloat16,
18
+ ).to(device)
19
+
20
+ if use_lightning:
21
+ pipe.load_lora_weights(
22
+ "lightx2v/Qwen-Image-Edit-2511-Lightning",
23
+ weight_name="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors"
24
+ )
25
+ pipe.fuse_lora()
26
+
27
+ return pipe
28
+
29
+
30
+ def insert_object_qwen_edit(
31
+ first_frame, # np.ndarray [H, W, 3] uint8 RGB
32
+ box, # [x1, y1, x2, y2]
33
+ object_description, # e.g. "a red sports car"
34
+ pipe,
35
+ context_pad=60, # pixels of context around box — helps Qwen understand scene
36
+ num_inference_steps=4,
37
+ guidance_scale=1.0,
38
+ seed=42,
39
+ ):
40
+ """
41
+ Inserts object into ONLY the bounding box region.
42
+ Background outside the box is pixel-identical to original.
43
+
44
+ Strategy:
45
+ 1. Crop (box + padding) from original → gives Qwen scene context
46
+ 2. Edit the crop with Qwen-Image-Edit
47
+ 3. Extract only the box pixels from the edited crop
48
+ 4. Paste back onto original frame
49
+ """
50
+ H, W = first_frame.shape[:2]
51
+ x1, y1, x2, y2 = [int(v) for v in box]
52
+
53
+ # --- Step 1: Crop with context padding ---
54
+ cx1 = max(0, x1 - context_pad)
55
+ cy1 = max(0, y1 - context_pad)
56
+ cx2 = min(W, x2 + context_pad)
57
+ cy2 = min(H, y2 + context_pad)
58
+
59
+ crop = first_frame[cy1:cy2, cx1:cx2].copy() # [cH, cW, 3]
60
+ cH, cW = crop.shape[:2]
61
+
62
+ # Box coordinates relative to crop
63
+ lx1 = x1 - cx1
64
+ ly1 = y1 - cy1
65
+ lx2 = x2 - cx1
66
+ ly2 = y2 - cy1
67
+
68
+ # --- Step 2: Build focused edit instruction ---
69
+ prompt = (
70
+ f"Insert {object_description} in the region ({lx1},{ly1}) to ({lx2},{ly2}). "
71
+ f"Keep everything outside that region exactly the same. "
72
+ f"Match the scene lighting, shadows, and perspective."
73
+ )
74
+
75
+ generator = torch.Generator().manual_seed(seed)
76
+
77
+ edited = pipe(
78
+ image=[Image.fromarray(crop)],
79
+ prompt=prompt,
80
+ num_inference_steps=num_inference_steps,
81
+ true_cfg_scale=guidance_scale,
82
+ negative_prompt=" ",
83
+ generator=generator,
84
+ ).images[0]
85
+
86
+ edited_np = np.array(edited) # [cH', cW', 3]
87
+
88
+ # Resize back if pipeline changed resolution
89
+ if edited_np.shape[:2] != (cH, cW):
90
+ edited_np = cv2.resize(edited_np, (cW, cH), interpolation=cv2.INTER_LINEAR)
91
+
92
+ # --- Step 3: Hard composite — only paste the box region back ---
93
+ result = first_frame.copy()
94
+ result[y1:y2, x1:x2] = edited_np[ly1:ly2, lx1:lx2]
95
+
96
+ return result # [H, W, 3] uint8 RGB — background unchanged
97
+
98
+
99
+
100
+ def segment_existing_object(
101
+ first_frame: np.ndarray,
102
+ box: list,
103
+ sam2_predictor
104
+ ) -> np.ndarray:
105
+ """
106
+ Use SAM2 to get a precise mask of an existing object.
107
+ Returns: [H, W] binary float32 mask
108
+ """
109
+ sam2_predictor.set_image(first_frame)
110
+
111
+ input_box = np.array([box])
112
+ masks, scores, _ = sam2_predictor.predict(
113
+ box=input_box,
114
+ multimask_output=False
115
+ )
116
+
117
+ return masks[np.argmax(scores)].astype(np.float32)
pipeline.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline.py
2
+
3
+ import numpy as np
4
+ import torch
5
+ from utils.video_utils import load_video, save_video
6
+ from utils.box_utils import boxes_to_mask_sequence
7
+ from stage1_approx import stage1_linear, stage1_cotracker
8
+ from stage2_vace import VACEWrapper, SimpleCompositeStage2
9
+
10
+
11
+ class TRACEPrototype:
12
+
13
+ def __init__(self, use_vace: bool = False, use_cotracker: bool = False):
14
+
15
+ # ── Stage 2: Video Synthesis ──────────────────────────────────
16
+ if use_vace:
17
+ self.stage2 = VACEWrapper()
18
+ else:
19
+ self.stage2 = SimpleCompositeStage2()
20
+
21
+ # ── CoTracker for Stage 1 ─────────────────────────────────────
22
+ self.cotracker = None
23
+ if use_cotracker:
24
+ try:
25
+ self.cotracker = torch.hub.load(
26
+ "facebookresearch/co-tracker",
27
+ "cotracker3_online"
28
+ ).cuda()
29
+ print("CoTracker loaded.")
30
+ except Exception as e:
31
+ print(f"CoTracker failed to load: {e}")
32
+ print("Falling back to linear interpolation.")
33
+
34
+ # ── SAM2 for object segmentation ─────────────────────────────
35
+ self.sam2 = None
36
+ try:
37
+ from sam2.build_sam import build_sam2
38
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
39
+ self.sam2 = SAM2ImagePredictor(
40
+ build_sam2("sam2_hiera_large.pt")
41
+ )
42
+ print("SAM2 loaded.")
43
+ except Exception as e:
44
+ print(f"SAM2 not available: {e}")
45
+ print("Will use box masks directly instead of segmentation.")
46
+
47
+ # ── Qwen-Image-Edit for object insertion ──────────────────────
48
+ self.qwen_edit_pipe = None
49
+ try:
50
+ from frame_editor import load_qwen_image_edit
51
+ self.qwen_edit_pipe = load_qwen_image_edit(
52
+ use_lightning=True, device="cuda"
53
+ )
54
+ print("Qwen-Image-Edit loaded.")
55
+ except Exception as e:
56
+ print(f"Qwen-Image-Edit not available: {e}")
57
+
58
+
59
+ def run_motion_edit(
60
+ self,
61
+ video_path: str,
62
+ keyboxes: dict, # {frame_idx: [x1, y1, x2, y2]}
63
+ text_prompt: str,
64
+ output_path: str = None,
65
+ frames: np.ndarray = None # pass directly to avoid reloading
66
+ ) -> np.ndarray:
67
+ """
68
+ Edit the trajectory of an existing object in the video.
69
+
70
+ keyboxes must include:
71
+ - frame 0: current object location (start)
72
+ - at least one other frame: target location (end)
73
+ """
74
+
75
+ # Load video if frames not passed directly
76
+ if frames is None:
77
+ frames = load_video(video_path)
78
+ T, H, W, _ = frames.shape
79
+
80
+ # ── Stage 1: Compute target trajectory ───────────────────────
81
+ if self.cotracker is not None:
82
+ pred_boxes = stage1_cotracker(
83
+ frames, keyboxes, self.cotracker
84
+ )
85
+ else:
86
+ pred_boxes = stage1_linear(keyboxes, T)
87
+
88
+ # ── Build masks ───────────────────────────────────────────────
89
+ # Synthesis mask: where to PLACE the object (new trajectory)
90
+ synthesis_masks = boxes_to_mask_sequence(pred_boxes, H, W)
91
+
92
+ # Inpainting mask: where to ERASE the object (original position)
93
+ # Use SAM2 for precise mask if available, else use box directly
94
+ orig_box = keyboxes[0]
95
+ if self.sam2 is not None:
96
+ from frame_editor import segment_existing_object
97
+ seg_mask = segment_existing_object(
98
+ frames[0], orig_box, self.sam2
99
+ )
100
+ # Propagate original mask roughly using linear boxes
101
+ orig_keyboxes = {0: orig_box}
102
+ orig_boxes = stage1_linear(orig_keyboxes, T)
103
+ inpaint_masks = boxes_to_mask_sequence(orig_boxes, H, W)
104
+ # Refine frame 0 with SAM2 mask
105
+ inpaint_masks[0] = seg_mask
106
+ else:
107
+ # Fallback: use box directly as inpaint mask
108
+ orig_keyboxes = {0: orig_box}
109
+ orig_boxes = stage1_linear(orig_keyboxes, T)
110
+ inpaint_masks = boxes_to_mask_sequence(orig_boxes, H, W)
111
+
112
+ # ── Stage 2: Synthesize video ─────────────────────────────────
113
+ if isinstance(self.stage2, VACEWrapper):
114
+ result = self.stage2.synthesize(
115
+ original_frames=frames,
116
+ synthesis_masks=synthesis_masks,
117
+ inpaint_masks=inpaint_masks,
118
+ first_frame_ref=frames[0],
119
+ text_prompt=text_prompt
120
+ )
121
+ else:
122
+ # SimpleCompositeStage2: needs object crop
123
+ x1, y1, x2, y2 = [int(v) for v in orig_box]
124
+ obj_crop = frames[0, y1:y2, x1:x2]
125
+
126
+ if self.sam2 is not None:
127
+ obj_mask = seg_mask[y1:y2, x1:x2]
128
+ else:
129
+ obj_mask = np.ones(
130
+ (y2 - y1, x2 - x1), dtype=np.float32
131
+ )
132
+
133
+ result = self.stage2.synthesize(
134
+ original_frames=frames,
135
+ synthesis_masks=synthesis_masks,
136
+ inpaint_masks=inpaint_masks,
137
+ object_crop=obj_crop,
138
+ object_mask=obj_mask
139
+ )
140
+
141
+ # ── Save if path provided ─────────────────────────────────────
142
+ if output_path is not None:
143
+ save_video(result, output_path)
144
+ print(f"Saved to {output_path}")
145
+
146
+ return result
147
+
148
+ def run_object_insertion(
149
+ self,
150
+ video_path: str,
151
+ object_description: str,
152
+ keyboxes: dict, # {frame_idx: [x1, y1, x2, y2]}
153
+ text_prompt: str,
154
+ output_path: str = None,
155
+ frames: np.ndarray = None,
156
+ ) -> np.ndarray:
157
+ """
158
+ Insert a new object into the video and animate it along a trajectory.
159
+ Qwen-Image-Edit paints the object into frame 0 only.
160
+ Stage 2 propagates it through all frames.
161
+ """
162
+ if frames is None:
163
+ frames = load_video(video_path)
164
+ T, H, W, _ = frames.shape
165
+
166
+ # Stage 1: trajectory
167
+ pred_boxes = stage1_linear(keyboxes, T)
168
+
169
+ # Edit first frame with Qwen-Image-Edit
170
+ if self.qwen_edit_pipe is not None:
171
+ from frame_editor import insert_object_qwen_edit
172
+ edited_first_frame = insert_object_qwen_edit(
173
+ first_frame=frames[0],
174
+ box=pred_boxes[0],
175
+ object_description=object_description,
176
+ pipe=self.qwen_edit_pipe,
177
+ )
178
+ else:
179
+ print("Qwen-Image-Edit not available, using original first frame.")
180
+ edited_first_frame = frames[0]
181
+
182
+ # Synthesis masks: where to place object along trajectory
183
+ synthesis_masks = boxes_to_mask_sequence(pred_boxes, H, W)
184
+ # No inpaint masks needed — nothing to erase for insertion
185
+ inpaint_masks = np.zeros((T, H, W), dtype=np.uint8)
186
+
187
+ # Stage 2
188
+ if isinstance(self.stage2, VACEWrapper):
189
+ result = self.stage2.synthesize(
190
+ original_frames=frames,
191
+ synthesis_masks=synthesis_masks,
192
+ inpaint_masks=inpaint_masks,
193
+ first_frame_ref=edited_first_frame,
194
+ text_prompt=text_prompt,
195
+ )
196
+ else:
197
+ x1, y1, x2, y2 = [int(v) for v in pred_boxes[0]]
198
+ obj_crop = edited_first_frame[y1:y2, x1:x2]
199
+ obj_mask = np.ones((y2 - y1, x2 - x1), dtype=np.float32)
200
+
201
+ result = self.stage2.synthesize(
202
+ original_frames=frames,
203
+ synthesis_masks=synthesis_masks,
204
+ inpaint_masks=inpaint_masks,
205
+ object_crop=obj_crop,
206
+ object_mask=obj_mask,
207
+ )
208
+
209
+ if output_path is not None:
210
+ save_video(result, output_path)
211
+ print(f"Saved to {output_path}")
212
+
213
+ return result
214
+
pipeline_adapter.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline_adapter.py
2
+
3
+ import numpy as np
4
+ import tempfile
5
+ from utils.video_utils import load_video, save_video
6
+
7
+ import numpy as np
8
+ from skimage.metrics import peak_signal_noise_ratio, structural_similarity
9
+
10
+ def compute_psnr(original, result):
11
+ """Mean PSNR across all frames."""
12
+ scores = []
13
+ for f1, f2 in zip(original, result):
14
+ scores.append(peak_signal_noise_ratio(f1, f2, data_range=255))
15
+ return float(np.mean(scores))
16
+
17
+ def compute_ssim_video(original, result):
18
+ """Mean SSIM across all frames."""
19
+ scores = []
20
+ for f1, f2 in zip(original, result):
21
+ scores.append(structural_similarity(f1, f2, channel_axis=-1, data_range=255))
22
+ return float(np.mean(scores))
23
+
24
+ def compute_lpips_video(original, result, device="cuda"):
25
+ """Mean LPIPS across all frames (lower = better)."""
26
+ import torch
27
+ import lpips
28
+
29
+ loss_fn = lpips.LPIPS(net="alex").to(device)
30
+ scores = []
31
+
32
+ for f1, f2 in zip(original, result):
33
+ # Convert [H, W, 3] uint8 → [1, 3, H, W] float in [-1, 1]
34
+ t1 = torch.from_numpy(f1).permute(2, 0, 1).unsqueeze(0).float() / 127.5 - 1.0
35
+ t2 = torch.from_numpy(f2).permute(2, 0, 1).unsqueeze(0).float() / 127.5 - 1.0
36
+ t1, t2 = t1.to(device), t2.to(device)
37
+
38
+ with torch.no_grad():
39
+ score = loss_fn(t1, t2)
40
+ scores.append(score.item())
41
+
42
+ return float(np.mean(scores))
43
+
44
+
45
+ def extract_first_frame(video_path: str) -> np.ndarray:
46
+ frames = load_video(video_path, max_frames=1)
47
+ return frames[0]
48
+
49
+
50
+ def load_all_frames(video_path: str) -> np.ndarray:
51
+ return load_video(video_path, max_frames=81)
52
+
53
+
54
+ def run_pipeline_motion_edit(
55
+ video_path: str,
56
+ start_box: list,
57
+ end_box: list,
58
+ prompt: str,
59
+ stage1_method: str = "linear",
60
+ use_vace: bool = False,
61
+ progress_callback=None
62
+ ) -> tuple:
63
+ from pipeline import TRACEPrototype
64
+ from stage1_approx import stage1_linear, stage1_cotracker
65
+ # from evaluation.metrics import (
66
+ # compute_psnr, compute_ssim_video, compute_lpips_video
67
+ # )
68
+
69
+ if progress_callback:
70
+ progress_callback(0.1, "Loading video...")
71
+
72
+ frames = load_all_frames(video_path)
73
+ T, H, W, _ = frames.shape
74
+ keyboxes = {0: start_box, T - 1: end_box}
75
+
76
+ proto = TRACEPrototype(
77
+ use_vace=use_vace,
78
+ use_cotracker=(stage1_method == "cotracker")
79
+ )
80
+
81
+ if progress_callback:
82
+ progress_callback(0.3, "Computing trajectory...")
83
+
84
+ if stage1_method == "cotracker" and proto.cotracker is not None:
85
+ pred_boxes = stage1_cotracker(frames, keyboxes, proto.cotracker)
86
+ else:
87
+ pred_boxes = stage1_linear(keyboxes, T)
88
+
89
+ if progress_callback:
90
+ progress_callback(0.5, "Running video synthesis...")
91
+
92
+ result = proto.run_motion_edit(
93
+ video_path=video_path,
94
+ keyboxes=keyboxes,
95
+ text_prompt=prompt,
96
+ output_path=None
97
+ )
98
+
99
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
100
+ save_video(result, tmp.name)
101
+
102
+ if progress_callback:
103
+ progress_callback(0.9, "Computing metrics...")
104
+
105
+ psnr = compute_psnr(result, frames)
106
+ ssim = compute_ssim_video(result, frames)
107
+ lpips = compute_lpips_video(result, frames)
108
+
109
+ metrics_text = (
110
+ f"**Video Quality**\n"
111
+ f"- PSNR: {psnr:.2f} dB (TRACE paper: 20.48)\n"
112
+ f"- SSIM: {ssim:.3f} (TRACE paper: 0.71)\n"
113
+ f"- LPIPS: {lpips:.3f} (TRACE paper: 0.19)\n\n"
114
+ f"**Settings**\n"
115
+ f"- Stage 1: `{stage1_method}`\n"
116
+ f"- Frames: {T} | Resolution: {W}x{H}\n"
117
+ )
118
+
119
+ if progress_callback:
120
+ progress_callback(1.0, "Done!")
121
+
122
+ return tmp.name, result, pred_boxes, metrics_text
123
+
124
+
125
+ def run_pipeline_insertion(
126
+ video_path: str,
127
+ edited_first_frame: np.ndarray, # Qwen/FLUX output — already edited
128
+ start_box: list,
129
+ end_box: list,
130
+ prompt: str,
131
+ use_vace: bool = False,
132
+ progress_callback=None
133
+ ) -> tuple:
134
+ """
135
+ Run insertion pipeline using a pre-edited first frame.
136
+ The first frame has already been modified by Qwen or FLUX-Fill
137
+ before this function is called — this function handles
138
+ the trajectory + video synthesis steps only.
139
+ """
140
+ from pipeline import TRACEPrototype
141
+ from stage1_approx import stage1_linear
142
+ from stage2_vace import VACEWrapper, SimpleCompositeStage2
143
+ from utils.box_utils import boxes_to_mask_sequence
144
+ #from evaluation.metrics import compute_psnr, compute_ssim_video
145
+
146
+ if progress_callback:
147
+ progress_callback(0.1, "Loading video...")
148
+
149
+ frames = load_all_frames(video_path)
150
+ T, H, W, _ = frames.shape
151
+ keyboxes = {0: start_box, T - 1: end_box}
152
+
153
+ if progress_callback:
154
+ progress_callback(0.3, "Computing trajectory...")
155
+
156
+ # Stage 1: interpolate trajectory
157
+ # (cotracker optional — linear fine for insertion prototype)
158
+ pred_boxes = stage1_linear(keyboxes, T)
159
+
160
+ # Build masks
161
+ synthesis_masks = boxes_to_mask_sequence(pred_boxes, H, W)
162
+ # No inpainting mask — object wasn't in original video
163
+ inpaint_masks = np.zeros_like(synthesis_masks)
164
+
165
+ if progress_callback:
166
+ progress_callback(0.5, "Running video synthesis...")
167
+
168
+ if use_vace:
169
+ stage2 = VACEWrapper()
170
+ result = stage2.synthesize(
171
+ original_frames=frames,
172
+ synthesis_masks=synthesis_masks,
173
+ inpaint_masks=inpaint_masks,
174
+ first_frame_ref=edited_first_frame, # ← Qwen-edited frame
175
+ text_prompt=prompt
176
+ )
177
+ else:
178
+ # Debug mode: simple alpha compositing
179
+ stage2 = SimpleCompositeStage2()
180
+ x1, y1, x2, y2 = [int(v) for v in start_box]
181
+ obj_crop = edited_first_frame[y1:y2, x1:x2]
182
+
183
+ # Build object mask from non-black pixels in crop
184
+ obj_mask = (obj_crop.sum(axis=2) > 10).astype(np.float32)
185
+
186
+ result = stage2.synthesize(
187
+ original_frames=frames,
188
+ synthesis_masks=synthesis_masks,
189
+ inpaint_masks=inpaint_masks,
190
+ object_crop=obj_crop,
191
+ object_mask=obj_mask
192
+ )
193
+
194
+ if progress_callback:
195
+ progress_callback(0.9, "Saving output...")
196
+
197
+ tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
198
+ save_video(result, tmp.name)
199
+
200
+ psnr = compute_psnr(result, frames)
201
+ ssim = compute_ssim_video(result, frames)
202
+
203
+ metrics_text = (
204
+ f"**Insertion Result**\n"
205
+ f"- PSNR: {psnr:.2f} dB\n"
206
+ f"- SSIM: {ssim:.3f}\n\n"
207
+ f"**Settings**\n"
208
+ f"- First frame editor: Qwen/FLUX (run separately)\n"
209
+ f"- VACE synthesis: {'on' if use_vace else 'off (debug mode)'}\n"
210
+ f"- Frames: {T} | Resolution: {W}x{H}\n"
211
+ )
212
+
213
+ if progress_callback:
214
+ progress_callback(1.0, "Done!")
215
+
216
+ return tmp.name, result, pred_boxes, metrics_text
preview.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo/preview.py
2
+ import numpy as np
3
+ from visualizer import draw_box_on_frame, draw_trajectory_on_frame
4
+ from utils.box_utils import interpolate_boxes
5
+
6
+ def preview_trajectory(
7
+ first_frame: np.ndarray, # [H, W, 3]
8
+ start_box: list, # [x1, y1, x2, y2]
9
+ end_box: list, # [x1, y1, x2, y2]
10
+ num_frames: int = 81
11
+ ) -> np.ndarray:
12
+ """
13
+ Shows the planned trajectory on the first frame BEFORE running.
14
+ User sees this immediately after drawing boxes — fast feedback.
15
+ """
16
+ keyboxes = {0: start_box, num_frames - 1: end_box}
17
+ boxes = interpolate_boxes(keyboxes, num_frames)
18
+
19
+ frame = first_frame.copy()
20
+
21
+ # Draw full trajectory path (center points)
22
+ centers = np.stack([
23
+ (boxes[:, 0] + boxes[:, 2]) / 2,
24
+ (boxes[:, 1] + boxes[:, 3]) / 2
25
+ ], axis=1).astype(int)
26
+
27
+ for i in range(1, len(centers)):
28
+ alpha = i / len(centers)
29
+ color = (
30
+ int(255 * (1 - alpha)),
31
+ int(200 * alpha),
32
+ 255
33
+ )
34
+ import cv2
35
+ cv2.line(frame,
36
+ tuple(centers[i-1]),
37
+ tuple(centers[i]),
38
+ color, 2)
39
+
40
+ # Draw start box (solid yellow)
41
+ frame = draw_box_on_frame(
42
+ frame, start_box,
43
+ color=(255, 220, 0),
44
+ label="START",
45
+ dashed=False
46
+ )
47
+
48
+ # Draw end box (dashed yellow)
49
+ frame = draw_box_on_frame(
50
+ frame, end_box,
51
+ color=(255, 220, 0),
52
+ label="END",
53
+ dashed=True
54
+ )
55
+
56
+ # Draw a few intermediate boxes (faded)
57
+ for i in [20, 40, 60]:
58
+ if i < len(boxes):
59
+ frame = draw_box_on_frame(
60
+ frame, boxes[i],
61
+ color=(200, 200, 200),
62
+ label=f"t={i}",
63
+ dashed=True,
64
+ thickness=1
65
+ )
66
+
67
+ return frame
68
+
69
+
70
+ def preview_trajectory_strip(
71
+ frames: np.ndarray, # [T, H, W, 3]
72
+ start_box: list,
73
+ end_box: list,
74
+ ) -> np.ndarray:
75
+ """
76
+ Shows predicted box overlaid on 5 sampled frames.
77
+ Gives sense of how box moves through the video.
78
+ """
79
+ T = len(frames)
80
+ keyboxes = {0: start_box, T - 1: end_box}
81
+ boxes = interpolate_boxes(keyboxes, T)
82
+
83
+ sample_ts = [0, T//4, T//2, 3*T//4, T-1]
84
+ previews = []
85
+
86
+ for t in sample_ts:
87
+ frame = frames[t].copy()
88
+ frame = draw_box_on_frame(
89
+ frame, boxes[t],
90
+ color=(0, 255, 255),
91
+ label=f"t={t}",
92
+ dashed=(t > 0)
93
+ )
94
+ # Add small frame counter
95
+ import cv2
96
+ H, W = frame.shape[:2]
97
+ progress = f"{t}/{T-1}"
98
+ cv2.putText(frame, progress, (W-80, H-10),
99
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5,
100
+ (200, 200, 200), 1)
101
+ previews.append(frame)
102
+
103
+ return np.concatenate(previews, axis=1) # horizontal strip
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # requirements.txt
2
+ torch>=2.1.0
3
+ torchvision
4
+ transformers>=4.40.0
5
+ git+https://github.com/huggingface/diffusers.git
6
+ torchao==0.11.0
7
+ peft
8
+ sentencepiece
9
+ opencv-python
10
+ numpy
11
+ scipy
12
+ Pillow
13
+ imageio[ffmpeg]
14
+ einops
15
+ transformers
16
+ accelerate
17
+
18
+ # Install separately (need git clone):
19
+ # CoTracker3: github.com/facebookresearch/co-tracker
20
+ # SAM2: github.com/facebookresearch/segment-anything-2
21
+ # VACE: github.com/ali-vilab/VACE
22
+ # DA-v3: github.com/DepthAnything/Depth-Anything-V3
stage1_approx.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stage1_approx.py
2
+ import numpy as np
3
+ import torch
4
+ from utils.box_utils import interpolate_boxes
5
+
6
+ # ── Option A: Pure Linear Interpolation ─────────────────────────────
7
+ # Best for: static camera or very slow camera movement
8
+ # Worst for: fast pans, zoom, handheld footage
9
+
10
+ def stage1_linear(
11
+ keyboxes: dict,
12
+ num_frames: int
13
+ ) -> np.ndarray:
14
+ """
15
+ Simplest possible Stage 1 substitute.
16
+ keyboxes: {frame_idx: [x1, y1, x2, y2]}
17
+ Returns: [T, 4] box sequence
18
+ """
19
+ return interpolate_boxes(keyboxes, num_frames, method="linear")
20
+
21
+
22
+ # ── Option B: DA-v3 Depth Warping ───────────────────────────────────
23
+ # Better for: moderate camera motion
24
+ # From Table 7: IoU=0.79, mAP=0.73 (vs TRACE 0.80, 0.91)
25
+ # Requires: DepthAnything-v3 + MegaSAM or RAFT optical flow
26
+
27
+ def stage1_depth_warp(
28
+ frames: np.ndarray, # [T, H, W, 3]
29
+ keyboxes: dict,
30
+ depth_model,
31
+ flow_model=None
32
+ ) -> np.ndarray:
33
+ """
34
+ Project first-frame boxes to subsequent frames using depth + flow.
35
+ """
36
+ T, H, W, _ = frames.shape
37
+ first_frame = frames[0]
38
+
39
+ # Get depth for all frames
40
+ depths = []
41
+ for frame in frames:
42
+ d = depth_model.infer(frame) # [H, W] depth map
43
+ depths.append(d)
44
+ depths = np.stack(depths) # [T, H, W]
45
+
46
+ # Get first-frame depth at box center
47
+ result_boxes = np.zeros((T, 4))
48
+ for frame_idx, box in keyboxes.items():
49
+ result_boxes[frame_idx] = box
50
+
51
+ # For each unspecified frame, warp from nearest keybox
52
+ keyframe_ids = sorted(keyboxes.keys())
53
+
54
+ for t in range(T):
55
+ if t in keyboxes:
56
+ continue
57
+
58
+ # Find nearest keyframe
59
+ nearest_key = min(keyframe_ids, key=lambda k: abs(k - t))
60
+ ref_box = keyboxes[nearest_key]
61
+ ref_depth = depths[nearest_key]
62
+ tgt_depth = depths[t]
63
+
64
+ # Get depth at box center in reference frame
65
+ cx_ref = (ref_box[0] + ref_box[2]) / 2
66
+ cy_ref = (ref_box[1] + ref_box[3]) / 2
67
+ cx_ref_i, cy_ref_i = int(cx_ref), int(cy_ref)
68
+ d_ref = ref_depth[cy_ref_i, cx_ref_i]
69
+
70
+ # Use optical flow if available for center displacement
71
+ if flow_model is not None:
72
+ flow = flow_model.compute(
73
+ frames[nearest_key], frames[t]
74
+ ) # [H, W, 2]
75
+ dx = flow[cy_ref_i, cx_ref_i, 0]
76
+ dy = flow[cy_ref_i, cx_ref_i, 1]
77
+ else:
78
+ dx, dy = 0, 0
79
+
80
+ # Warp center
81
+ cx_tgt = cx_ref + dx
82
+ cy_tgt = cy_ref + dy
83
+
84
+ # Scale box size by depth ratio
85
+ d_tgt = tgt_depth[int(cy_tgt), int(cx_tgt)]
86
+ scale = d_ref / (d_tgt + 1e-6)
87
+ bw = (ref_box[2] - ref_box[0]) * scale
88
+ bh = (ref_box[3] - ref_box[1]) * scale
89
+
90
+ result_boxes[t] = [
91
+ cx_tgt - bw/2, cy_tgt - bh/2,
92
+ cx_tgt + bw/2, cy_tgt + bh/2
93
+ ]
94
+
95
+ # Fill any remaining gaps with interpolation
96
+ specified = {i: result_boxes[i] for i in keyframe_ids}
97
+ return interpolate_boxes(specified, T, method="linear")
98
+
99
+
100
+ # ── Option C: CoTracker-Assisted Warping ────────────────────────────
101
+ # Best for: fast camera, most accurate without training
102
+ # Uses background point tracks to estimate camera motion
103
+
104
+ def stage1_cotracker(
105
+ frames: np.ndarray, # [T, H, W, 3]
106
+ keyboxes: dict,
107
+ cotracker_model
108
+ ) -> np.ndarray:
109
+ """
110
+ Use CoTracker point tracks to estimate camera motion,
111
+ then warp keyboxes accordingly.
112
+ """
113
+ import torch
114
+ T, H, W, _ = frames.shape
115
+
116
+ # Build grid of background query points (avoid object region)
117
+ first_box = list(keyboxes.values())[0]
118
+
119
+ # Sample 100 background points (outside object box)
120
+ bg_points = _sample_background_points(
121
+ H, W, first_box, n_points=100
122
+ ) # [100, 2] (x, y)
123
+
124
+ # Track them across all frames
125
+ video_tensor = torch.from_numpy(frames).float()
126
+ video_tensor = video_tensor.permute(0, 3, 1, 2).unsqueeze(0)
127
+ # [1, T, 3, H, W]
128
+
129
+ queries = torch.zeros(1, len(bg_points), 3)
130
+ queries[0, :, 0] = 0 # query at frame 0
131
+ queries[0, :, 1] = torch.from_numpy(bg_points[:, 0]) # x
132
+ queries[0, :, 2] = torch.from_numpy(bg_points[:, 1]) # y
133
+
134
+ with torch.no_grad():
135
+ tracks, visibility = cotracker_model(
136
+ video_tensor, queries=queries
137
+ )
138
+ # tracks: [1, T, N_points, 2]
139
+ tracks = tracks[0].numpy() # [T, N, 2]
140
+
141
+ # Estimate per-frame homography from background tracks
142
+ result_boxes = np.zeros((T, 4))
143
+ ref_points = tracks[0] # [N, 2] at frame 0
144
+
145
+ for t in range(T):
146
+ if t in keyboxes:
147
+ result_boxes[t] = keyboxes[t]
148
+ continue
149
+
150
+ # Find nearest keyframe
151
+ nearest_key = min(keyboxes.keys(), key=lambda k: abs(k-t))
152
+ ref_box = keyboxes[nearest_key]
153
+
154
+ # Estimate transformation from nearest keyframe to frame t
155
+ src_pts = tracks[nearest_key] # [N, 2]
156
+ dst_pts = tracks[t] # [N, 2]
157
+
158
+ import cv2
159
+ H_mat, mask = cv2.findHomography(
160
+ src_pts, dst_pts, cv2.RANSAC, 5.0
161
+ )
162
+
163
+ if H_mat is None:
164
+ result_boxes[t] = ref_box
165
+ continue
166
+
167
+ # Warp box corners through homography
168
+ corners = np.array([
169
+ [ref_box[0], ref_box[1]],
170
+ [ref_box[2], ref_box[1]],
171
+ [ref_box[2], ref_box[3]],
172
+ [ref_box[0], ref_box[3]]
173
+ ], dtype=np.float32).reshape(-1, 1, 2)
174
+
175
+ warped = cv2.perspectiveTransform(corners, H_mat)
176
+ warped = warped.reshape(-1, 2)
177
+
178
+ result_boxes[t] = [
179
+ warped[:, 0].min(), warped[:, 1].min(),
180
+ warped[:, 0].max(), warped[:, 1].max()
181
+ ]
182
+
183
+ return result_boxes
184
+
185
+
186
+ def _sample_background_points(H, W, object_box, n_points=100):
187
+ """Sample points outside the object bounding box"""
188
+ x1, y1, x2, y2 = object_box
189
+ points = []
190
+ attempts = 0
191
+ while len(points) < n_points and attempts < n_points * 10:
192
+ x = np.random.randint(0, W)
193
+ y = np.random.randint(0, H)
194
+ if not (x1 <= x <= x2 and y1 <= y <= y2):
195
+ points.append([x, y])
196
+ attempts += 1
197
+ return np.array(points, dtype=np.float32)
stage2_vace.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stage2_vace.py
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+
6
+ class VACEWrapper:
7
+ def __init__(self, device="cuda"):
8
+ from diffusers import WanImageToVideoPipeline
9
+ from diffusers.utils import export_to_video
10
+ import torch
11
+
12
+ self.device = device
13
+ self.pipe = WanImageToVideoPipeline.from_pretrained(
14
+ "Wan-AI/Wan2.1-VACE-1.3B-diffusers",
15
+ torch_dtype=torch.bfloat16,
16
+ ).to(device)
17
+ self.pipe.enable_model_cpu_offload()
18
+
19
+
20
+ def synthesize(
21
+ self,
22
+ original_frames,
23
+ synthesis_masks,
24
+ inpaint_masks,
25
+ first_frame_ref,
26
+ text_prompt="",
27
+ ):
28
+ import numpy as np
29
+ import cv2
30
+ import torch
31
+ from PIL import Image
32
+
33
+ T, orig_H, orig_W, _ = original_frames.shape
34
+
35
+ # Round to nearest multiple of 16 (VACE requirement)
36
+ H = (orig_H // 16) * 16
37
+ W = (orig_W // 16) * 16
38
+
39
+ if H != orig_H or W != orig_W:
40
+ original_frames = np.stack([cv2.resize(f, (W, H)) for f in original_frames])
41
+ first_frame_ref = cv2.resize(first_frame_ref, (W, H))
42
+ synthesis_masks = np.stack([
43
+ cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) for m in synthesis_masks
44
+ ])
45
+ inpaint_masks = np.stack([
46
+ cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) for m in inpaint_masks
47
+ ])
48
+
49
+ video_pil = [Image.fromarray(f) for f in original_frames]
50
+ combined = np.clip(
51
+ synthesis_masks.astype(np.uint16) + inpaint_masks.astype(np.uint16), 0, 255
52
+ ).astype(np.uint8)
53
+ mask_pil = [Image.fromarray(m) for m in combined]
54
+ ref_pil = Image.fromarray(first_frame_ref)
55
+
56
+ output = self.pipe(
57
+ video=video_pil,
58
+ mask=mask_pil,
59
+ prompt=text_prompt,
60
+ negative_prompt="static, blurry, low quality",
61
+ reference_images=[ref_pil],
62
+ num_frames=T,
63
+ height=H,
64
+ width=W,
65
+ guidance_scale=5.0,
66
+ num_inference_steps=25,
67
+ ).frames[0]
68
+
69
+ result = np.stack([np.array(f) for f in output], axis=0)
70
+
71
+ # Restore original resolution
72
+ if orig_H != H or orig_W != W:
73
+ result = np.stack([cv2.resize(f, (orig_W, orig_H)) for f in result])
74
+
75
+ return result
76
+
77
+
78
+
79
+
80
+ class SimpleCompositeStage2:
81
+ """
82
+ Fallback Stage 2: simple alpha compositing.
83
+ No diffusion model needed.
84
+ Works for: clean background, simple objects.
85
+ Quality: low but fast for debugging the pipeline.
86
+ """
87
+
88
+ def synthesize(
89
+ self,
90
+ original_frames: np.ndarray, # [T, H, W, 3]
91
+ synthesis_masks: np.ndarray, # [T, H, W]
92
+ inpaint_masks: np.ndarray, # [T, H, W]
93
+ object_crop: np.ndarray, # [H_obj, W_obj, 3]
94
+ object_mask: np.ndarray, # [H_obj, W_obj] binary
95
+ ) -> np.ndarray:
96
+ """
97
+ Composite object into new positions using simple alpha blending.
98
+ Useful for validating box trajectory before diffusion.
99
+ """
100
+ import cv2
101
+
102
+ T, H, W, _ = original_frames.shape
103
+ result = original_frames.copy()
104
+
105
+ for t in range(T):
106
+ # Find box from synthesis mask
107
+ mask_t = synthesis_masks[t]
108
+ ys, xs = np.where(mask_t > 0.5)
109
+ if len(ys) == 0:
110
+ continue
111
+
112
+ y1, y2 = ys.min(), ys.max()
113
+ x1, x2 = xs.min(), xs.max()
114
+ bh, bw = y2 - y1, x2 - x1
115
+
116
+ if bh <= 0 or bw <= 0:
117
+ continue
118
+
119
+ # Resize object to target box size
120
+ obj_resized = cv2.resize(
121
+ object_crop, (bw, bh),
122
+ interpolation=cv2.INTER_LINEAR
123
+ )
124
+ mask_resized = cv2.resize(
125
+ object_mask.astype(np.float32), (bw, bh),
126
+ interpolation=cv2.INTER_LINEAR
127
+ )
128
+ mask_3ch = mask_resized[:, :, None]
129
+
130
+ # Erase original position (simple fill with nearby bg)
131
+ erase_mask = inpaint_masks[t]
132
+ if erase_mask.sum() > 0:
133
+ result[t] = _inpaint_simple(result[t], erase_mask)
134
+
135
+ # Composite object at new position
136
+ roi = result[t, y1:y2, x1:x2]
137
+ result[t, y1:y2, x1:x2] = (
138
+ obj_resized * mask_3ch + roi * (1 - mask_3ch)
139
+ ).astype(np.uint8)
140
+
141
+ return result
142
+
143
+
144
+ def _inpaint_simple(frame: np.ndarray, mask: np.ndarray) -> np.ndarray:
145
+ """Simple telea inpainting for object removal"""
146
+ import cv2
147
+ mask_uint8 = (mask * 255).astype(np.uint8)
148
+ return cv2.inpaint(frame, mask_uint8, 3, cv2.INPAINT_TELEA)
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (183 Bytes). View file
 
utils/__pycache__/box_utils.cpython-312.pyc ADDED
Binary file (3.2 kB). View file
 
utils/__pycache__/video_utils.cpython-312.pyc ADDED
Binary file (2.6 kB). View file
 
utils/box_utils.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/box_utils.py
2
+ import numpy as np
3
+ from scipy.interpolate import interp1d
4
+
5
+ def interpolate_boxes(
6
+ keyboxes: dict, # {frame_idx: [x1, y1, x2, y2]}
7
+ num_frames: int,
8
+ method: str = "linear" # "linear" or "cubic"
9
+ ) -> np.ndarray:
10
+ """
11
+ Interpolate sparse keyboxes to dense per-frame boxes.
12
+ Returns: [T, 4] float32
13
+ """
14
+ frame_ids = sorted(keyboxes.keys())
15
+ boxes = np.array([keyboxes[i] for i in frame_ids], dtype=np.float32)
16
+
17
+ # Interpolate each coordinate separately
18
+ result = np.zeros((num_frames, 4), dtype=np.float32)
19
+ t_query = np.arange(num_frames)
20
+
21
+ for coord in range(4):
22
+ f = interp1d(
23
+ frame_ids,
24
+ boxes[:, coord],
25
+ kind=method,
26
+ fill_value="extrapolate"
27
+ )
28
+ result[:, coord] = f(t_query)
29
+
30
+ return result.clip(0, None) # boxes can't be negative
31
+
32
+ def box_to_mask(
33
+ box: np.ndarray, # [x1, y1, x2, y2]
34
+ H: int,
35
+ W: int
36
+ ) -> np.ndarray:
37
+ """
38
+ Convert bounding box to binary mask [H, W]
39
+ """
40
+ mask = np.zeros((H, W), dtype=np.float32)
41
+ x1, y1, x2, y2 = box.astype(int)
42
+ x1, x2 = np.clip([x1, x2], 0, W)
43
+ y1, y2 = np.clip([y1, y2], 0, H)
44
+ mask[y1:y2, x1:x2] = 1.0
45
+ return mask
46
+
47
+ def boxes_to_mask_sequence(
48
+ boxes: np.ndarray, # [T, 4]
49
+ H: int,
50
+ W: int
51
+ ) -> np.ndarray:
52
+ """
53
+ Returns: [T, H, W] binary masks
54
+ """
55
+ T = len(boxes)
56
+ masks = np.zeros((T, H, W), dtype=np.float32)
57
+ for t, box in enumerate(boxes):
58
+ masks[t] = box_to_mask(box, H, W)
59
+ return masks
60
+
61
+ def expand_box(box: np.ndarray, padding: int = 10) -> np.ndarray:
62
+ """Expand box by padding pixels on each side"""
63
+ x1, y1, x2, y2 = box
64
+ return np.array([x1 - padding, y1 - padding,
65
+ x2 + padding, y2 + padding])
utils/video_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils/video_utils.py
2
+ import cv2
3
+ import numpy as np
4
+ import imageio
5
+ import torch
6
+
7
+ def load_video(path: str, max_frames: int = 81) -> np.ndarray:
8
+ """
9
+ Returns: [T, H, W, 3] uint8 RGB array
10
+ """
11
+ cap = cv2.VideoCapture(path)
12
+ frames = []
13
+ while len(frames) < max_frames:
14
+ ret, frame = cap.read()
15
+ if not ret:
16
+ break
17
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
18
+ cap.release()
19
+ return np.stack(frames)
20
+
21
+ def save_video(frames: np.ndarray, path: str, fps: int = 24):
22
+ """
23
+ frames: [T, H, W, 3] uint8 RGB
24
+ """
25
+ writer = imageio.get_writer(path, fps=fps)
26
+ for frame in frames:
27
+ writer.append_data(frame)
28
+ writer.close()
29
+
30
+ def frames_to_tensor(frames: np.ndarray) -> torch.Tensor:
31
+ """
32
+ [T, H, W, 3] uint8 → [T, 3, H, W] float32 in [-1, 1]
33
+ """
34
+ t = torch.from_numpy(frames).float() / 127.5 - 1.0
35
+ return t.permute(0, 3, 1, 2)
36
+
37
+ def tensor_to_frames(t: torch.Tensor) -> np.ndarray:
38
+ """
39
+ [T, 3, H, W] float32 in [-1, 1] → [T, H, W, 3] uint8
40
+ """
41
+ t = ((t + 1.0) * 127.5).clamp(0, 255)
42
+ return t.permute(0, 2, 3, 1).byte().numpy()
visualizer.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo/visualizer.py
2
+ import numpy as np
3
+ import cv2
4
+ from typing import Optional
5
+
6
+ def draw_box_on_frame(
7
+ frame: np.ndarray, # [H, W, 3] uint8 RGB
8
+ box: list, # [x1, y1, x2, y2]
9
+ color: tuple = (255, 255, 0),
10
+ label: str = "",
11
+ thickness: int = 2,
12
+ dashed: bool = False
13
+ ) -> np.ndarray:
14
+ """Draw a single bounding box on a frame"""
15
+ frame = frame.copy()
16
+ x1, y1, x2, y2 = [int(v) for v in box]
17
+
18
+ if dashed:
19
+ # Draw dashed rectangle manually
20
+ dash_len = 10
21
+ gap_len = 5
22
+ pts = [
23
+ ((x1, y1), (x2, y1)), # top
24
+ ((x2, y1), (x2, y2)), # right
25
+ ((x2, y2), (x1, y2)), # bottom
26
+ ((x1, y2), (x1, y1)), # left
27
+ ]
28
+ for (px1, py1), (px2, py2) in pts:
29
+ dx = px2 - px1
30
+ dy = py2 - py1
31
+ dist = max(abs(dx), abs(dy))
32
+ if dist == 0:
33
+ continue
34
+ for i in range(0, dist, dash_len + gap_len):
35
+ s = i / dist
36
+ e = min(i + dash_len, dist) / dist
37
+ sx = int(px1 + s * dx)
38
+ sy = int(py1 + s * dy)
39
+ ex = int(px1 + e * dx)
40
+ ey = int(py1 + e * dy)
41
+ cv2.line(frame, (sx, sy), (ex, ey), color, thickness)
42
+ else:
43
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, thickness)
44
+
45
+ if label:
46
+ cv2.putText(
47
+ frame, label,
48
+ (x1, max(y1 - 8, 12)),
49
+ cv2.FONT_HERSHEY_SIMPLEX,
50
+ 0.6, color, 2
51
+ )
52
+
53
+ return frame
54
+
55
+
56
+ def draw_trajectory_on_frame(
57
+ frame: np.ndarray,
58
+ boxes: np.ndarray, # [T, 4] — full trajectory
59
+ current_t: int,
60
+ color: tuple = (255, 200, 0)
61
+ ) -> np.ndarray:
62
+ """
63
+ Draw the motion path (center points) up to current frame.
64
+ Gives a visual "trail" showing where the object came from.
65
+ """
66
+ frame = frame.copy()
67
+ centers = np.stack([
68
+ (boxes[:, 0] + boxes[:, 2]) / 2,
69
+ (boxes[:, 1] + boxes[:, 3]) / 2
70
+ ], axis=1).astype(int)
71
+
72
+ # Draw path line
73
+ for i in range(1, current_t + 1):
74
+ alpha = i / (current_t + 1) # fade older points
75
+ c = tuple(int(v * alpha) for v in color)
76
+ cv2.line(
77
+ frame,
78
+ tuple(centers[i-1]),
79
+ tuple(centers[i]),
80
+ c, 2
81
+ )
82
+
83
+ # Draw current center dot
84
+ cv2.circle(frame, tuple(centers[current_t]), 5, color, -1)
85
+
86
+ return frame
87
+
88
+
89
+ def create_comparison_strip(
90
+ original: np.ndarray, # [T, H, W, 3]
91
+ result: np.ndarray, # [T, H, W, 3]
92
+ pred_boxes: np.ndarray, # [T, 4]
93
+ sample_ts: list = None # which frames to show
94
+ ) -> np.ndarray:
95
+ """
96
+ Creates a horizontal strip for visual comparison.
97
+ Shows: Original | Result | Diff for N sampled frames.
98
+ """
99
+ T = len(original)
100
+ if sample_ts is None:
101
+ sample_ts = [0, T//4, T//2, 3*T//4, T-1]
102
+
103
+ rows = []
104
+ for t in sample_ts:
105
+ orig_t = original[t].copy()
106
+ res_t = result[t].copy()
107
+
108
+ # Draw box on result
109
+ res_t = draw_box_on_frame(
110
+ res_t, pred_boxes[t],
111
+ color=(0, 255, 0),
112
+ label=f"t={t}"
113
+ )
114
+
115
+ # Amplified diff
116
+ diff_t = np.abs(
117
+ orig_t.astype(np.int32) - result[t].astype(np.int32)
118
+ )
119
+ diff_t = (diff_t * 4).clip(0, 255).astype(np.uint8)
120
+
121
+ # Add labels
122
+ def add_label(img, text):
123
+ img = img.copy()
124
+ cv2.putText(img, text, (10, 25),
125
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7,
126
+ (255, 255, 255), 2)
127
+ cv2.putText(img, text, (10, 25),
128
+ cv2.FONT_HERSHEY_SIMPLEX, 0.7,
129
+ (0, 0, 0), 1)
130
+ return img
131
+
132
+ orig_t = add_label(orig_t, "Original")
133
+ res_t = add_label(res_t, "Result")
134
+ diff_t = add_label(diff_t, "Diff x4")
135
+
136
+ row = np.concatenate([orig_t, res_t, diff_t], axis=1)
137
+ rows.append(row)
138
+
139
+ return np.concatenate(rows, axis=0)