| | """ |
| | SAM 2 Few-Shot Learning Model |
| | |
| | This module implements a few-shot segmentation model that combines SAM 2 with CLIP |
| | for domain adaptation using minimal labeled examples. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Dict, List, Optional, Tuple, Union |
| | import numpy as np |
| | from PIL import Image |
| | import clip |
| | from segment_anything_2 import sam_model_registry, SamPredictor |
| | from transformers import CLIPTextModel, CLIPTokenizer |
| |
|
| |
|
| | class SAM2FewShot(nn.Module): |
| | """ |
| | SAM 2 Few-Shot Learning Model |
| | |
| | Combines SAM 2 with CLIP for few-shot and zero-shot segmentation |
| | across different domains (satellite, fashion, robotics). |
| | """ |
| | |
| | def __init__( |
| | self, |
| | sam2_checkpoint: str, |
| | clip_model_name: str = "ViT-B/32", |
| | device: str = "cuda", |
| | prompt_engineering: bool = True, |
| | visual_similarity: bool = True, |
| | temperature: float = 0.1 |
| | ): |
| | super().__init__() |
| | self.device = device |
| | self.temperature = temperature |
| | self.prompt_engineering = prompt_engineering |
| | self.visual_similarity = visual_similarity |
| | |
| | |
| | self.sam2 = sam_model_registry["vit_h"](checkpoint=sam2_checkpoint) |
| | self.sam2.to(device) |
| | self.sam2_predictor = SamPredictor(self.sam2) |
| | |
| | |
| | self.clip_model, self.clip_preprocess = clip.load(clip_model_name, device=device) |
| | self.clip_model.eval() |
| | |
| | |
| | self.domain_prompts = { |
| | "satellite": { |
| | "building": ["building", "house", "structure", "rooftop"], |
| | "road": ["road", "street", "highway", "pavement"], |
| | "vegetation": ["vegetation", "forest", "trees", "green area"], |
| | "water": ["water", "lake", "river", "ocean", "pond"] |
| | }, |
| | "fashion": { |
| | "shirt": ["shirt", "t-shirt", "blouse", "top"], |
| | "pants": ["pants", "trousers", "jeans", "legs"], |
| | "dress": ["dress", "gown", "outfit"], |
| | "shoes": ["shoes", "footwear", "sneakers", "boots"] |
| | }, |
| | "robotics": { |
| | "robot": ["robot", "automation", "mechanical arm"], |
| | "tool": ["tool", "wrench", "screwdriver", "equipment"], |
| | "safety": ["safety equipment", "helmet", "vest", "protection"] |
| | } |
| | } |
| | |
| | |
| | self.few_shot_memory = {} |
| | |
| | def encode_text_prompts(self, domain: str, class_names: List[str]) -> torch.Tensor: |
| | """Encode text prompts for given domain and classes.""" |
| | prompts = [] |
| | for class_name in class_names: |
| | if domain in self.domain_prompts and class_name in self.domain_prompts[domain]: |
| | prompts.extend(self.domain_prompts[domain][class_name]) |
| | else: |
| | prompts.append(class_name) |
| | |
| | |
| | if domain == "satellite": |
| | prompts = [f"satellite image of {p}" for p in prompts] |
| | elif domain == "fashion": |
| | prompts = [f"fashion item {p}" for p in prompts] |
| | elif domain == "robotics": |
| | prompts = [f"robotics environment {p}" for p in prompts] |
| | |
| | text_tokens = clip.tokenize(prompts).to(self.device) |
| | with torch.no_grad(): |
| | text_features = self.clip_model.encode_text(text_tokens) |
| | text_features = F.normalize(text_features, dim=-1) |
| | |
| | return text_features |
| | |
| | def encode_image(self, image: Union[torch.Tensor, np.ndarray, Image.Image]) -> torch.Tensor: |
| | """Encode image using CLIP.""" |
| | if isinstance(image, torch.Tensor): |
| | if image.dim() == 4: |
| | image = image.squeeze(0) |
| | image = image.permute(1, 2, 0).cpu().numpy() |
| | |
| | if isinstance(image, np.ndarray): |
| | image = Image.fromarray(image) |
| | |
| | |
| | clip_image = self.clip_preprocess(image).unsqueeze(0).to(self.device) |
| | |
| | with torch.no_grad(): |
| | image_features = self.clip_model.encode_image(clip_image) |
| | image_features = F.normalize(image_features, dim=-1) |
| | |
| | return image_features |
| | |
| | def compute_similarity( |
| | self, |
| | image_features: torch.Tensor, |
| | text_features: torch.Tensor |
| | ) -> torch.Tensor: |
| | """Compute similarity between image and text features.""" |
| | similarity = torch.matmul(image_features, text_features.T) / self.temperature |
| | return similarity |
| | |
| | def add_few_shot_example( |
| | self, |
| | domain: str, |
| | class_name: str, |
| | image: torch.Tensor, |
| | mask: torch.Tensor |
| | ): |
| | """Add a few-shot example to the memory bank.""" |
| | if domain not in self.few_shot_memory: |
| | self.few_shot_memory[domain] = {} |
| | |
| | if class_name not in self.few_shot_memory[domain]: |
| | self.few_shot_memory[domain][class_name] = [] |
| | |
| | |
| | image_features = self.encode_image(image) |
| | |
| | self.few_shot_memory[domain][class_name].append({ |
| | 'image_features': image_features, |
| | 'mask': mask, |
| | 'image': image |
| | }) |
| | |
| | def get_few_shot_similarity( |
| | self, |
| | query_image: torch.Tensor, |
| | domain: str, |
| | class_name: str |
| | ) -> torch.Tensor: |
| | """Compute similarity with few-shot examples.""" |
| | if domain not in self.few_shot_memory or class_name not in self.few_shot_memory[domain]: |
| | return torch.zeros(1, device=self.device) |
| | |
| | query_features = self.encode_image(query_image) |
| | similarities = [] |
| | |
| | for example in self.few_shot_memory[domain][class_name]: |
| | similarity = F.cosine_similarity( |
| | query_features, |
| | example['image_features'], |
| | dim=-1 |
| | ) |
| | similarities.append(similarity) |
| | |
| | return torch.stack(similarities).mean() |
| | |
| | def generate_sam2_prompts( |
| | self, |
| | image: torch.Tensor, |
| | domain: str, |
| | class_names: List[str], |
| | use_few_shot: bool = True |
| | ) -> List[Dict]: |
| | """Generate SAM 2 prompts based on text and few-shot similarity.""" |
| | prompts = [] |
| | |
| | |
| | if self.prompt_engineering: |
| | text_features = self.encode_text_prompts(domain, class_names) |
| | image_features = self.encode_image(image) |
| | text_similarities = self.compute_similarity(image_features, text_features) |
| | |
| | |
| | for i, class_name in enumerate(class_names): |
| | if text_similarities[0, i] > 0.3: |
| | |
| | h, w = image.shape[-2:] |
| | point = [w // 2, h // 2] |
| | prompts.append({ |
| | 'type': 'point', |
| | 'data': point, |
| | 'label': 1, |
| | 'class': class_name, |
| | 'confidence': text_similarities[0, i].item() |
| | }) |
| | |
| | |
| | if use_few_shot and self.visual_similarity: |
| | for class_name in class_names: |
| | few_shot_sim = self.get_few_shot_similarity(image, domain, class_name) |
| | if few_shot_sim > 0.5: |
| | h, w = image.shape[-2:] |
| | point = [w // 2, h // 2] |
| | prompts.append({ |
| | 'type': 'point', |
| | 'data': point, |
| | 'label': 1, |
| | 'class': class_name, |
| | 'confidence': few_shot_sim.item() |
| | }) |
| | |
| | return prompts |
| | |
| | def segment( |
| | self, |
| | image: torch.Tensor, |
| | domain: str, |
| | class_names: List[str], |
| | use_few_shot: bool = True |
| | ) -> Dict[str, torch.Tensor]: |
| | """ |
| | Perform few-shot/zero-shot segmentation. |
| | |
| | Args: |
| | image: Input image tensor [C, H, W] |
| | domain: Domain name (satellite, fashion, robotics) |
| | class_names: List of class names to segment |
| | use_few_shot: Whether to use few-shot examples |
| | |
| | Returns: |
| | Dictionary with masks for each class |
| | """ |
| | |
| | if isinstance(image, torch.Tensor): |
| | image_np = image.permute(1, 2, 0).cpu().numpy() |
| | else: |
| | image_np = image |
| | |
| | |
| | self.sam2_predictor.set_image(image_np) |
| | |
| | |
| | prompts = self.generate_sam2_prompts(image, domain, class_names, use_few_shot) |
| | |
| | results = {} |
| | |
| | for prompt in prompts: |
| | class_name = prompt['class'] |
| | |
| | if prompt['type'] == 'point': |
| | point = prompt['data'] |
| | label = prompt['label'] |
| | |
| | |
| | masks, scores, logits = self.sam2_predictor.predict( |
| | point_coords=np.array([point]), |
| | point_labels=np.array([label]), |
| | multimask_output=True |
| | ) |
| | |
| | |
| | best_mask_idx = np.argmax(scores) |
| | mask = torch.from_numpy(masks[best_mask_idx]).float() |
| | |
| | |
| | if prompt['confidence'] > 0.3: |
| | results[class_name] = mask |
| | |
| | return results |
| | |
| | def forward( |
| | self, |
| | image: torch.Tensor, |
| | domain: str, |
| | class_names: List[str], |
| | use_few_shot: bool = True |
| | ) -> Dict[str, torch.Tensor]: |
| | """Forward pass for training.""" |
| | return self.segment(image, domain, class_names, use_few_shot) |
| |
|
| |
|
| | class FewShotTrainer: |
| | """Trainer for few-shot segmentation.""" |
| | |
| | def __init__(self, model: SAM2FewShot, learning_rate: float = 1e-4): |
| | self.model = model |
| | self.optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) |
| | self.criterion = nn.BCELoss() |
| | |
| | def train_step( |
| | self, |
| | support_images: List[torch.Tensor], |
| | support_masks: List[torch.Tensor], |
| | query_image: torch.Tensor, |
| | query_mask: torch.Tensor, |
| | domain: str, |
| | class_name: str |
| | ): |
| | """Single training step.""" |
| | self.model.train() |
| | |
| | |
| | for img, mask in zip(support_images, support_masks): |
| | self.model.add_few_shot_example(domain, class_name, img, mask) |
| | |
| | |
| | predictions = self.model(query_image, domain, [class_name], use_few_shot=True) |
| | |
| | if class_name in predictions: |
| | pred_mask = predictions[class_name] |
| | loss = self.criterion(pred_mask, query_mask) |
| | else: |
| | |
| | loss = torch.tensor(0.0, device=self.model.device, requires_grad=True) |
| | |
| | |
| | self.optimizer.zero_grad() |
| | loss.backward() |
| | self.optimizer.step() |
| | |
| | return loss.item() |