| |
| |
|
|
| |
| |
|
|
| from typing import List, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class ImageEncoder(nn.Module): |
| def __init__( |
| self, |
| trunk: nn.Module, |
| neck: nn.Module, |
| scalp: int = 0, |
| ): |
| super().__init__() |
| self.trunk = trunk |
| self.neck = neck |
| self.scalp = scalp |
| assert ( |
| self.trunk.channel_list == self.neck.backbone_channel_list |
| ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" |
|
|
| def forward(self, sample: torch.Tensor): |
| |
| features, pos = self.neck(self.trunk(sample)) |
| if self.scalp > 0: |
| |
| features, pos = features[: -self.scalp], pos[: -self.scalp] |
|
|
| src = features[-1] |
| output = { |
| "vision_features": src, |
| "vision_pos_enc": pos, |
| "backbone_fpn": features, |
| } |
| return output |
|
|
|
|
| class FpnNeck(nn.Module): |
| """ |
| A modified variant of Feature Pyramid Network (FPN) neck |
| (we remove output conv and also do bicubic interpolation similar to ViT |
| pos embed interpolation) |
| """ |
|
|
| def __init__( |
| self, |
| position_encoding: nn.Module, |
| d_model: int, |
| backbone_channel_list: List[int], |
| kernel_size: int = 1, |
| stride: int = 1, |
| padding: int = 0, |
| fpn_interp_model: str = "bilinear", |
| fuse_type: str = "sum", |
| fpn_top_down_levels: Optional[List[int]] = None, |
| ): |
| """Initialize the neck |
| :param trunk: the backbone |
| :param position_encoding: the positional encoding to use |
| :param d_model: the dimension of the model |
| :param neck_norm: the normalization to use |
| """ |
| super().__init__() |
| self.position_encoding = position_encoding |
| self.convs = nn.ModuleList() |
| self.backbone_channel_list = backbone_channel_list |
| for dim in backbone_channel_list: |
| current = nn.Sequential() |
| current.add_module( |
| "conv", |
| nn.Conv2d( |
| in_channels=dim, |
| out_channels=d_model, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| ), |
| ) |
|
|
| self.convs.append(current) |
| self.fpn_interp_model = fpn_interp_model |
| assert fuse_type in ["sum", "avg"] |
| self.fuse_type = fuse_type |
|
|
| |
| |
| |
| |
| if fpn_top_down_levels is None: |
| |
| fpn_top_down_levels = range(len(self.convs)) |
| self.fpn_top_down_levels = list(fpn_top_down_levels) |
|
|
| def forward(self, xs: List[torch.Tensor]): |
|
|
| out = [None] * len(self.convs) |
| pos = [None] * len(self.convs) |
| assert len(xs) == len(self.convs) |
| |
| |
| prev_features = None |
| |
| n = len(self.convs) - 1 |
| for i in range(n, -1, -1): |
| x = xs[i] |
| lateral_features = self.convs[n - i](x) |
| if i in self.fpn_top_down_levels and prev_features is not None: |
| top_down_features = F.interpolate( |
| prev_features.to(dtype=torch.float32), |
| scale_factor=2.0, |
| mode=self.fpn_interp_model, |
| align_corners=( |
| None if self.fpn_interp_model == "nearest" else False |
| ), |
| antialias=False, |
| ) |
| prev_features = lateral_features + top_down_features |
| if self.fuse_type == "avg": |
| prev_features /= 2 |
| else: |
| prev_features = lateral_features |
| x_out = prev_features |
| out[i] = x_out |
| pos[i] = self.position_encoding(x_out).to(x_out.dtype) |
|
|
| return out, pos |
|
|