| |
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from transformers.activations import ACT2FN |
|
|
| from .configuration_downsampler import DownsamplerConfig |
|
|
|
|
| class DownsamplerModel(PreTrainedModel): |
| _auto_class = 'AutoModel' |
| config_class = DownsamplerConfig |
| base_model_prefix = 'model' |
| supports_gradient_checkpointing = True |
|
|
| def __init__(self, config: DownsamplerConfig) -> None: |
| super().__init__(config) |
| self.gradient_checkpointing = False |
|
|
| self.group_op = nn.Conv2d( |
| in_channels=config.visual_hidden_size, |
| out_channels=config.llm_hidden_size, |
| bias=config.bias, |
| kernel_size=config.kernel_size, stride=config.stride) |
| modules = list() |
| for _ in range(1, config.depth): |
| modules.append(ACT2FN[config.hidden_act]) |
| modules.append( |
| nn.Linear( |
| config.llm_hidden_size, |
| config.llm_hidden_size, |
| bias=config.bias)) |
| self.linear_model = nn.Sequential(*modules) |
|
|
| def enable_input_require_grads(self): |
|
|
| def make_inputs_require_grad(module, input, output): |
| output.requires_grad_(True) |
|
|
| self.model.register_forward_hook(make_inputs_require_grad) |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| if isinstance(module, DownsamplerModel): |
| module.gradient_checkpointing = value |
|
|
| def _forward(self, x): |
|
|
| |
| x = x.permute(0, 3, 1, 2) |
| x = self.group_op(x) |
| |
| x = x.permute(0, 2, 3, 1) |
| x = self.linear_model(x) |
|
|
| return x |
|
|
| def forward(self, x): |
| if self.gradient_checkpointing and self.training: |
| layer_outputs = torch.utils.checkpoint.checkpoint(self._forward, x) |
| else: |
| layer_outputs = self._forward(x) |
| return layer_outputs |
|
|