PencilFolder / diffsynth /utils /blip2_state_text_encoder.py
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified
from __future__ import annotations
from dataclasses import dataclass
from functools import lru_cache
from typing import Iterable, List
import torch
@dataclass
class Blip2TextStateEncoderConfig:
model_name_or_path: str = "Salesforce/blip2-itm-vit-g"
device: str = "cpu"
torch_dtype: torch.dtype = torch.float16
max_length: int = 32
class Blip2TextStateEncoder:
"""
用 BLIP2 的 `Blip2TextModelWithProjection` 把状态文本编码为一个向量(text_embeds)。
设计目标:
- 状态在数据里用可读字符串(例如 "raw", "cooked")
- 训练/推理阶段把这些字符串变成 state_features: (B,N,D_text)
- 下游 InstanceFeatureExtractor 再把 D_text 投影到 DiT hidden_dim
"""
def __init__(self, cfg: Blip2TextStateEncoderConfig):
self.cfg = cfg
self._tokenizer = None
self._model = None
def _lazy_init(self):
if self._model is not None:
return
from transformers import AutoTokenizer, Blip2TextModelWithProjection
self._tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name_or_path)
self._model = Blip2TextModelWithProjection.from_pretrained(
self.cfg.model_name_or_path,
torch_dtype=self.cfg.torch_dtype,
)
self._model.eval()
for p in self._model.parameters():
p.requires_grad_(False)
self._model.to(device=self.cfg.device)
@torch.inference_mode()
def encode_texts(self, texts: List[str]) -> torch.Tensor:
self._lazy_init()
tok = self._tokenizer(
texts,
padding=True,
truncation=True,
max_length=self.cfg.max_length,
return_tensors="pt",
)
tok = {k: v.to(self.cfg.device) for k, v in tok.items()}
out = self._model(**tok)
# (B, D_text)
return out.text_embeds.to(dtype=torch.float32, device="cpu")
def encode_state_text_tensor(
state_texts: list,
model_name_or_path: str = "Salesforce/blip2-itm-vit-g",
device: str = "cpu",
torch_dtype: torch.dtype = torch.float16,
max_length: int = 32,
) -> torch.Tensor:
"""
将嵌套 list 的 state_texts(B,N)编码成 tensor: (B,N,D_text) float32 on CPU。
"""
if not isinstance(state_texts, list) or not state_texts:
raise ValueError("state_texts must be a non-empty nested list (B,N)")
if not isinstance(state_texts[0], list):
raise ValueError("state_texts must be nested list like [[...], [...]]")
encoder = Blip2TextStateEncoder(
Blip2TextStateEncoderConfig(
model_name_or_path=model_name_or_path,
device=device,
torch_dtype=torch_dtype,
max_length=max_length,
)
)
# flatten unique texts to avoid redundant encode
all_texts = []
for row in state_texts:
for t in row:
if not isinstance(t, str):
raise ValueError(f"state_text must be str, got: {type(t)}")
all_texts.append(t)
uniq = sorted(set(all_texts))
emb = encoder.encode_texts(uniq) # (U, D)
table = {t: emb[i] for i, t in enumerate(uniq)}
b = len(state_texts)
n = len(state_texts[0])
out = torch.stack([torch.stack([table[t] for t in row], dim=0) for row in state_texts], dim=0)
# (B,N,D)
if out.shape[0] != b or out.shape[1] != n:
raise RuntimeError(f"unexpected encoded shape: {tuple(out.shape)}")
return out