|
|
| import torch |
| import torch.nn as nn |
| import math |
| import re |
|
|
| def build_layout_projector(): |
| projector_type = 'mlp2x_gelu' |
| mm_hidden_size = 4 |
| hidden_size = 4096 |
|
|
| mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) |
| if mlp_gelu_match: |
| mlp_depth = int(mlp_gelu_match.group(1)) |
| modules = [nn.Linear(mm_hidden_size, hidden_size)] |
| for _ in range(1, mlp_depth): |
| modules.append(nn.GELU()) |
| modules.append(nn.Linear(hidden_size, hidden_size)) |
| return nn.Sequential(*modules) |
|
|
| if projector_type == 'identity': |
| return IdentityMap() |
|
|
| raise ValueError(f'Unknown projector type: {projector_type}') |
|
|
|
|
| class IdentityMap(nn.Module): |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x, *args, **kwargs): |
| return x |
|
|
| @property |
| def config(self): |
| return {'mm_projector_type': 'identity'} |
|
|
| class PLoRA(nn.Linear): |
|
|
| def __init__(self, |
| in_features: int, |
| out_features: int, |
| bias: bool = True, |
| device=None, |
| dtype=None, |
| lora_r=8, |
| lora_alpha=16, |
| lora_dropout=0.05, |
| lora_len=0, |
| **kwargs) -> None: |
| super().__init__(in_features, out_features, bias, device, dtype) |
| self.lora_r = lora_r |
| self.lora_alpha = lora_alpha |
| self.lora_len = lora_len |
| if lora_dropout > 0.: |
| self.lora_dropout = nn.Dropout(p=lora_dropout) |
| else: |
| self.lora_dropout = lambda x: x |
| self.lora_scaling = self.lora_alpha / self.lora_r |
|
|
| self.Plora_A = nn.Linear( |
| in_features, self.lora_r, bias=False, device=device, dtype=dtype) |
| self.Plora_B = nn.Linear( |
| self.lora_r, out_features, bias=False, device=device, dtype=dtype) |
|
|
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| if hasattr(self, 'lora_A'): |
| |
| nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) |
| nn.init.zeros_(self.lora_B.weight) |
|
|
| def forward(self, x, im_mask=None): |
| res = super().forward(x) |
| if im_mask is not None: |
| if torch.sum(im_mask) > 0: |
| part_x = x[im_mask] |
| res[im_mask] += self.Plora_B( |
| self.Plora_A( |
| self.lora_dropout(part_x))) * self.lora_scaling |
| else: |
| part_x = x[:, :1] |
| res[:, :1] += self.Plora_B( |
| self.Plora_A(self.lora_dropout(part_x))) * 0 |
| return res |