Spaces:
Running
Running
| """ | |
| step2_encode_image.py | |
| ====================== | |
| STEP 2 β Encode an image through BLIP's Vision Transformer (ViT). | |
| Responsibilities: | |
| - Accept a PIL image. | |
| - Run it through the ViT image encoder. | |
| - Return encoder_hidden_states (197 patch tokens Γ 768 dim). | |
| - Also return the encoder_attention_mask. | |
| This is kept separate so the expensive ViT encode is run ONCE for | |
| however many words we later generate β zero redundant computation. | |
| Shape of the output: | |
| encoder_hidden_states : (1, 197, 768) | |
| encoder_mask : (1, 197) β all-ones (no padding) | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| from PIL import Image | |
| _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| _PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR)) | |
| if _PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, _PROJECT_ROOT) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _ENCODE_SIZE = 224 # ViT expects 224Γ224 β 14Γ14 patches of size 16 | |
| def encode_image(model, processor, device, image_pil: Image.Image, verbose: bool = True): | |
| """ | |
| Encode a PIL image through BLIP's ViT backbone. | |
| Args: | |
| model β BlipForConditionalGeneration. | |
| processor β BlipProcessor. | |
| device β torch.device. | |
| image_pil β Any PIL image (will be resized to 224Γ224 internally). | |
| verbose β Print progress. | |
| Returns: | |
| image_224 β PIL image resized to 224Γ224. | |
| encoder_hidden β Tensor (1, 197, 768), detached, no grad. | |
| encoder_mask β Tensor (1, 197), all-ones. | |
| """ | |
| image_224 = image_pil.resize((_ENCODE_SIZE, _ENCODE_SIZE), Image.LANCZOS) | |
| inputs = processor(images=image_224, return_tensors="pt").to(device) | |
| if verbose: | |
| print(f"π· Image resized to {_ENCODE_SIZE}Γ{_ENCODE_SIZE} and encoded through ViT β¦") | |
| with torch.no_grad(): | |
| vision_out = model.vision_model(pixel_values=inputs["pixel_values"]) | |
| # Shape: (1, 197, 768) β 1 [CLS] token + 196 patch tokens | |
| encoder_hidden = vision_out[0].detach().requires_grad_(False) | |
| encoder_mask = torch.ones(encoder_hidden.size()[:-1], dtype=torch.long, device=device) | |
| if verbose: | |
| print(f"β Encoder output shape: {encoder_hidden.shape} " | |
| f"(1 CLS + {encoder_hidden.shape[1]-1} patches)") | |
| return image_224, encoder_hidden, encoder_mask | |