| |
|
|
| import numpy as np |
| import torch |
| from torch.nn.functional import normalize |
|
|
| from . import get_model |
| from models.base import BaseModel |
| |
| |
| from models.voting import ( |
| argmax_xyr, |
| conv2d_fft_batchwise, |
| expectation_xyr, |
| log_softmax_spatial, |
| mask_yaw_prior, |
| nll_loss_xyr, |
| nll_loss_xyr_smoothed, |
| TemplateSampler, |
| UAVTemplateSampler, |
| UAVTemplateSamplerFast |
| ) |
| from .map_encoder import MapEncoder |
| from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall |
|
|
|
|
| class MapLocNet(BaseModel): |
| default_conf = { |
| "image_size": "???", |
| "val_citys":"???", |
| "image_encoder": "???", |
| "map_encoder": "???", |
| "bev_net": "???", |
| "latent_dim": "???", |
| "matching_dim": "???", |
| "scale_range": [0, 9], |
| "num_scale_bins": "???", |
| "z_min": None, |
| "z_max": "???", |
| "x_max": "???", |
| "pixel_per_meter": "???", |
| "num_rotations": "???", |
| "add_temperature": False, |
| "normalize_features": False, |
| "padding_matching": "replicate", |
| "apply_map_prior": True, |
| "do_label_smoothing": False, |
| "sigma_xy": 1, |
| "sigma_r": 2, |
| |
| "depth_parameterization": "scale", |
| "norm_depth_scores": False, |
| "normalize_scores_by_dim": False, |
| "normalize_scores_by_num_valid": True, |
| "prior_renorm": True, |
| "retrieval_dim": None, |
| } |
|
|
| def _init(self, conf): |
| assert not self.conf.norm_depth_scores |
| assert self.conf.depth_parameterization == "scale" |
| assert not self.conf.normalize_scores_by_dim |
| assert self.conf.normalize_scores_by_num_valid |
| assert self.conf.prior_renorm |
|
|
| Encoder = get_model(conf.image_encoder.get("name", "feature_extractor_v2")) |
| self.image_encoder = Encoder(conf.image_encoder.backbone) |
| self.map_encoder = MapEncoder(conf.map_encoder) |
| |
|
|
| ppm = conf.pixel_per_meter |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.template_sampler = UAVTemplateSampler(conf.num_rotations) |
| |
| |
| |
| |
| |
| if conf.add_temperature: |
| temperature = torch.nn.Parameter(torch.tensor(0.0)) |
| self.register_parameter("temperature", temperature) |
|
|
| def exhaustive_voting(self, f_bev, f_map): |
| if self.conf.normalize_features: |
| f_bev = normalize(f_bev, dim=1) |
| f_map = normalize(f_map, dim=1) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| templates = self.template_sampler(f_bev) |
| |
| with torch.autocast("cuda", enabled=False): |
| scores = conv2d_fft_batchwise( |
| f_map.float(), |
| templates.float(), |
| padding_mode=self.conf.padding_matching, |
| ) |
| if self.conf.add_temperature: |
| scores = scores * torch.exp(self.temperature) |
|
|
| |
| |
| |
| |
| |
| return scores |
|
|
| def _forward(self, data): |
| pred = {} |
| pred_map = pred["map"] = self.map_encoder(data) |
| f_map = pred_map["map_features"][0] |
|
|
| |
| level = 0 |
| f_image = self.image_encoder(data)["feature_maps"][level] |
| |
|
|
| scores = self.exhaustive_voting(f_image, f_map) |
| scores = scores.moveaxis(1, -1) |
| if "log_prior" in pred_map and self.conf.apply_map_prior: |
| scores = scores + pred_map["log_prior"][0].unsqueeze(-1) |
| |
| if "map_mask" in data: |
| scores.masked_fill_(~data["map_mask"][..., None], -np.inf) |
| if "yaw_prior" in data: |
| mask_yaw_prior(scores, data["yaw_prior"], self.conf.num_rotations) |
| log_probs = log_softmax_spatial(scores) |
| |
| with torch.no_grad(): |
| uvr_max = argmax_xyr(scores).to(scores) |
| uvr_avg, _ = expectation_xyr(log_probs.exp()) |
|
|
| return { |
| **pred, |
| "scores": scores, |
| "log_probs": log_probs, |
| "uvr_max": uvr_max, |
| "uv_max": uvr_max[..., :2], |
| "yaw_max": uvr_max[..., 2], |
| "uvr_expectation": uvr_avg, |
| "uv_expectation": uvr_avg[..., :2], |
| "yaw_expectation": uvr_avg[..., 2], |
| "features_image": f_image, |
| } |
|
|
| def loss(self, pred, data): |
| xy_gt = data["uv"] |
| yaw_gt = data["roll_pitch_yaw"][..., -1] |
| if self.conf.do_label_smoothing: |
| nll = nll_loss_xyr_smoothed( |
| pred["log_probs"], |
| xy_gt, |
| yaw_gt, |
| self.conf.sigma_xy / self.conf.pixel_per_meter, |
| self.conf.sigma_r, |
| mask=data.get("map_mask"), |
| ) |
| else: |
| nll = nll_loss_xyr(pred["log_probs"], xy_gt, yaw_gt) |
| loss = {"total": nll, "nll": nll} |
| if self.training and self.conf.add_temperature: |
| loss["temperature"] = self.temperature.expand(len(nll)) |
| return loss |
|
|
| def metrics(self): |
| return { |
| "xy_max_error": Location2DError("uv_max", self.conf.pixel_per_meter), |
| "xy_expectation_error": Location2DError( |
| "uv_expectation", self.conf.pixel_per_meter |
| ), |
| "yaw_max_error": AngleError("yaw_max"), |
| "xy_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"), |
| "xy_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"), |
| "xy_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"), |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| "yaw_recall_1°": AngleRecall(1.0, "yaw_max"), |
| "yaw_recall_3°": AngleRecall(3.0, "yaw_max"), |
| "yaw_recall_5°": AngleRecall(5.0, "yaw_max"), |
| } |
|
|