File size: 3,506 Bytes
1146a67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from __future__ import annotations

from dataclasses import dataclass
from functools import lru_cache
from typing import Iterable, List

import torch


@dataclass
class Blip2TextStateEncoderConfig:
    model_name_or_path: str = "Salesforce/blip2-itm-vit-g"
    device: str = "cpu"
    torch_dtype: torch.dtype = torch.float16
    max_length: int = 32


class Blip2TextStateEncoder:
    """
    用 BLIP2 的 `Blip2TextModelWithProjection` 把状态文本编码为一个向量(text_embeds)。

    设计目标:
    - 状态在数据里用可读字符串(例如 "raw", "cooked")
    - 训练/推理阶段把这些字符串变成 state_features: (B,N,D_text)
    - 下游 InstanceFeatureExtractor 再把 D_text 投影到 DiT hidden_dim
    """

    def __init__(self, cfg: Blip2TextStateEncoderConfig):
        self.cfg = cfg
        self._tokenizer = None
        self._model = None

    def _lazy_init(self):
        if self._model is not None:
            return
        from transformers import AutoTokenizer, Blip2TextModelWithProjection

        self._tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name_or_path)
        self._model = Blip2TextModelWithProjection.from_pretrained(
            self.cfg.model_name_or_path,
            torch_dtype=self.cfg.torch_dtype,
        )
        self._model.eval()
        for p in self._model.parameters():
            p.requires_grad_(False)
        self._model.to(device=self.cfg.device)

    @torch.inference_mode()
    def encode_texts(self, texts: List[str]) -> torch.Tensor:
        self._lazy_init()
        tok = self._tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=self.cfg.max_length,
            return_tensors="pt",
        )
        tok = {k: v.to(self.cfg.device) for k, v in tok.items()}
        out = self._model(**tok)
        # (B, D_text)
        return out.text_embeds.to(dtype=torch.float32, device="cpu")


def encode_state_text_tensor(
    state_texts: list,
    model_name_or_path: str = "Salesforce/blip2-itm-vit-g",
    device: str = "cpu",
    torch_dtype: torch.dtype = torch.float16,
    max_length: int = 32,
) -> torch.Tensor:
    """
    将嵌套 list 的 state_texts(B,N)编码成 tensor: (B,N,D_text) float32 on CPU。
    """
    if not isinstance(state_texts, list) or not state_texts:
        raise ValueError("state_texts must be a non-empty nested list (B,N)")
    if not isinstance(state_texts[0], list):
        raise ValueError("state_texts must be nested list like [[...], [...]]")

    encoder = Blip2TextStateEncoder(
        Blip2TextStateEncoderConfig(
            model_name_or_path=model_name_or_path,
            device=device,
            torch_dtype=torch_dtype,
            max_length=max_length,
        )
    )

    # flatten unique texts to avoid redundant encode
    all_texts = []
    for row in state_texts:
        for t in row:
            if not isinstance(t, str):
                raise ValueError(f"state_text must be str, got: {type(t)}")
            all_texts.append(t)
    uniq = sorted(set(all_texts))
    emb = encoder.encode_texts(uniq)  # (U, D)
    table = {t: emb[i] for i, t in enumerate(uniq)}

    b = len(state_texts)
    n = len(state_texts[0])
    out = torch.stack([torch.stack([table[t] for t in row], dim=0) for row in state_texts], dim=0)
    # (B,N,D)
    if out.shape[0] != b or out.shape[1] != n:
        raise RuntimeError(f"unexpected encoded shape: {tuple(out.shape)}")
    return out