| import os |
|
|
| import torch |
| from torch import Tensor, nn |
| from transformers import ( |
| CLIPTextModel, |
| CLIPTokenizer, |
| T5EncoderModel, |
| T5Tokenizer, |
| __version__, |
| ) |
| from transformers.utils.quantization_config import QuantoConfig, BitsAndBytesConfig |
|
|
| CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface") |
|
|
|
|
| def auto_quantization_config( |
| quantization_dtype: str, |
| ) -> QuantoConfig | BitsAndBytesConfig: |
| if quantization_dtype == "qfloat8": |
| return QuantoConfig(weights="float8") |
| elif quantization_dtype == "qint4": |
| return BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| bnb_4bit_quant_type="nf4", |
| ) |
| elif quantization_dtype == "qint8": |
| return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False) |
| elif quantization_dtype == "qint2": |
| return QuantoConfig(weights="int2") |
| elif quantization_dtype is None or quantization_dtype == "bfloat16": |
| return None |
| else: |
| raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}") |
|
|
|
|
| class HFEmbedder(nn.Module): |
| def __init__( |
| self, |
| version: str, |
| max_length: int, |
| device: torch.device | int, |
| quantization_dtype: str | None = None, |
| offloading_device: torch.device | int | None = torch.device("cpu"), |
| is_clip: bool = False, |
| **hf_kwargs, |
| ): |
| super().__init__() |
| self.offloading_device = ( |
| offloading_device |
| if isinstance(offloading_device, torch.device) |
| else torch.device(offloading_device) |
| ) |
| self.device = ( |
| device if isinstance(device, torch.device) else torch.device(device) |
| ) |
| self.is_clip = version.startswith("openai") or is_clip |
| self.max_length = max_length |
| self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" |
|
|
| auto_quant_config = ( |
| auto_quantization_config(quantization_dtype) |
| if quantization_dtype is not None |
| and quantization_dtype != "bfloat16" |
| and quantization_dtype != "float16" |
| else None |
| ) |
|
|
| |
| if isinstance(auto_quant_config, BitsAndBytesConfig): |
| hf_kwargs["device_map"] = {"": self.device.index} |
| if auto_quant_config is not None: |
| hf_kwargs["quantization_config"] = auto_quant_config |
|
|
| if self.is_clip: |
| self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( |
| version, max_length=max_length |
| ) |
|
|
| self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( |
| version, |
| **hf_kwargs, |
| ) |
|
|
| else: |
| self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( |
| version, max_length=max_length |
| ) |
| self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( |
| version, |
| **hf_kwargs, |
| ) |
|
|
| def offload(self): |
| self.hf_module.to(device=self.offloading_device) |
| torch.cuda.empty_cache() |
|
|
| def cuda(self): |
| self.hf_module.to(device=self.device) |
|
|
| def forward(self, text: list[str]) -> Tensor: |
| batch_encoding = self.tokenizer( |
| text, |
| truncation=True, |
| max_length=self.max_length, |
| return_length=False, |
| return_overflowing_tokens=False, |
| padding="max_length", |
| return_tensors="pt", |
| ) |
| outputs = self.hf_module( |
| input_ids=batch_encoding["input_ids"].to(self.hf_module.device), |
| attention_mask=None, |
| output_hidden_states=False, |
| ) |
| return outputs[self.output_key] |
|
|
|
|
| if __name__ == "__main__": |
| model = HFEmbedder( |
| "city96/t5-v1_1-xxl-encoder-bf16", |
| max_length=512, |
| device=0, |
| quantization_dtype="qfloat8", |
| ) |
| o = model(["hello"]) |
| print(o) |
|
|