Raon-VisionEncoder / modeling_raonve.py
ValentineKRAFTON's picture
initial commit
acd771b verified
"""Raon-VisionEncoder model."""
import importlib
import os
import sys
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from .configuration_raonve import RaonVEConfig
_raon_repo_id = None
def set_repo_id(repo_id):
global _raon_repo_id
_raon_repo_id = repo_id
def _ensure_raon_package():
"""Import raon_vision_encoder, downloading from HF Hub if needed."""
try:
clip_mod = importlib.import_module("raon_vision_encoder.clip")
return clip_mod.CustomTextCLIP
except (ImportError, ModuleNotFoundError):
pass
from huggingface_hub import snapshot_download
repo_id = _raon_repo_id or "KRAFTON/Raon-VisionEncoder"
repo_dir = snapshot_download(repo_id, allow_patterns=["raon_vision_encoder/**"])
sys.path.insert(0, repo_dir)
for key in list(sys.modules.keys()):
if key.startswith("raon_vision_encoder"):
del sys.modules[key]
clip_mod = importlib.import_module("raon_vision_encoder.clip")
return clip_mod.CustomTextCLIP
class RaonVEPreTrainedModel(PreTrainedModel):
config_class = RaonVEConfig
base_model_prefix = ""
supports_gradient_checkpointing = True
def _init_weights(self, module):
pass
class RaonVEModel(RaonVEPreTrainedModel):
config_class = RaonVEConfig
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
set_repo_id(str(pretrained_model_name_or_path))
return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
def __init__(self, config: RaonVEConfig):
super().__init__(config)
vision_cfg = {
"image_size": config.vision_config.image_size,
"timm_model_name": config.vision_config.timm_model_name,
"timm_model_pretrained": config.vision_config.timm_model_pretrained,
"timm_pool": config.vision_config.timm_pool,
"timm_proj": config.vision_config.timm_proj,
}
text_cfg = {
"context_length": config.text_config.context_length,
"vocab_size": config.text_config.vocab_size,
"width": config.text_config.width,
"heads": config.text_config.heads,
"layers": config.text_config.layers,
"mlp_ratio": config.text_config.mlp_ratio,
"no_causal_mask": config.text_config.no_causal_mask,
"proj_bias": config.text_config.proj_bias,
"pool_type": config.text_config.pool_type,
"hf_tokenizer_name": config.text_config.hf_tokenizer_name,
"tokenizer_kwargs": config.text_config.tokenizer_kwargs,
"norm_kwargs": config.text_config.norm_kwargs,
"act_kwargs": config.text_config.act_kwargs,
}
CustomTextCLIP = _ensure_raon_package()
inner = CustomTextCLIP(
embed_dim=config.embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
init_logit_bias=config.init_logit_bias,
)
self.visual = inner.visual
self.text = inner.text
self.logit_scale = inner.logit_scale
self.logit_bias = inner.logit_bias
# Enable NaFlex by default
self.visual._setup_1d_forward()
self.post_init()
def encode_image(self, pixel_values, pixel_attention_mask=None, spatial_shapes=None):
"""Encode images to normalized feature vectors [B, 1152].
Pass the output of processor(images=...) directly via **inputs.
"""
kwargs = {}
if pixel_attention_mask is not None:
kwargs["patch_valid_mask"] = pixel_attention_mask
if spatial_shapes is not None:
kwargs["spatial_shapes"] = spatial_shapes
features = self.visual(pixel_values, **kwargs) if kwargs else self.visual(pixel_values)
return F.normalize(features, dim=-1)
def encode_text(self, input_ids):
"""Encode text to normalized feature vectors [B, 1152].
Pass the output of processor(text=...) directly via **inputs.
"""
features = self.text(input_ids)
return F.normalize(features, dim=-1)
def forward(self, pixel_values=None, input_ids=None, pixel_attention_mask=None, spatial_shapes=None):
image_features = None
text_features = None
if pixel_values is not None:
image_features = self.encode_image(
pixel_values,
pixel_attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
)
if input_ids is not None:
text_features = self.encode_text(input_ids)
output = {
"image_features": image_features,
"text_features": text_features,
"logit_scale": self.logit_scale,
"logit_bias": self.logit_bias,
}
return output
@staticmethod
def get_processor(pretrained_model_name_or_path, **kwargs):
"""Get the processor for this model."""
return RaonVEProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs)
class RaonVEProcessor:
"""Image and text processor for Raon-VisionEncoder.
Preprocesses images into NaFlex patch sequences and tokenizes text.
Args:
max_num_patches: Maximum number of patches per image (controls resolution).
Higher values preserve more detail. Default: 256.
"""
DEFAULT_MAX_PATCHES = 256
def __init__(self, patch_size=16, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711), tokenizer=None):
from torchvision import transforms as T
self.patch_size = patch_size
self.mean, self.std = mean, std
self.tokenizer = tokenizer
self._post = T.Compose([T.ToTensor(), T.Normalize(mean=list(mean), std=list(std))])
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
import json
from pathlib import Path as _Path
if _Path(pretrained_model_name_or_path).is_dir():
cfg_path = _Path(pretrained_model_name_or_path) / "config.json"
else:
from huggingface_hub import hf_hub_download
cfg_path = hf_hub_download(pretrained_model_name_or_path, "config.json")
with open(cfg_path) as f:
cfg = json.load(f)
v = cfg.get("vision_config", {}); t = cfg.get("text_config", {})
ps = 16
for part in v.get("timm_model_name", "").split("_"):
if part.startswith("patch") and part[5:].isdigit():
ps = int(part[5:]); break
tokenizer = None
if t.get("hf_tokenizer_name"):
_ensure_raon_package()
tok_mod = importlib.import_module("raon_vision_encoder.tokenizer")
tokenizer = tok_mod.HFTokenizer(
t["hf_tokenizer_name"], context_length=t.get("context_length", 64),
tokenizer_mode=t.get("tokenizer_mode"), **t.get("tokenizer_kwargs", {}),
)
return cls(patch_size=ps, tokenizer=tokenizer)
def __call__(self, images=None, text=None, max_num_patches=None, return_tensors="pt"):
"""Process images and/or text.
Args:
images: PIL Image or list of PIL Images.
text: String or list of strings.
max_num_patches: Resolution budget (default: 256). Higher = more detail.
Returns:
Dict with 'pixel_values', 'pixel_attention_mask', 'spatial_shapes' for images
and/or 'input_ids' for text.
"""
from PIL import Image
result = {}
if images is not None:
mnp = max_num_patches or self.DEFAULT_MAX_PATCHES
_ensure_raon_package()
transform_mod = importlib.import_module("raon_vision_encoder.transform")
get_size = transform_mod.get_image_size_for_max_num_patches
imgs = [images] if isinstance(images, Image.Image) else images
ps = self.patch_size
all_p, all_m, all_s = [], [], []
for img in imgs:
img = img.convert("RGB")
w, h = img.size
th, tw = get_size(h, w, ps, mnp)
t = self._post(img.resize((tw, th), Image.BICUBIC))
gh, gw = th // ps, tw // ps
n = gh * gw
# [C, gh, ps, gw, ps] -> [gh, gw, C, ps, ps] -> [n, C*ps*ps]
patches = t.reshape(3, gh, ps, gw, ps).permute(1,3,0,2,4).reshape(n, 3*ps*ps)
padded = torch.zeros(mnp, ps*ps*3); padded[:n] = patches
mask = torch.zeros(mnp, dtype=torch.bool); mask[:n] = True
all_p.append(padded); all_m.append(mask)
all_s.append(torch.tensor([gh, gw]))
result["pixel_values"] = torch.stack(all_p)
result["pixel_attention_mask"] = torch.stack(all_m)
result["spatial_shapes"] = torch.stack(all_s)
if text is not None:
if self.tokenizer is None:
raise RuntimeError("Tokenizer not initialized.")
result["input_ids"] = self.tokenizer([text] if isinstance(text, str) else text)
return result