| from typing import Iterable, List, Optional, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from PIL import Image |
|
|
| from .download import default_cache_dir |
|
|
| ImageType = Union[np.ndarray, torch.Tensor, Image.Image] |
|
|
|
|
| class ImageCLIP(nn.Module): |
| """ |
| A wrapper around a pre-trained CLIP model that automatically handles |
| batches of texts, images, and embeddings. |
| """ |
|
|
| def __init__( |
| self, |
| device: torch.device, |
| dtype: Optional[torch.dtype] = torch.float32, |
| ensure_used_params: bool = True, |
| clip_name: str = "ViT-L/14", |
| cache_dir: Optional[str] = None, |
| ): |
| super().__init__() |
|
|
| assert clip_name in ["ViT-L/14", "ViT-B/32"] |
|
|
| self.device = device |
| self.ensure_used_params = ensure_used_params |
|
|
| |
| import clip |
|
|
| self.clip_model, self.preprocess = clip.load( |
| clip_name, device=device, download_root=cache_dir or default_cache_dir() |
| ) |
| self.clip_name = clip_name |
|
|
| if dtype is not None: |
| self.clip_model.to(dtype) |
| self._tokenize = clip.tokenize |
|
|
| @property |
| def feature_dim(self) -> int: |
| if self.clip_name == "ViT-L/14": |
| return 768 |
| else: |
| return 512 |
|
|
| @property |
| def grid_size(self) -> int: |
| if self.clip_name == "ViT-L/14": |
| return 16 |
| else: |
| return 7 |
|
|
| @property |
| def grid_feature_dim(self) -> int: |
| if self.clip_name == "ViT-L/14": |
| return 1024 |
| else: |
| return 768 |
|
|
| def forward( |
| self, |
| batch_size: int, |
| images: Optional[Iterable[Optional[ImageType]]] = None, |
| texts: Optional[Iterable[Optional[str]]] = None, |
| embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None, |
| ) -> torch.Tensor: |
| """ |
| Generate a batch of embeddings from a mixture of images, texts, |
| precomputed embeddings, and possibly empty values. |
| |
| For each batch element, at most one of images, texts, and embeddings |
| should have a non-None value. Embeddings from multiple modalities |
| cannot be mixed for a single batch element. If no modality is provided, |
| a zero embedding will be used for the batch element. |
| """ |
| image_seq = [None] * batch_size if images is None else list(images) |
| text_seq = [None] * batch_size if texts is None else list(texts) |
| embedding_seq = [None] * batch_size if embeddings is None else list(embeddings) |
| assert len(image_seq) == batch_size, "number of images should match batch size" |
| assert len(text_seq) == batch_size, "number of texts should match batch size" |
| assert len(embedding_seq) == batch_size, "number of embeddings should match batch size" |
|
|
| if self.ensure_used_params: |
| return self._static_multimodal_embed( |
| images=image_seq, texts=text_seq, embeddings=embedding_seq |
| ) |
|
|
| result = torch.zeros((batch_size, self.feature_dim), device=self.device) |
| index_images = [] |
| index_texts = [] |
| for i, (image, text, emb) in enumerate(zip(image_seq, text_seq, embedding_seq)): |
| assert ( |
| sum([int(image is not None), int(text is not None), int(emb is not None)]) < 2 |
| ), "only one modality may be non-None per batch element" |
| if image is not None: |
| index_images.append((i, image)) |
| elif text is not None: |
| index_texts.append((i, text)) |
| elif emb is not None: |
| result[i] = emb.to(result) |
|
|
| if len(index_images): |
| embs = self.embed_images((img for _, img in index_images)) |
| for (i, _), emb in zip(index_images, embs): |
| result[i] = emb.to(result) |
| if len(index_texts): |
| embs = self.embed_text((text for _, text in index_texts)) |
| for (i, _), emb in zip(index_texts, embs): |
| result[i] = emb.to(result) |
|
|
| return result |
|
|
| def _static_multimodal_embed( |
| self, |
| images: List[Optional[ImageType]] = None, |
| texts: List[Optional[str]] = None, |
| embeddings: List[Optional[torch.Tensor]] = None, |
| ) -> torch.Tensor: |
| """ |
| Like forward(), but always runs all encoders to ensure that |
| the forward graph looks the same on every rank. |
| """ |
| image_emb = self.embed_images(images) |
| text_emb = self.embed_text(t if t else "" for t in texts) |
| joined_embs = torch.stack( |
| [ |
| emb.to(device=self.device, dtype=torch.float32) |
| if emb is not None |
| else torch.zeros(self.feature_dim, device=self.device) |
| for emb in embeddings |
| ], |
| dim=0, |
| ) |
|
|
| image_flag = torch.tensor([x is not None for x in images], device=self.device)[ |
| :, None |
| ].expand_as(image_emb) |
| text_flag = torch.tensor([x is not None for x in texts], device=self.device)[ |
| :, None |
| ].expand_as(image_emb) |
| emb_flag = torch.tensor([x is not None for x in embeddings], device=self.device)[ |
| :, None |
| ].expand_as(image_emb) |
|
|
| return ( |
| image_flag.float() * image_emb |
| + text_flag.float() * text_emb |
| + emb_flag.float() * joined_embs |
| + self.clip_model.logit_scale * 0 |
| ) |
|
|
| def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: |
| """ |
| :param xs: N images, stored as numpy arrays, tensors, or PIL images. |
| :return: an [N x D] tensor of features. |
| """ |
| clip_inputs = self.images_to_tensor(xs) |
| results = self.clip_model.encode_image(clip_inputs).float() |
| return results / torch.linalg.norm(results, dim=-1, keepdim=True) |
|
|
| def embed_text(self, prompts: Iterable[str]) -> torch.Tensor: |
| """ |
| Embed text prompts as an [N x D] tensor. |
| """ |
| enc = self.clip_model.encode_text( |
| self._tokenize(list(prompts), truncate=True).to(self.device) |
| ).float() |
| return enc / torch.linalg.norm(enc, dim=-1, keepdim=True) |
|
|
| def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: |
| """ |
| Embed images into latent grids. |
| |
| :param xs: an iterable of images to embed. |
| :return: a tensor of shape [N x C x L], where L = self.grid_size**2. |
| """ |
| if self.ensure_used_params: |
| extra_value = 0.0 |
| for p in self.parameters(): |
| extra_value = extra_value + p.mean() * 0.0 |
| else: |
| extra_value = 0.0 |
|
|
| x = self.images_to_tensor(xs).to(self.clip_model.dtype) |
|
|
| |
| vt = self.clip_model.visual |
| x = vt.conv1(x) |
| x = x.reshape(x.shape[0], x.shape[1], -1) |
| x = x.permute(0, 2, 1) |
| x = torch.cat( |
| [ |
| vt.class_embedding.to(x.dtype) |
| + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), |
| x, |
| ], |
| dim=1, |
| ) |
| x = x + vt.positional_embedding.to(x.dtype) |
| x = vt.ln_pre(x) |
|
|
| x = x.permute(1, 0, 2) |
| x = vt.transformer(x) |
| x = x.permute(1, 2, 0) |
|
|
| return x[..., 1:].contiguous().float() + extra_value |
|
|
| def images_to_tensor(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: |
| return torch.stack([self.preprocess(_image_to_pil(x)) for x in xs], dim=0).to(self.device) |
|
|
|
|
| class FrozenImageCLIP: |
| def __init__(self, device: torch.device, **kwargs): |
| self.model = ImageCLIP(device, dtype=None, ensure_used_params=False, **kwargs) |
| for parameter in self.model.parameters(): |
| parameter.requires_grad_(False) |
|
|
| @property |
| def feature_dim(self) -> int: |
| return self.model.feature_dim |
|
|
| @property |
| def grid_size(self) -> int: |
| return self.model.grid_size |
|
|
| @property |
| def grid_feature_dim(self) -> int: |
| return self.model.grid_feature_dim |
|
|
| def __call__( |
| self, |
| batch_size: int, |
| images: Optional[Iterable[Optional[ImageType]]] = None, |
| texts: Optional[Iterable[Optional[str]]] = None, |
| embeddings: Optional[Iterable[Optional[torch.Tensor]]] = None, |
| ) -> torch.Tensor: |
| |
| |
| |
| return self.model(batch_size=batch_size, images=images, texts=texts, embeddings=embeddings) |
|
|
| def embed_images(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: |
| with torch.no_grad(): |
| return self.model.embed_images(xs) |
|
|
| def embed_text(self, prompts: Iterable[str]) -> torch.Tensor: |
| with torch.no_grad(): |
| return self.model.embed_text(prompts) |
|
|
| def embed_images_grid(self, xs: Iterable[Optional[ImageType]]) -> torch.Tensor: |
| with torch.no_grad(): |
| return self.model.embed_images_grid(xs) |
|
|
|
|
| def _image_to_pil(obj: Optional[ImageType]) -> Image.Image: |
| if obj is None: |
| return Image.fromarray(np.zeros([64, 64, 3], dtype=np.uint8)) |
| if isinstance(obj, np.ndarray): |
| return Image.fromarray(obj.astype(np.uint8)) |
| elif isinstance(obj, torch.Tensor): |
| return Image.fromarray(obj.detach().cpu().numpy().astype(np.uint8)) |
| else: |
| return obj |
|
|