| from __future__ import annotations | |
| import torch | |
| from torch import nn | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from transformers.activations import ACT2FN | |
| class BitDanceProjector(ModelMixin, ConfigMixin): | |
| def __init__( | |
| self, | |
| in_dim: int, | |
| out_dim: int, | |
| hidden_act: str = "gelu_pytorch_tanh", | |
| ) -> None: | |
| super().__init__() | |
| self.activation_fn = ACT2FN[hidden_act] | |
| self.fc1 = nn.Linear(in_dim, out_dim) | |
| self.fc2 = nn.Linear(out_dim, out_dim) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = hidden_states.to(self.fc1.weight.dtype) | |
| hidden_states = self.fc1(hidden_states) | |
| hidden_states = self.activation_fn(hidden_states) | |
| hidden_states = self.fc2(hidden_states) | |
| return hidden_states | |