| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import copy |
| from abc import ABC, abstractmethod |
| from collections import defaultdict |
| from dataclasses import dataclass, fields |
| from enum import Enum |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .batch_ops import batch_mul |
| from .log import log |
| from .lazy_config_init import instantiate |
|
|
|
|
| class BaseConditionEntry(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| self._dropout_rate = None |
| self._input_key = None |
| self._return_dict = False |
|
|
| @property |
| def dropout_rate(self) -> Union[float, torch.Tensor]: |
| return self._dropout_rate |
|
|
| @property |
| def input_key(self) -> str: |
| return self._input_key |
|
|
| @property |
| def is_return_dict(self) -> bool: |
| return self._return_dict |
|
|
| @dropout_rate.setter |
| def dropout_rate(self, value: Union[float, torch.Tensor]): |
| self._dropout_rate = value |
|
|
| @input_key.setter |
| def input_key(self, value: str): |
| self._input_key = value |
|
|
| @is_return_dict.setter |
| def is_return_dict(self, value: bool): |
| self._return_dict = value |
|
|
| @dropout_rate.deleter |
| def dropout_rate(self): |
| del self._dropout_rate |
|
|
| @input_key.deleter |
| def input_key(self): |
| del self._input_key |
|
|
| @is_return_dict.deleter |
| def is_return_dict(self): |
| del self._return_dict |
|
|
| def random_dropout_input( |
| self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None |
| ) -> torch.Tensor: |
| del key |
| dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate |
| return batch_mul( |
| torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor), |
| in_tensor, |
| ) |
|
|
| def summary(self) -> str: |
| pass |
|
|
|
|
| class DataType(Enum): |
| IMAGE = "image" |
| VIDEO = "video" |
|
|
|
|
| class TextAttr(BaseConditionEntry): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, token: torch.Tensor, mask: torch.Tensor): |
| return {"crossattn_emb": token, "crossattn_mask": mask} |
|
|
| def random_dropout_input( |
| self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None |
| ) -> torch.Tensor: |
| if key is not None and "mask" in key: |
| return in_tensor |
| return super().random_dropout_input(in_tensor, dropout_rate, key) |
|
|
|
|
| @dataclass |
| class BaseVideoCondition: |
| crossattn_emb: torch.Tensor |
| crossattn_mask: torch.Tensor |
| data_type: DataType = DataType.VIDEO |
| padding_mask: Optional[torch.Tensor] = None |
| fps: Optional[torch.Tensor] = None |
| num_frames: Optional[torch.Tensor] = None |
| image_size: Optional[torch.Tensor] = None |
| scalar_feature: Optional[torch.Tensor] = None |
|
|
| def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: |
| return {f.name: getattr(self, f.name) for f in fields(self)} |
|
|
|
|
| @dataclass |
| class VideoExtendCondition(BaseVideoCondition): |
| video_cond_bool: Optional[torch.Tensor] = None |
| gt_latent: Optional[torch.Tensor] = None |
| condition_video_indicator: Optional[torch.Tensor] = None |
|
|
| |
| |
| condition_video_input_mask: Optional[torch.Tensor] = None |
| |
| condition_video_augment_sigma: Optional[torch.Tensor] = None |
|
|
|
|
| class GeneralConditioner(nn.Module, ABC): |
| """ |
| An abstract module designed to handle various embedding models with conditional and |
| unconditional configurations. This abstract base class initializes and manages a collection |
| of embedders that can dynamically adjust their dropout rates based on conditioning. |
| |
| Attributes: |
| KEY2DIM (dict): A mapping from output keys to dimensions used for concatenation. |
| embedders (nn.ModuleDict): A dictionary containing all embedded models initialized and |
| configured based on the provided configurations. |
| |
| Parameters: |
| emb_models (Union[List, Any]): A dictionary where keys are embedder names and values |
| are configurations for initializing the embedders. |
| |
| """ |
|
|
| KEY2DIM = {"crossattn_emb": 1, "crossattn_mask": 1} |
|
|
| def __init__(self, **emb_models: Union[List, Any]): |
| super().__init__() |
| self.embedders = nn.ModuleDict() |
| for n, (emb_name, embconfig) in enumerate(emb_models.items()): |
| embedder = instantiate(embconfig.obj) |
| assert isinstance( |
| embedder, BaseConditionEntry |
| ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" |
| embedder.dropout_rate = getattr(embconfig, "dropout_rate", 0.0) |
|
|
| if hasattr(embconfig, "input_key"): |
| embedder.input_key = embconfig.input_key |
| elif hasattr(embconfig, "input_keys"): |
| embedder.input_keys = embconfig.input_keys |
| else: |
| raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}") |
|
|
| log.debug(f"Initialized embedder #{n}-{emb_name}: \n {embedder.summary()}") |
| self.embedders[emb_name] = embedder |
|
|
| @abstractmethod |
| def forward( |
| self, |
| batch: Dict, |
| override_dropout_rate: Optional[Dict[str, float]] = None, |
| ) -> Any: |
| """Should be implemented in subclasses to handle conditon datatype""" |
| raise NotImplementedError |
|
|
| def _forward( |
| self, |
| batch: Dict, |
| override_dropout_rate: Optional[Dict[str, float]] = None, |
| ) -> Dict: |
| """ |
| Processes the input batch through all configured embedders, applying conditional dropout rates if specified. |
| Output tensors for each key are concatenated along the dimensions specified in KEY2DIM. |
| |
| Parameters: |
| batch (Dict): The input data batch to process. |
| override_dropout_rate (Optional[Dict[str, float]]): Optional dictionary to override default dropout rates |
| per embedder key. |
| |
| Returns: |
| Dict: A dictionary of output tensors concatenated by specified dimensions. |
| |
| Note: |
| In case the network code is sensitive to the order of concatenation, you can either control the order via \ |
| config file or make sure the embedders return a unique key for each output. |
| """ |
| output = defaultdict(list) |
| if override_dropout_rate is None: |
| override_dropout_rate = {} |
|
|
| |
| for emb_name in override_dropout_rate.keys(): |
| assert emb_name in self.embedders, f"invalid name found {emb_name}" |
|
|
| for emb_name, embedder in self.embedders.items(): |
| with torch.no_grad(): |
| if hasattr(embedder, "input_key") and (embedder.input_key is not None): |
| emb_out = embedder( |
| embedder.random_dropout_input( |
| batch[embedder.input_key], override_dropout_rate.get(emb_name, None) |
| ) |
| ) |
| elif hasattr(embedder, "input_keys"): |
| emb_out = embedder( |
| *[ |
| embedder.random_dropout_input(batch[k], override_dropout_rate.get(emb_name, None), k) |
| for k in embedder.input_keys |
| ] |
| ) |
| for k, v in emb_out.items(): |
| output[k].append(v) |
| |
| return {k: torch.cat(v, dim=self.KEY2DIM.get(k, -1)) for k, v in output.items()} |
|
|
| def get_condition_uncondition( |
| self, |
| data_batch: Dict, |
| ) -> Tuple[Any, Any]: |
| """ |
| Processes the provided data batch to generate conditioned and unconditioned outputs. |
| |
| This method manipulates dropout rates to simulate two scenarios: |
| 1. All conditions applied (conditioned) |
| 2. Conditions removed/reduced to minimum (unconditioned) |
| |
| This method sets dropout rates to zero for the conditioned scenario to fully apply |
| embedders' effects. For unconditioned, it sets rates to 1 (or 0 if initial rate is |
| insignificant) to minimize embedder influences. |
| |
| Parameters: |
| data_batch (Dict): Input data batch containing all necessary information for |
| embedding processing. |
| |
| Returns: |
| Tuple[Any, Any]: A tuple containing: |
| - Outputs with all embedders fully applied (conditioned) |
| - Outputs with embedders minimized/not applied (unconditioned) |
| """ |
| cond_dropout_rates, dropout_rates = {}, {} |
| for emb_name, embedder in self.embedders.items(): |
| cond_dropout_rates[emb_name] = 0.0 |
| dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 |
|
|
| condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) |
| un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates) |
| return condition, un_condition |
|
|
| def get_condition_with_negative_prompt( |
| self, |
| data_batch: Dict, |
| ) -> Tuple[Any, Any]: |
| """ |
| Similar functionality as get_condition_uncondition |
| But use negative prompts for unconditon |
| """ |
| cond_dropout_rates, uncond_dropout_rates = {}, {} |
| for emb_name, embedder in self.embedders.items(): |
| cond_dropout_rates[emb_name] = 0.0 |
| if isinstance(embedder, TextAttr): |
| uncond_dropout_rates[emb_name] = 0.0 |
| else: |
| uncond_dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0 |
|
|
| data_batch_neg_prompt = copy.deepcopy(data_batch) |
| if "neg_t5_text_embeddings" in data_batch_neg_prompt: |
| if isinstance(data_batch_neg_prompt["neg_t5_text_embeddings"], torch.Tensor): |
| data_batch_neg_prompt["t5_text_embeddings"] = data_batch_neg_prompt["neg_t5_text_embeddings"] |
| data_batch_neg_prompt["t5_text_mask"] = data_batch_neg_prompt["neg_t5_text_mask"] |
|
|
| condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates) |
| un_condition: Any = self(data_batch_neg_prompt, override_dropout_rate=uncond_dropout_rates) |
|
|
| return condition, un_condition |
|
|
|
|
| @dataclass |
| class CosmosCondition: |
| crossattn_emb: torch.Tensor |
| crossattn_mask: torch.Tensor |
| padding_mask: Optional[torch.Tensor] = None |
| scalar_feature: Optional[torch.Tensor] = None |
|
|
| def to_dict(self) -> Dict[str, Optional[torch.Tensor]]: |
| return {f.name: getattr(self, f.name) for f in fields(self)} |
|
|
|
|
| class VideoConditioner(GeneralConditioner): |
| def forward( |
| self, |
| batch: Dict, |
| override_dropout_rate: Optional[Dict[str, float]] = None, |
| ) -> BaseVideoCondition: |
| output = super()._forward(batch, override_dropout_rate) |
| return BaseVideoCondition(**output) |
|
|
|
|
| class VideoExtendConditioner(GeneralConditioner): |
| def forward( |
| self, |
| batch: Dict, |
| override_dropout_rate: Optional[Dict[str, float]] = None, |
| ) -> VideoExtendCondition: |
| output = super()._forward(batch, override_dropout_rate) |
| return VideoExtendCondition(**output) |
|
|