| """Raon-VisionEncoder model.""" |
|
|
| import importlib |
| import os |
| import sys |
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| from transformers import PreTrainedModel |
|
|
| from .configuration_raonve import RaonVEConfig |
|
|
|
|
| _raon_repo_id = None |
|
|
| def set_repo_id(repo_id): |
| global _raon_repo_id |
| _raon_repo_id = repo_id |
|
|
| def _ensure_raon_package(): |
| """Import raon_vision_encoder, downloading from HF Hub if needed.""" |
| try: |
| clip_mod = importlib.import_module("raon_vision_encoder.clip") |
| return clip_mod.CustomTextCLIP |
| except (ImportError, ModuleNotFoundError): |
| pass |
|
|
| from huggingface_hub import snapshot_download |
| repo_id = _raon_repo_id or "KRAFTON/Raon-VisionEncoder" |
| repo_dir = snapshot_download(repo_id, allow_patterns=["raon_vision_encoder/**"]) |
| sys.path.insert(0, repo_dir) |
|
|
| for key in list(sys.modules.keys()): |
| if key.startswith("raon_vision_encoder"): |
| del sys.modules[key] |
|
|
| clip_mod = importlib.import_module("raon_vision_encoder.clip") |
| return clip_mod.CustomTextCLIP |
|
|
|
|
| class RaonVEPreTrainedModel(PreTrainedModel): |
| config_class = RaonVEConfig |
| base_model_prefix = "" |
| supports_gradient_checkpointing = True |
|
|
| def _init_weights(self, module): |
| pass |
|
|
|
|
| class RaonVEModel(RaonVEPreTrainedModel): |
| config_class = RaonVEConfig |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| set_repo_id(str(pretrained_model_name_or_path)) |
| return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
|
|
| def __init__(self, config: RaonVEConfig): |
| super().__init__(config) |
|
|
| vision_cfg = { |
| "image_size": config.vision_config.image_size, |
| "timm_model_name": config.vision_config.timm_model_name, |
| "timm_model_pretrained": config.vision_config.timm_model_pretrained, |
| "timm_pool": config.vision_config.timm_pool, |
| "timm_proj": config.vision_config.timm_proj, |
| } |
| text_cfg = { |
| "context_length": config.text_config.context_length, |
| "vocab_size": config.text_config.vocab_size, |
| "width": config.text_config.width, |
| "heads": config.text_config.heads, |
| "layers": config.text_config.layers, |
| "mlp_ratio": config.text_config.mlp_ratio, |
| "no_causal_mask": config.text_config.no_causal_mask, |
| "proj_bias": config.text_config.proj_bias, |
| "pool_type": config.text_config.pool_type, |
| "hf_tokenizer_name": config.text_config.hf_tokenizer_name, |
| "tokenizer_kwargs": config.text_config.tokenizer_kwargs, |
| "norm_kwargs": config.text_config.norm_kwargs, |
| "act_kwargs": config.text_config.act_kwargs, |
| } |
|
|
| CustomTextCLIP = _ensure_raon_package() |
| inner = CustomTextCLIP( |
| embed_dim=config.embed_dim, |
| vision_cfg=vision_cfg, |
| text_cfg=text_cfg, |
| init_logit_bias=config.init_logit_bias, |
| ) |
|
|
| self.visual = inner.visual |
| self.text = inner.text |
| self.logit_scale = inner.logit_scale |
| self.logit_bias = inner.logit_bias |
|
|
| |
| self.visual._setup_1d_forward() |
|
|
| self.post_init() |
|
|
| def encode_image(self, pixel_values, pixel_attention_mask=None, spatial_shapes=None): |
| """Encode images to normalized feature vectors [B, 1152]. |
| Pass the output of processor(images=...) directly via **inputs. |
| """ |
| kwargs = {} |
| if pixel_attention_mask is not None: |
| kwargs["patch_valid_mask"] = pixel_attention_mask |
| if spatial_shapes is not None: |
| kwargs["spatial_shapes"] = spatial_shapes |
| features = self.visual(pixel_values, **kwargs) if kwargs else self.visual(pixel_values) |
| return F.normalize(features, dim=-1) |
|
|
| def encode_text(self, input_ids): |
| """Encode text to normalized feature vectors [B, 1152]. |
| Pass the output of processor(text=...) directly via **inputs. |
| """ |
| features = self.text(input_ids) |
| return F.normalize(features, dim=-1) |
|
|
| def forward(self, pixel_values=None, input_ids=None, pixel_attention_mask=None, spatial_shapes=None): |
| image_features = None |
| text_features = None |
|
|
| if pixel_values is not None: |
| image_features = self.encode_image( |
| pixel_values, |
| pixel_attention_mask=pixel_attention_mask, |
| spatial_shapes=spatial_shapes, |
| ) |
| if input_ids is not None: |
| text_features = self.encode_text(input_ids) |
|
|
| output = { |
| "image_features": image_features, |
| "text_features": text_features, |
| "logit_scale": self.logit_scale, |
| "logit_bias": self.logit_bias, |
| } |
| return output |
|
|
| @staticmethod |
| def get_processor(pretrained_model_name_or_path, **kwargs): |
| """Get the processor for this model.""" |
| return RaonVEProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs) |
|
|
|
|
| class RaonVEProcessor: |
| """Image and text processor for Raon-VisionEncoder. |
| |
| Preprocesses images into NaFlex patch sequences and tokenizes text. |
| |
| Args: |
| max_num_patches: Maximum number of patches per image (controls resolution). |
| Higher values preserve more detail. Default: 256. |
| """ |
|
|
| DEFAULT_MAX_PATCHES = 256 |
|
|
| def __init__(self, patch_size=16, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), tokenizer=None): |
| from torchvision import transforms as T |
| self.patch_size = patch_size |
| self.mean, self.std = mean, std |
| self.tokenizer = tokenizer |
| self._post = T.Compose([T.ToTensor(), T.Normalize(mean=list(mean), std=list(std))]) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| import json |
| from pathlib import Path as _Path |
| if _Path(pretrained_model_name_or_path).is_dir(): |
| cfg_path = _Path(pretrained_model_name_or_path) / "config.json" |
| else: |
| from huggingface_hub import hf_hub_download |
| cfg_path = hf_hub_download(pretrained_model_name_or_path, "config.json") |
| with open(cfg_path) as f: |
| cfg = json.load(f) |
| v = cfg.get("vision_config", {}); t = cfg.get("text_config", {}) |
| ps = 16 |
| for part in v.get("timm_model_name", "").split("_"): |
| if part.startswith("patch") and part[5:].isdigit(): |
| ps = int(part[5:]); break |
| tokenizer = None |
| if t.get("hf_tokenizer_name"): |
| _ensure_raon_package() |
| tok_mod = importlib.import_module("raon_vision_encoder.tokenizer") |
| tokenizer = tok_mod.HFTokenizer( |
| t["hf_tokenizer_name"], context_length=t.get("context_length", 64), |
| tokenizer_mode=t.get("tokenizer_mode"), **t.get("tokenizer_kwargs", {}), |
| ) |
| return cls(patch_size=ps, tokenizer=tokenizer) |
|
|
| def __call__(self, images=None, text=None, max_num_patches=None, return_tensors="pt"): |
| """Process images and/or text. |
| |
| Args: |
| images: PIL Image or list of PIL Images. |
| text: String or list of strings. |
| max_num_patches: Resolution budget (default: 256). Higher = more detail. |
| |
| Returns: |
| Dict with 'pixel_values', 'pixel_attention_mask', 'spatial_shapes' for images |
| and/or 'input_ids' for text. |
| """ |
| from PIL import Image |
| result = {} |
| if images is not None: |
| mnp = max_num_patches or self.DEFAULT_MAX_PATCHES |
| _ensure_raon_package() |
| transform_mod = importlib.import_module("raon_vision_encoder.transform") |
| get_size = transform_mod.get_image_size_for_max_num_patches |
| imgs = [images] if isinstance(images, Image.Image) else images |
| ps = self.patch_size |
| all_p, all_m, all_s = [], [], [] |
| for img in imgs: |
| img = img.convert("RGB") |
| w, h = img.size |
| th, tw = get_size(h, w, ps, mnp) |
| t = self._post(img.resize((tw, th), Image.BICUBIC)) |
| gh, gw = th // ps, tw // ps |
| n = gh * gw |
| |
| patches = t.reshape(3, gh, ps, gw, ps).permute(1,3,0,2,4).reshape(n, 3*ps*ps) |
| padded = torch.zeros(mnp, ps*ps*3); padded[:n] = patches |
| mask = torch.zeros(mnp, dtype=torch.bool); mask[:n] = True |
| all_p.append(padded); all_m.append(mask) |
| all_s.append(torch.tensor([gh, gw])) |
| result["pixel_values"] = torch.stack(all_p) |
| result["pixel_attention_mask"] = torch.stack(all_m) |
| result["spatial_shapes"] = torch.stack(all_s) |
| if text is not None: |
| if self.tokenizer is None: |
| raise RuntimeError("Tokenizer not initialized.") |
| result["input_ids"] = self.tokenizer([text] if isinstance(text, str) else text) |
| return result |
|
|