| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import dino.vision_transformer as vits |
|
|
| from bkg_seg import compute_img_bkg_seg |
| from misc import batch_apply_bilateral_solver |
|
|
| class FoundModel(nn.Module): |
| def __init__( |
| self, |
| vit_model="dino", |
| vit_arch="vit_small", |
| vit_patch_size=8, |
| enc_type_feats="k", |
| bkg_type_feats="k", |
| bkg_th=0.3 |
| ): |
| |
| super(FoundModel, self).__init__() |
|
|
| |
| |
| self.vit_encoder, self.initial_dim, self.hook_features = get_vit_encoder( |
| vit_arch, vit_model, vit_patch_size, enc_type_feats |
| ) |
| self.vit_patch_size = vit_patch_size |
| self.enc_type_feats = enc_type_feats |
|
|
| |
| |
| self.bkg_type_feats = bkg_type_feats |
| self.bkg_th = bkg_th |
|
|
| |
| |
| self.previous_dim = self.initial_dim |
| self.decoder = nn.Conv2d(self.previous_dim, 1, (1, 1)) |
|
|
| def forward_step(self, batch, decoder=None, for_eval=False): |
|
|
| |
| if for_eval: |
| batch = self.make_input_divisible(batch) |
| _w, _h = batch.shape[-2:] |
| _h, _w = _h // self.vit_patch_size, _w // self.vit_patch_size |
| else: |
| |
| w, h = ( |
| batch.shape[-2] - batch.shape[-2] % self.vit_patch_size, |
| batch.shape[-1] - batch.shape[-1] % self.vit_patch_size, |
| ) |
| batch = batch[:, :, :w, :h] |
|
|
| w_featmap = batch.shape[-2] // self.vit_patch_size |
| h_featmap = batch.shape[-1] // self.vit_patch_size |
|
|
| |
| with torch.no_grad(): |
| |
| att = self.vit_encoder.get_last_selfattention(batch) |
| |
| |
| feats = self.extract_feats(dims=att.shape, type_feats=self.enc_type_feats) |
| feats = feats[:, 1:, :, :].reshape(att.shape[0], w_featmap, h_featmap, -1) |
| feats = feats.permute(0, 3, 1, 2) |
|
|
| |
| if decoder is None: |
| decoder = self.decoder |
| preds = decoder(feats) |
|
|
| |
| return preds, feats, (w_featmap, h_featmap), att |
|
|
| def make_input_divisible(self, x: torch.Tensor) -> torch.Tensor: |
| |
| """Pad some pixels to make the input size divisible by the patch size.""" |
| B, _, H_0, W_0 = x.shape |
| pad_w = (self.vit_patch_size - W_0 % self.vit_patch_size) % self.vit_patch_size |
| pad_h = (self.vit_patch_size - H_0 % self.vit_patch_size) % self.vit_patch_size |
|
|
| x = nn.functional.pad(x, (0, pad_w, 0, pad_h), value=0) |
| return x |
|
|
| def compute_background_batch( |
| self, |
| att, |
| shape_f, |
| |
| ): |
|
|
| w_f, h_f = shape_f |
|
|
| |
| nb_im = att.shape[0] |
| nh = att.shape[1] |
| nb_tokens = att.shape[2] |
|
|
| |
| feats = self.extract_feats(dims=att.shape, |
| |
| type_feats=self.bkg_type_feats |
| ) |
| feats = feats.reshape(nb_im, nb_tokens, -1) |
|
|
| bkg_mask = compute_img_bkg_seg( |
| att, |
| feats, |
| (w_f,h_f), |
| th_bkg=self.bkg_th, |
| dim=int(self.initial_dim / nh), |
| ) |
|
|
| return bkg_mask |
|
|
|
|
| def get_bkg_pseudo_labels_batch( |
| self, |
| att, |
| shape_f, |
| data, |
| use_bilateral_solver = True, |
| shape=None, |
| ): |
|
|
| bkg_mask_pred = self.compute_background_batch( |
| att, shape_f |
| ) |
| |
| |
| obj_mask = (~bkg_mask_pred.bool()).float() |
|
|
| if use_bilateral_solver: |
| pseudo_labels, cnt_bs = batch_apply_bilateral_solver(data, obj_mask, shape) |
| return pseudo_labels, cnt_bs |
| else: |
| return obj_mask, 0 |
|
|
| @torch.no_grad() |
| def decoder_load_weights(self, weights_path): |
| print(f"Loading model from weights {weights_path}.") |
| |
| map_location=torch.device('cpu') |
| state_dict = torch.load(weights_path, map_location=map_location) |
|
|
| |
| self.decoder.load_state_dict(state_dict["decoder"]) |
| self.decoder.eval() |
|
|
|
|
| @torch.no_grad() |
| def decoder_save_weights(self, save_dir, n_iter): |
| state_dict = {} |
| state_dict["decoder"] = self.decoder.state_dict() |
| fname = os.path.join( |
| save_dir, f"decoder_weights_niter{n_iter}.pt" |
| ) |
| torch.save(state_dict, fname) |
| print(f"\n----" |
| f"\nModel saved at {fname}" |
| ) |
| |
| @torch.no_grad() |
| def extract_feats(self, dims, type_feats="k"): |
|
|
| nb_im, nh, nb_tokens, _ = dims |
| qkv = ( |
| self.hook_features["qkv"] |
| .reshape( |
| nb_im, nb_tokens, 3, nh, -1 // nh |
| ) |
| .permute(2, 0, 3, 1, 4) |
| ) |
|
|
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| if type_feats == "q": |
| return q.transpose(1, 2).float() |
| elif type_feats == "k": |
| return k.transpose(1, 2).float() |
| elif type_feats == "v": |
| return v.transpose(1, 2).float() |
| else: |
| raise ValueError("Unknown features") |
|
|
|
|
| def get_vit_encoder(vit_arch, vit_model, vit_patch_size, enc_type_feats): |
| if vit_arch == "vit_small" and vit_patch_size == 16: |
| url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" |
| initial_dim = 384 |
| elif vit_arch == "vit_small" and vit_patch_size == 8: |
| url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" |
| initial_dim = 384 |
| elif vit_arch == "vit_base" and vit_patch_size == 16: |
| if vit_model == "clip": |
| url = "5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt" |
| elif vit_model == "dino": |
| url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth" |
| initial_dim = 768 |
| elif vit_arch == "vit_base" and vit_patch_size == 8: |
| url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" |
| initial_dim = 768 |
|
|
| if vit_model == "dino": |
| vit_encoder = vits.__dict__[vit_arch](patch_size=vit_patch_size, num_classes=0) |
| |
| for p in vit_encoder.parameters(): |
| p.requires_grad = False |
| vit_encoder.eval() |
| state_dict = torch.hub.load_state_dict_from_url( |
| url="https://dl.fbaipublicfiles.com/dino/" + url |
| ) |
| vit_encoder.load_state_dict(state_dict, strict=True) |
|
|
| hook_features = {} |
| if enc_type_feats in ["k", "q", "v", "qkv", "mlp"]: |
| |
| def hook_fn_forward_qkv(module, input, output): |
| hook_features["qkv"] = output |
|
|
| vit_encoder._modules["blocks"][-1]._modules["attn"]._modules[ |
| "qkv" |
| ].register_forward_hook(hook_fn_forward_qkv) |
| else: |
| raise ValueError("Not implemented.") |
|
|
| return vit_encoder, initial_dim, hook_features |