| from transformers import AutoConfig, AutoModel, PretrainedConfig, CLIPTextConfig, CLIPVisionConfig, PreTrainedModel, CLIPTextModelWithProjection, CLIPVisionModelWithProjection |
| from transformers.utils import ModelOutput |
| import torch |
| import open_clip |
| from dataclasses import dataclass |
| import safetensors.torch |
| from peft import get_peft_config, get_peft_model, LoraConfig, TaskType |
| import os |
|
|
| HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" |
| HF_SAFE_WEIGHTS_NAME_PRIOR = "prior_model.safetensors" |
|
|
| @dataclass |
| class PriorTransformerOutput(ModelOutput): |
| """ |
| The output of [`PriorTransformer`]. |
| |
| Args: |
| predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): |
| The predicted CLIP image embedding conditioned on the CLIP text embedding input. |
| """ |
|
|
| predicted_image_embedding: torch.FloatTensor |
|
|
| @dataclass |
| class TextEncoderOutput(ModelOutput): |
| """ |
| Output class for CLIPTextEncoderOnly model to store the outputs in a Hugging Face transformer style. |
| |
| Attributes: |
| prompt_embeds (torch.Tensor): The embeddings of the input prompts. |
| last_hidden_states (torch.Tensor): The last hidden states from the model. |
| """ |
| text_embeds: torch.FloatTensor = None |
| last_hidden_state: torch.FloatTensor = None |
|
|
| class CLIPTextEncoderOnlyConfig(CLIPTextConfig): |
| model_type = "clip_custom_text_model" |
|
|
| def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, lora: dict = None, **kwargs): |
| self.model_name = model_name |
| self.pretrained = pretrained |
| self.frozen = frozen |
| self.lora = lora |
| super().__init__(**kwargs) |
|
|
| class CLIPTextEncoderOnly(PreTrainedModel): |
| config_class = CLIPTextEncoderOnlyConfig |
|
|
| def __init__(self, config): |
| """ |
| Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
| |
| :param model_name: The name or path of the pretrained model. |
| :param pretrained: Whether to load the pretrained weights. |
| """ |
| super().__init__(config) |
| |
| if config.pretrained: |
| self.model = CLIPTextModelWithProjection.from_pretrained(config.model_name) |
| else: |
| base_cfg = CLIPTextConfig.from_pretrained(config.model_name) |
| self.model = CLIPTextModelWithProjection(base_cfg) |
|
|
| if config.lora: |
| l_config = LoraConfig( |
| r=config.lora.lora_r, |
| lora_alpha=config.lora.lora_alpha, |
| target_modules=[ |
| "k_proj", |
| "v_proj", |
| "q_proj", |
| "out_proj", |
| "fc1", |
| "fc2", |
| "visual_projection", |
| "text_projection" |
| ], |
| lora_dropout=config.lora.lora_dropout, |
| bias="lora_only", |
| ) |
| self.model = get_peft_model(self.model, l_config) |
| |
|
|
| def forward(self, input_ids, attention_mask=None, position_ids=None): |
| """ |
| Forward pass of the model. |
| |
| :param input_ids: Indices of input sequence tokens in the vocabulary. |
| :param attention_mask: Mask to avoid performing attention on padding token indices. |
| :param token_type_ids: Segment token indices to indicate first and second portions of the inputs. |
| :return: Outputs of the model. |
| """ |
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_hidden_states=True) |
| return TextEncoderOutput(text_embeds=outputs.text_embeds, last_hidden_state=outputs.last_hidden_state) |
| |
|
|
| class CustomTextEncoderOnlyConfig(CLIPTextConfig): |
| model_type = "whole_custom_text_model" |
|
|
| def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, output_hidden_size: int = 512, last_hidden_state: bool = False, lora: dict = None, **kwargs): |
| self.model_name = model_name |
| self.pretrained = pretrained |
| self.frozen = frozen |
| self.output_hidden_size = output_hidden_size |
| self.last_hidden_state = last_hidden_state |
| self.lora = lora |
| super().__init__(**kwargs) |
|
|
| class CustomTextEncoderOnly(PreTrainedModel): |
| config_class = CustomTextEncoderOnlyConfig |
|
|
| def __init__(self, config): |
| """ |
| Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
| |
| :param model_name: The name or path of the pretrained model. |
| :param pretrained: Whether to load the pretrained weights. |
| """ |
| super().__init__(config) |
|
|
| self.last_hidden_state = config.last_hidden_state |
|
|
| if config.pretrained: |
| self.model = AutoModel.from_pretrained(config.model_name) |
| if config.frozen: |
| for param in self.model.parameters(): |
| param.requires_grad = False |
| else: |
| self.model = AutoModel(config) |
|
|
| self.fc1 = torch.nn.Linear(self.model.config.hidden_size, config.output_hidden_size) |
| if config.last_hidden_state: |
| self.fc2 = torch.nn.Linear(self.model.config.hidden_size, config.output_hidden_size) |
|
|
| if config.lora: |
| l_config = LoraConfig( |
| task_type=TaskType.FEATURE_EXTRACTION, |
| r=config.lora.lora_r, |
| lora_alpha=config.lora.lora_alpha, |
| lora_dropout=config.lora.lora_dropout, |
| bias="lora_only", |
| ) |
| self.model = get_peft_model(self.model, l_config) |
|
|
| def forward(self, input_ids, attention_mask=None, token_type_ids=None): |
| """ |
| Forward pass of the model. |
| |
| :param input_ids: Indices of input sequence tokens in the vocabulary. |
| :param attention_mask: Mask to avoid performing attention on padding token indices. |
| :param token_type_ids: Segment token indices to indicate first and second portions of the inputs. |
| :return: Outputs of the model. |
| """ |
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True) |
| text_embeds = self.fc1(outputs[1]) |
| last_hidden_state = None |
| if self.last_hidden_state: |
| last_hidden_state = self.fc2(outputs[0]) |
| else: |
| last_hidden_state = outputs[0] |
| return TextEncoderOutput(text_embeds=text_embeds, last_hidden_state=last_hidden_state) |
|
|
| class CLIPVisionEncoderOnlyConfig(PretrainedConfig): |
| model_type = "clip_custom_vision_model" |
|
|
| def __init__(self, model_name: str = None, pretrained: bool = True, frozen: bool = False, lora: dict = None, **kwargs): |
| self.model_name = model_name |
| self.pretrained = pretrained |
| self.frozen = frozen |
| self.lora = lora |
| super().__init__(**kwargs) |
|
|
| class CLIPVisionEncoderOnly(PreTrainedModel): |
| config_class = CLIPVisionEncoderOnlyConfig |
|
|
| def __init__(self, config): |
| """ |
| Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
| |
| :param model_name: The name or path of the pretrained model. |
| :param pretrained: Whether to load the pretrained weights. |
| """ |
| super().__init__(config) |
| |
| if config.pretrained: |
| self.model = CLIPVisionModelWithProjection.from_pretrained(config.model_name) |
| else: |
| base_cfg = CLIPVisionConfig.from_pretrained(config.model_name) |
| self.model = CLIPVisionModelWithProjection(base_cfg) |
|
|
| if config.lora: |
| l_config = LoraConfig( |
| r=config.lora.lora_r, |
| lora_alpha=config.lora.lora_alpha, |
| target_modules=[ |
| "k_proj", |
| "v_proj", |
| "q_proj", |
| "out_proj", |
| "fc1", |
| "fc2", |
| "visual_projection", |
| "text_projection" |
| ], |
| lora_dropout=config.lora.lora_dropout, |
| bias="lora_only", |
| ) |
| self.model = get_peft_model(self.model, l_config) |
|
|
| def forward(self, data): |
| """ |
| Forward pass of the model. |
| """ |
| return self.model(**data).image_embeds |
| |
| def parameters(self): |
| return self.model.parameters() |
|
|
|
|
| class OpenCLIPVisionEncoderOnly(torch.nn.Module): |
| def __init__(self, model_name: str, pretrained: bool = True, frozen: bool = False, lora: dict = None): |
| """ |
| Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
| |
| :param model_name: The name or path of the pretrained model. |
| :param pretrained: Whether to load the pretrained weights. |
| """ |
| super().__init__() |
| if pretrained: |
| model, _ = open_clip.create_model_from_pretrained(f"hf-hub:{model_name}") |
| model = model.visual |
| else: |
| raise NotImplemented |
| self.model = model |
|
|
| if lora: |
| l_config = LoraConfig( |
| r=lora.lora_r, |
| lora_alpha=lora.lora_alpha, |
| target_modules=[ |
| "k_proj", |
| "v_proj", |
| "q_proj", |
| "out_proj", |
| "fc1", |
| "fc2", |
| "visual_projection", |
| "text_projection" |
| ], |
| lora_dropout=lora.lora_dropout, |
| bias="lora_only", |
| ) |
| self.model = get_peft_model(self.model, l_config) |
|
|
| def forward(self, image): |
| """ |
| Forward pass of the model. |
| """ |
| return self.model(image) |
| |
| def save_pretrained(self, save_dir): |
| tensors = self.model.state_dict() |
| safetensors.torch.save_file(tensors, save_dir / HF_SAFE_WEIGHTS_NAME) |
|
|
| class CustomPriorModel(torch.nn.Module): |
| def __init__(self, in_hidden_state, out_hidden_state): |
| """ |
| Initializes the Hugging Face text encoder for CLIP model, inheriting from PreTrainedModel. |
| |
| :param model_name: The name or path of the pretrained model. |
| :param pretrained: Whether to load the pretrained weights. |
| """ |
| super().__init__() |
| mid_hidden_state = max(in_hidden_state, out_hidden_state) |
|
|
| self.fc1 = torch.nn.Linear(in_hidden_state*2, mid_hidden_state) |
| self.relu = torch.nn.ReLU() |
| self.fc2 = torch.nn.Linear(mid_hidden_state, out_hidden_state) |
| |
| def reinitialize_model(self): |
| for name, param in self.named_parameters(): |
| if param.requires_grad: |
| if len(param.shape) > 1: |
| torch.nn.init.xavier_uniform_(param) |
| else: |
| if 'weight' in name: |
| torch.nn.init.normal_(param) |
| else: |
| torch.nn.init.zeros_(param) |
|
|
| def forward(self, feats): |
| """ |
| Forward pass of the model. |
| """ |
| return PriorTransformerOutput(predicted_image_embedding=self.fc2(self.relu(self.fc1(feats)))) |
| |
| def save_pretrained(self, save_dir): |
| pass |
| |
| |
|
|
|
|
| def test_text_model(register=False, upload=False): |
| |
| if register: |
| AutoConfig.register("clip_custom_text_model", CLIPTextEncoderOnlyConfig) |
| AutoModel.register(CLIPTextEncoderOnlyConfig, CLIPTextEncoderOnly) |
| CLIPTextEncoderOnlyConfig.register_for_auto_class() |
| CLIPTextEncoderOnly.register_for_auto_class("AutoModel") |
|
|
| if upload: |
| |
| model_name = "openai/clip-vit-base-patch32" |
| pretrained=True |
| lora=None |
|
|
| cfg = CLIPTextEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, lora=lora) |
| model = CLIPTextEncoderOnly(cfg) |
| model.push_to_hub("test-text-hf-upload") |
|
|
| model = CLIPTextEncoderOnly.from_pretrained("mpatel57/test-text-hf-upload", force_download=True) |
|
|
| def test_custom_text_model(register=False, upload=False): |
| |
| if register: |
| AutoConfig.register("whole_custom_text_model", CustomTextEncoderOnlyConfig) |
| AutoModel.register(CustomTextEncoderOnlyConfig, CustomTextEncoderOnly) |
| CustomTextEncoderOnlyConfig.register_for_auto_class() |
| CustomTextEncoderOnly.register_for_auto_class("AutoModel") |
|
|
| if upload: |
| |
| model_name = "google-bert/bert-base-uncased" |
| pretrained=True |
| frozen=False |
| output_hidden_size=512 |
| last_hidden_state=False |
|
|
| lora=None |
|
|
| cfg = CustomTextEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, frozen=frozen, output_hidden_size=output_hidden_size, last_hidden_state=last_hidden_state, lora=lora) |
| model = CustomTextEncoderOnly(cfg) |
| model.push_to_hub("test-text-hf-upload") |
|
|
| model = CustomTextEncoderOnly.from_pretrained("mpatel57/test-text-hf-upload", force_download=True) |
|
|
| def test_vision_model(register=False, upload=False): |
| |
| if register: |
| AutoConfig.register("clip_custom_vision_model", CLIPVisionEncoderOnlyConfig) |
| AutoModel.register(CLIPVisionEncoderOnlyConfig, CLIPVisionEncoderOnly) |
| CLIPVisionEncoderOnlyConfig.register_for_auto_class() |
| CLIPVisionEncoderOnly.register_for_auto_class("AutoModel") |
|
|
| if upload: |
| |
| model_name = "openai/clip-vit-base-patch32" |
| pretrained=True |
| lora=None |
|
|
| cfg = CLIPVisionEncoderOnlyConfig(model_name=model_name, pretrained=pretrained, lora=lora) |
| model = CLIPVisionEncoderOnly(cfg) |
| model.push_to_hub("test-vision-hf-upload") |
|
|
| model = CLIPVisionEncoderOnly.from_pretrained("mpatel57/test-vision-hf-upload", force_download=True) |
|
|
|
|
| if __name__ == "__main__": |
| test_custom_text_model(register=False, upload=True) |
| |
| |
|
|