| # 修改清单(前后对比) |
|
|
| ## 范围 |
| 只做“能跑”的最小修复,尽量保留原本逻辑与结构。 |
|
|
| ## 1) `multi-shot/multi_view/datasets/videodataset.py` |
| **补齐未定义变量,保持原返回结构** |
| |
| **Before** |
| ```python |
| return { |
| "global_caption": None, |
| "shot_num": 3, |
| "pre_shot_caption": ["xxx", "xxx", "xxx"], |
| # "single_caption": meta_prompt["single_prompt"], |
| "video": input_video, |
| "ref_num": ID_num * 3, ###TODO: 先跑通 ID_num = 1 的情况 |
| "ID_num": ID_num, |
| "ref_images": [[Image0, Image1, Image2]], |
| "video_path": video_path |
| } |
| ``` |
| |
| **After** |
| ```python |
| ID_num = 1 |
| Image0, Image1, Image2 = ref_images[:3] |
| return { |
| "global_caption": None, |
| "shot_num": 3, |
| "pre_shot_caption": ["xxx", "xxx", "xxx"], |
| # "single_caption": meta_prompt["single_prompt"], |
| "video": input_video, |
| "ref_num": ID_num * 3, ###TODO: 先跑通 ID_num = 1 的情况 |
| "ID_num": ID_num, |
| "ref_images": [[Image0, Image1, Image2]], |
| "video_path": video_path |
| } |
| ``` |
|
|
| ## 2) `multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/pipelines/wan_video_new.py` |
| ### 2.1 Prompt 编码(修复拼写/对象调用) |
| **Before** |
| ```python |
| prompt = pip.text_encoder.process_prompt(prompt, positive=positive) |
| output = pip.text_encoder.tokenizer(prompt, return_mask=True, add_special_tokens=True) |
| ids = output['input_ids'].to(device) |
| mask = output['attention_mask'].to(device) |
| prompt_emb = self.text_encoder(ids, mask) |
| ... |
| prompt_shot_all = pip.text_encoder.process_prompt(prompt_shot_all, positive=positive) |
| ... |
| for shot_index, shot_cut_end in enmurate(shot_cut_ends): |
| start_pos = shot_cut_starts[shot_index] |
| end_pos = shot_cut_end |
| shot_text = cleaned_prompt[start_pos: end_pos + 1].strip() |
| ``` |
| |
| **After** |
| ```python |
| prompt = pipe.text_encoder.process_prompt(prompt, positive=positive) |
| output = pipe.text_encoder.tokenizer(prompt, return_mask=True, add_special_tokens=True) |
| ids = output['input_ids'].to(device) |
| mask = output['attention_mask'].to(device) |
| prompt_emb = pipe.text_encoder(ids, mask) |
| ... |
| prompt_shot_all = pipe.text_encoder.process_prompt(prompt_shot_all, positive=positive) |
| cleaned_prompt = prompt_shot_all |
| ... |
| for shot_index, shot_cut_end in enumerate(shot_cut_ends): |
| start_pos = shot_cut_starts[shot_index] |
| end_pos = shot_cut_end |
| shot_text = cleaned_prompt[start_pos: end_pos + 1].strip() |
| ``` |
|
|
| ### 2.2 Shot mask 构造(修复未定义变量) |
| **Before** |
| ```python |
| S_shots = len(shot_text_ranges[0]) ###TODO: 当前batch size 是 1 |
| ... |
| for sid, (s0, s1) in enumerate(shot_ranges): |
| s0 = int(s0) |
| s1 = int(s1) |
| shot_table[sid, s0: s1 + 1] = True |
| ... |
| allow_all = torch.cat([allow_shot, allow_ref_image], dim = 1) |
| assert allow_all.shape == x.shape[2] "The shape is something wrong" |
| ``` |
|
|
| **After** |
| ```python |
| shot_ranges = shot_text_ranges[0] |
| if isinstance(shot_ranges, dict): |
| shot_ranges = shot_ranges.get("shots", []) |
| S_shots = len(shot_ranges) |
| for sid, span in enumerate(shot_ranges): |
| if span is None: |
| continue |
| s0, s1 = span |
| s0 = int(s0) |
| s1 = int(s1) |
| shot_table[sid, s0: s1 + 1] = True |
| ... |
| allow_all = torch.cat([allow_shot, allow_ref_image], dim = 1) |
| assert allow_all.shape[1] == S_q, "The shape is something wrong" |
| ``` |
|
|
| ### 2.3 `shot_rope` 分支变量名冲突修复 |
| **Before** |
| ```python |
| for shot_index, num_frames in enumerate(shots_nums): |
| f = num_frames |
| rope_s = freq_s[shot_index] \ |
| .view(1, 1, 1, -1) \ |
| .expand(f, h, w, -1) |
| ... |
| freqs = freqs.reshape(f * h * w, 1, -1) |
| ``` |
| |
| **After** |
| ```python |
| for shot_index, num_frames in enumerate(shots_nums): |
| f = num_frames |
| rope_s = freq_s[shot_index].view(1, 1, 1, -1).expand(f, h, w, -1) |
| ... |
| freqs = freqs.reshape(f * h * w, 1, -1) |
| ``` |
|
|
| ### 2.4 `model_fn_wan_video` 函数签名语法修复 |
| **Before** |
| ```python |
| ID_2_shot: None ######每个shot 中对应包含的ID是那几个,是一个list[ batch0: [shot0: [0,1], shot1:[2]], batch1:[]] |
| **kwargs, |
| ``` |
| |
| **After** |
| ```python |
| ID_2_shot=None, ######每个shot 中对应包含的ID是那几个,是一个list[ batch0: [shot0: [0,1], shot1:[2]], batch1:[]] |
| **kwargs, |
| ``` |
| |
| ### 2.5 `WanVideoUnit_SpeedControl` 缺失类补齐 |
| **Before** |
| ```python |
| WanVideoUnit_SpeedControl(), # 在 units 列表中引用,但类未定义 |
| ``` |
| |
| **After** |
| ```python |
| class WanVideoUnit_SpeedControl(PipelineUnit): |
| def __init__(self): |
| super().__init__(input_params=("motion_bucket_id",)) |
|
|
| def process(self, pipe: WanVideoPipeline, motion_bucket_id): |
| if motion_bucket_id is None: |
| return {} |
| motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) |
| return {"motion_bucket_id": motion_bucket_id} |
| ``` |
| |
| ### 2.6 Prompt 处理使用 prompter(修复 `process_prompt` 缺失) |
| **Before** |
| ```python |
| prompt = pipe.text_encoder.process_prompt(prompt, positive=positive) |
| output = pipe.text_encoder.tokenizer(prompt, return_mask=True, add_special_tokens=True) |
| ... |
| prompt_shot_all = pipe.text_encoder.process_prompt(prompt_shot_all, positive=positive) |
| ... |
| enc_output = pipe.text_encoder( |
| text, |
| return_mask=True, |
| add_special_tokens=True, |
| return_tensors="pt" |
| ) |
| ``` |
| |
| **After** |
| ```python |
| prompt = pipe.prompter.process_prompt(prompt, positive=positive) |
| output = pipe.prompter.tokenizer(prompt, return_mask=True, add_special_tokens=True) |
| ... |
| prompt_shot_all = pipe.prompter.process_prompt(prompt_shot_all, positive=positive) |
| ... |
| enc_output = pipe.prompter.tokenizer( |
| text, |
| return_mask=True, |
| add_special_tokens=True, |
| return_tensors="pt" |
| ) |
| ``` |
|
|
| ### 2.7 兼容 tokenizer 返回 tuple / dict |
| **Before** |
| ```python |
| output = pipe.prompter.tokenizer(prompt, return_mask=True, add_special_tokens=True) |
| ids = output['input_ids'].to(device) |
| mask = output['attention_mask'].to(device) |
| ... |
| enc_output = pipe.prompter.tokenizer(..., return_mask=True, ...) |
| ids = enc_output['input_ids'].to(device) |
| mask = enc_output['attention_mask'].to(device) |
| ``` |
|
|
| **After** |
| ```python |
| output = pipe.prompter.tokenizer(prompt, return_mask=True, add_special_tokens=True) |
| if isinstance(output, tuple): |
| ids, mask = output |
| else: |
| ids = output['input_ids'] |
| mask = output['attention_mask'] |
| ids = ids.to(device) |
| mask = mask.to(device) |
| ... |
| enc_output = pipe.prompter.tokenizer(..., return_mask=True, ...) |
| if isinstance(enc_output, tuple): |
| ids, mask = enc_output |
| else: |
| ids = enc_output['input_ids'] |
| mask = enc_output['attention_mask'] |
| ids = ids.to(device) |
| mask = mask.to(device) |
| ``` |
|
|
| ### 2.8 使用 prompter 的 `text_len`(修复属性缺失) |
| **Before** |
| ```python |
| pad_len = pipe.text_encoder.text_len - total_len |
| ``` |
| |
| **After** |
| ```python |
| pad_len = pipe.prompter.text_len - total_len |
| ``` |
| |
| ## 3) `multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/models/wan_video_dit.py` |
| ### 3.1 `attention_per_batch_with_shots` 中 ID token slice 修复 |
| **Before** |
| ```python |
| ID_token_start = shot_token_all_num + id_idx * pre_ID_token_num |
| ID_token_end = start + pre_ID_token_num |
| assert end <= k.shape[2], ( |
| f"ID token slice out of range: start={start}, end={end}, " |
| f"K_len={k.shape[2]}" |
| ) |
| id_token_k = k[bi, :, start:end, :] |
| id_token_v = v[bi, :, start:end, :] |
| ``` |
| |
| **After** |
| ```python |
| start = shot_token_all_num + id_idx * pre_id_token_num |
| if start >= k.shape[2]: |
| continue |
| end = min(start + pre_id_token_num, k.shape[2]) |
| id_token_k = k[bi, :, start:end, :] |
| id_token_v = v[bi, :, start:end, :] |
| ``` |
|
|
| ### 3.2 `CrossAttention.forward` 增加 `attn_mask` |
| **Before** |
| ```python |
| def forward(self, x: torch.Tensor, y: torch.Tensor): |
| ... |
| x = self.attn(q, k, v) |
| ``` |
| |
| **After** |
| ```python |
| def forward(self, x: torch.Tensor, y: torch.Tensor, attn_mask=None): |
| ... |
| x = self.attn(q, k, v, attn_mask=attn_mask) |
| ``` |
| |
| ## 4) `multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/trainers/utils.py` |
| **新增参数以匹配 pipeline** |
| |
| **Before** |
| ```python |
| # (no --shot_rope argument) |
| ``` |
| |
| **After** |
| ```python |
| parser.add_argument("--shot_rope", type=bool, default=False, help="Whether apply shot rope for multi-shot video") |
| ``` |
| |
| ## 5) 新增文件 |
| **`multi-shot/MULTI_SHOT_CORE_SUMMARY.md`** |
| - Before: 文件不存在 |
| - After: 新增总结文档 |
| |
| **`multi-shot/MODIFICATION_LOG.md`** |
| - Before: 文件不存在 |
| |
| ## 6) `multi-shot/dry_run_train.py` |
| **强制将模型移动到 CUDA 以匹配输入设备** |
| |
| **Before** |
| ```python |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model.pipe.device = device |
| model.pipe.torch_dtype = torch.bfloat16 |
| ``` |
| |
| **After** |
| ```python |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model.to(device) |
| model.pipe.device = device |
| model.pipe.torch_dtype = torch.bfloat16 |
| ``` |
| - After: 新增修改清单(本文件) |
| |
| ## 验证 |
| ```bash |
| python -m py_compile multi-shot/multi_view/datasets/videodataset.py |
| python -m py_compile multi-shot/multi_view/train.py |
| python -m py_compile multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/pipelines/wan_video_new.py |
| python -m py_compile multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/models/wan_video_dit.py |
| python -m py_compile multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/trainers/utils.py |
| ``` |
| |