# 修改清单(前后对比) ## 范围 只做“能跑”的最小修复,尽量保留原本逻辑与结构。 ## 1) `multi-shot/multi_view/datasets/videodataset.py` **补齐未定义变量,保持原返回结构** **Before** ```python return { "global_caption": None, "shot_num": 3, "pre_shot_caption": ["xxx", "xxx", "xxx"], # "single_caption": meta_prompt["single_prompt"], "video": input_video, "ref_num": ID_num * 3, ###TODO: 先跑通 ID_num = 1 的情况 "ID_num": ID_num, "ref_images": [[Image0, Image1, Image2]], "video_path": video_path } ``` **After** ```python ID_num = 1 Image0, Image1, Image2 = ref_images[:3] return { "global_caption": None, "shot_num": 3, "pre_shot_caption": ["xxx", "xxx", "xxx"], # "single_caption": meta_prompt["single_prompt"], "video": input_video, "ref_num": ID_num * 3, ###TODO: 先跑通 ID_num = 1 的情况 "ID_num": ID_num, "ref_images": [[Image0, Image1, Image2]], "video_path": video_path } ``` ## 2) `multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/pipelines/wan_video_new.py` ### 2.1 Prompt 编码(修复拼写/对象调用) **Before** ```python prompt = pip.text_encoder.process_prompt(prompt, positive=positive) output = pip.text_encoder.tokenizer(prompt, return_mask=True, add_special_tokens=True) ids = output['input_ids'].to(device) mask = output['attention_mask'].to(device) prompt_emb = self.text_encoder(ids, mask) ... prompt_shot_all = pip.text_encoder.process_prompt(prompt_shot_all, positive=positive) ... for shot_index, shot_cut_end in enmurate(shot_cut_ends): start_pos = shot_cut_starts[shot_index] end_pos = shot_cut_end shot_text = cleaned_prompt[start_pos: end_pos + 1].strip() ``` **After** ```python prompt = pipe.text_encoder.process_prompt(prompt, positive=positive) output = pipe.text_encoder.tokenizer(prompt, return_mask=True, add_special_tokens=True) ids = output['input_ids'].to(device) mask = output['attention_mask'].to(device) prompt_emb = pipe.text_encoder(ids, mask) ... prompt_shot_all = pipe.text_encoder.process_prompt(prompt_shot_all, positive=positive) cleaned_prompt = prompt_shot_all ... for shot_index, shot_cut_end in enumerate(shot_cut_ends): start_pos = shot_cut_starts[shot_index] end_pos = shot_cut_end shot_text = cleaned_prompt[start_pos: end_pos + 1].strip() ``` ### 2.2 Shot mask 构造(修复未定义变量) **Before** ```python S_shots = len(shot_text_ranges[0]) ###TODO: 当前batch size 是 1 ... for sid, (s0, s1) in enumerate(shot_ranges): s0 = int(s0) s1 = int(s1) shot_table[sid, s0: s1 + 1] = True ... allow_all = torch.cat([allow_shot, allow_ref_image], dim = 1) assert allow_all.shape == x.shape[2] "The shape is something wrong" ``` **After** ```python shot_ranges = shot_text_ranges[0] if isinstance(shot_ranges, dict): shot_ranges = shot_ranges.get("shots", []) S_shots = len(shot_ranges) for sid, span in enumerate(shot_ranges): if span is None: continue s0, s1 = span s0 = int(s0) s1 = int(s1) shot_table[sid, s0: s1 + 1] = True ... allow_all = torch.cat([allow_shot, allow_ref_image], dim = 1) assert allow_all.shape[1] == S_q, "The shape is something wrong" ``` ### 2.3 `shot_rope` 分支变量名冲突修复 **Before** ```python for shot_index, num_frames in enumerate(shots_nums): f = num_frames rope_s = freq_s[shot_index] \ .view(1, 1, 1, -1) \ .expand(f, h, w, -1) ... freqs = freqs.reshape(f * h * w, 1, -1) ``` **After** ```python for shot_index, num_frames in enumerate(shots_nums): f = num_frames rope_s = freq_s[shot_index].view(1, 1, 1, -1).expand(f, h, w, -1) ... freqs = freqs.reshape(f * h * w, 1, -1) ``` ### 2.4 `model_fn_wan_video` 函数签名语法修复 **Before** ```python ID_2_shot: None ######每个shot 中对应包含的ID是那几个,是一个list[ batch0: [shot0: [0,1], shot1:[2]], batch1:[]] **kwargs, ``` **After** ```python ID_2_shot=None, ######每个shot 中对应包含的ID是那几个,是一个list[ batch0: [shot0: [0,1], shot1:[2]], batch1:[]] **kwargs, ``` ### 2.5 `WanVideoUnit_SpeedControl` 缺失类补齐 **Before** ```python WanVideoUnit_SpeedControl(), # 在 units 列表中引用,但类未定义 ``` **After** ```python class WanVideoUnit_SpeedControl(PipelineUnit): def __init__(self): super().__init__(input_params=("motion_bucket_id",)) def process(self, pipe: WanVideoPipeline, motion_bucket_id): if motion_bucket_id is None: return {} motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=pipe.torch_dtype, device=pipe.device) return {"motion_bucket_id": motion_bucket_id} ``` ### 2.6 Prompt 处理使用 prompter(修复 `process_prompt` 缺失) **Before** ```python prompt = pipe.text_encoder.process_prompt(prompt, positive=positive) output = pipe.text_encoder.tokenizer(prompt, return_mask=True, add_special_tokens=True) ... prompt_shot_all = pipe.text_encoder.process_prompt(prompt_shot_all, positive=positive) ... enc_output = pipe.text_encoder( text, return_mask=True, add_special_tokens=True, return_tensors="pt" ) ``` **After** ```python prompt = pipe.prompter.process_prompt(prompt, positive=positive) output = pipe.prompter.tokenizer(prompt, return_mask=True, add_special_tokens=True) ... prompt_shot_all = pipe.prompter.process_prompt(prompt_shot_all, positive=positive) ... enc_output = pipe.prompter.tokenizer( text, return_mask=True, add_special_tokens=True, return_tensors="pt" ) ``` ### 2.7 兼容 tokenizer 返回 tuple / dict **Before** ```python output = pipe.prompter.tokenizer(prompt, return_mask=True, add_special_tokens=True) ids = output['input_ids'].to(device) mask = output['attention_mask'].to(device) ... enc_output = pipe.prompter.tokenizer(..., return_mask=True, ...) ids = enc_output['input_ids'].to(device) mask = enc_output['attention_mask'].to(device) ``` **After** ```python output = pipe.prompter.tokenizer(prompt, return_mask=True, add_special_tokens=True) if isinstance(output, tuple): ids, mask = output else: ids = output['input_ids'] mask = output['attention_mask'] ids = ids.to(device) mask = mask.to(device) ... enc_output = pipe.prompter.tokenizer(..., return_mask=True, ...) if isinstance(enc_output, tuple): ids, mask = enc_output else: ids = enc_output['input_ids'] mask = enc_output['attention_mask'] ids = ids.to(device) mask = mask.to(device) ``` ### 2.8 使用 prompter 的 `text_len`(修复属性缺失) **Before** ```python pad_len = pipe.text_encoder.text_len - total_len ``` **After** ```python pad_len = pipe.prompter.text_len - total_len ``` ## 3) `multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/models/wan_video_dit.py` ### 3.1 `attention_per_batch_with_shots` 中 ID token slice 修复 **Before** ```python ID_token_start = shot_token_all_num + id_idx * pre_ID_token_num ID_token_end = start + pre_ID_token_num assert end <= k.shape[2], ( f"ID token slice out of range: start={start}, end={end}, " f"K_len={k.shape[2]}" ) id_token_k = k[bi, :, start:end, :] id_token_v = v[bi, :, start:end, :] ``` **After** ```python start = shot_token_all_num + id_idx * pre_id_token_num if start >= k.shape[2]: continue end = min(start + pre_id_token_num, k.shape[2]) id_token_k = k[bi, :, start:end, :] id_token_v = v[bi, :, start:end, :] ``` ### 3.2 `CrossAttention.forward` 增加 `attn_mask` **Before** ```python def forward(self, x: torch.Tensor, y: torch.Tensor): ... x = self.attn(q, k, v) ``` **After** ```python def forward(self, x: torch.Tensor, y: torch.Tensor, attn_mask=None): ... x = self.attn(q, k, v, attn_mask=attn_mask) ``` ## 4) `multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/trainers/utils.py` **新增参数以匹配 pipeline** **Before** ```python # (no --shot_rope argument) ``` **After** ```python parser.add_argument("--shot_rope", type=bool, default=False, help="Whether apply shot rope for multi-shot video") ``` ## 5) 新增文件 **`multi-shot/MULTI_SHOT_CORE_SUMMARY.md`** - Before: 文件不存在 - After: 新增总结文档 **`multi-shot/MODIFICATION_LOG.md`** - Before: 文件不存在 ## 6) `multi-shot/dry_run_train.py` **强制将模型移动到 CUDA 以匹配输入设备** **Before** ```python device = "cuda" if torch.cuda.is_available() else "cpu" model.pipe.device = device model.pipe.torch_dtype = torch.bfloat16 ``` **After** ```python device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.pipe.device = device model.pipe.torch_dtype = torch.bfloat16 ``` - After: 新增修改清单(本文件) ## 验证 ```bash python -m py_compile multi-shot/multi_view/datasets/videodataset.py python -m py_compile multi-shot/multi_view/train.py python -m py_compile multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/pipelines/wan_video_new.py python -m py_compile multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/models/wan_video_dit.py python -m py_compile multi-shot/multi_view/DiffSynth-Studio-main/diffsynth/trainers/utils.py ```