| import kornia |
| import torch |
|
|
| from .configuration_disk import DiskConfig |
| from transformers.models.superpoint.modeling_superpoint import ( |
| SuperPointKeypointDescriptionOutput, |
| ) |
| from transformers import PreTrainedModel |
|
|
|
|
| class DiskForKeypointDetection(PreTrainedModel): |
| config_class = DiskConfig |
|
|
| def __init__(self, config: DiskConfig): |
| super().__init__(config) |
|
|
| self.config = config |
| self.model = kornia.feature.DISK(self.config.descriptor_decoder_dim) |
|
|
| def forward( |
| self, pixel_values: torch.Tensor |
| ) -> SuperPointKeypointDescriptionOutput: |
| detections = self.model( |
| pixel_values, |
| n=self.config.max_num_keypoints, |
| window_size=self.config.nms_window_size, |
| score_threshold=self.config.detection_threshold, |
| pad_if_not_divisible=self.config.pad_if_not_divisible, |
| ) |
| max_num_keypoints = max( |
| detection.keypoints.shape[0] for detection in detections |
| ) |
| keypoints = torch.zeros( |
| len(detections), max_num_keypoints, 2, device=pixel_values.device |
| ) |
| descriptors = torch.zeros( |
| len(detections), |
| max_num_keypoints, |
| self.config.descriptor_decoder_dim, |
| device=pixel_values.device, |
| ) |
| scores = torch.zeros( |
| len(detections), max_num_keypoints, device=pixel_values.device |
| ) |
| mask = torch.zeros( |
| len(detections), max_num_keypoints, device=pixel_values.device |
| ) |
| for i, detection in enumerate(detections): |
| keypoints[i, : detection.keypoints.shape[0]] = detection.keypoints |
| descriptors[i, : detection.descriptors.shape[0]] = detection.descriptors |
| scores[i, : detection.detection_scores.shape[0]] = ( |
| detection.detection_scores |
| ) |
| mask[i, : detection.detection_scores.shape[0]] = 1 |
| width, height = pixel_values.shape[-1], pixel_values.shape[-2] |
| keypoints[:, :, 0] = keypoints[:, :, 0] / width |
| keypoints[:, :, 1] = keypoints[:, :, 1] / height |
|
|
| return SuperPointKeypointDescriptionOutput( |
| keypoints=keypoints, |
| scores=scores, |
| descriptors=descriptors, |
| mask=mask, |
| ) |
|
|