from __future__ import annotations from dataclasses import dataclass from functools import lru_cache from typing import Iterable, List import torch @dataclass class Blip2TextStateEncoderConfig: model_name_or_path: str = "Salesforce/blip2-itm-vit-g" device: str = "cpu" torch_dtype: torch.dtype = torch.float16 max_length: int = 32 class Blip2TextStateEncoder: """ 用 BLIP2 的 `Blip2TextModelWithProjection` 把状态文本编码为一个向量(text_embeds)。 设计目标: - 状态在数据里用可读字符串(例如 "raw", "cooked") - 训练/推理阶段把这些字符串变成 state_features: (B,N,D_text) - 下游 InstanceFeatureExtractor 再把 D_text 投影到 DiT hidden_dim """ def __init__(self, cfg: Blip2TextStateEncoderConfig): self.cfg = cfg self._tokenizer = None self._model = None def _lazy_init(self): if self._model is not None: return from transformers import AutoTokenizer, Blip2TextModelWithProjection self._tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name_or_path) self._model = Blip2TextModelWithProjection.from_pretrained( self.cfg.model_name_or_path, torch_dtype=self.cfg.torch_dtype, ) self._model.eval() for p in self._model.parameters(): p.requires_grad_(False) self._model.to(device=self.cfg.device) @torch.inference_mode() def encode_texts(self, texts: List[str]) -> torch.Tensor: self._lazy_init() tok = self._tokenizer( texts, padding=True, truncation=True, max_length=self.cfg.max_length, return_tensors="pt", ) tok = {k: v.to(self.cfg.device) for k, v in tok.items()} out = self._model(**tok) # (B, D_text) return out.text_embeds.to(dtype=torch.float32, device="cpu") def encode_state_text_tensor( state_texts: list, model_name_or_path: str = "Salesforce/blip2-itm-vit-g", device: str = "cpu", torch_dtype: torch.dtype = torch.float16, max_length: int = 32, ) -> torch.Tensor: """ 将嵌套 list 的 state_texts(B,N)编码成 tensor: (B,N,D_text) float32 on CPU。 """ if not isinstance(state_texts, list) or not state_texts: raise ValueError("state_texts must be a non-empty nested list (B,N)") if not isinstance(state_texts[0], list): raise ValueError("state_texts must be nested list like [[...], [...]]") encoder = Blip2TextStateEncoder( Blip2TextStateEncoderConfig( model_name_or_path=model_name_or_path, device=device, torch_dtype=torch_dtype, max_length=max_length, ) ) # flatten unique texts to avoid redundant encode all_texts = [] for row in state_texts: for t in row: if not isinstance(t, str): raise ValueError(f"state_text must be str, got: {type(t)}") all_texts.append(t) uniq = sorted(set(all_texts)) emb = encoder.encode_texts(uniq) # (U, D) table = {t: emb[i] for i, t in enumerate(uniq)} b = len(state_texts) n = len(state_texts[0]) out = torch.stack([torch.stack([table[t] for t in row], dim=0) for row in state_texts], dim=0) # (B,N,D) if out.shape[0] != b or out.shape[1] != n: raise RuntimeError(f"unexpected encoded shape: {tuple(out.shape)}") return out