| from transformers import OlmoModel, OlmoPreTrainedModel, GenerationMixin, AutoConfig, AutoModelForSequenceClassification |
| from transformers.modeling_outputs import SequenceClassifierOutputWithPast |
| import torch |
|
|
| from peft import PeftModel, PeftConfig |
|
|
| from transformers import AutoConfig |
|
|
| import logging |
| from contextlib import contextmanager |
| from types import SimpleNamespace |
|
|
| |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| class OlmoForSequenceClassification(OlmoPreTrainedModel, GenerationMixin): |
| def __init__(self, config): |
| super().__init__(config) |
| self.model = OlmoModel(config) |
| self.num_labels = config.num_labels |
| self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: torch.Tensor | None = None, |
| labels: torch.LongTensor | None = None, |
| **kwargs, |
| ) -> SequenceClassifierOutputWithPast: |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| **kwargs, |
| ) |
| logits = self.classifier(outputs.last_hidden_state) |
| pooled_logits = logits[:, -1] |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function( |
| logits=logits, |
| labels=labels, |
| pooled_logits=pooled_logits, |
| config=self.config, |
| ) |
|
|
| return SequenceClassifierOutputWithPast( |
| loss=loss, |
| logits=pooled_logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| |
|
|
| def get_fulltuning_model(model_path, model_type="olmo"): |
| if model_type == "olmo": |
| model = OlmoForSequenceClassification.from_pretrained( |
| model_path, |
| trust_remote_code=True, |
| torch_dtype=torch.float32, |
| ).to("cuda" if torch.cuda.is_available() else "cpu") |
| model.eval() |
| elif model_type == "pythia": |
| cfg = AutoConfig.from_pretrained(model_path, num_labels=3) |
| model = AutoModelForSequenceClassification.from_pretrained( |
| model_path, |
| config=cfg, |
| torch_dtype=torch.float32, |
| ).to(device) |
| else: |
| raise ValueError(f"Unsupported model_type: {model_type}") |
|
|
| return model |
|
|
| |
| |
| |
|
|
| class DropLoadReport(logging.Filter): |
| def filter(self, record: logging.LogRecord) -> bool: |
| return "LOAD REPORT" not in record.getMessage() |
|
|
| @contextmanager |
| def suppress_load_report_only(): |
| f = DropLoadReport() |
|
|
| names = [ |
| "transformers.modeling_utils", |
| "transformers.modeling_tf_pytorch_utils", |
| "transformers", |
| ] |
| loggers = [logging.getLogger(n) for n in names] |
|
|
| for lg in loggers: |
| lg.addFilter(f) |
| try: |
| yield |
| finally: |
| for lg in loggers: |
| lg.removeFilter(f) |
|
|
| |
|
|
| def get_peft_model(model_path, model_type="olmo"): |
| peft_config = PeftConfig.from_pretrained(model_path) |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| if model_type == "olmo": |
| config = AutoConfig.from_pretrained( |
| peft_config.base_model_name_or_path, |
| trust_remote_code=True, |
| num_labels=2, |
| ) |
| with suppress_load_report_only(): |
| base = OlmoForSequenceClassification.from_pretrained( |
| peft_config.base_model_name_or_path, |
| trust_remote_code=True, |
| torch_dtype=torch.float32, |
| config=config, |
| ).to(device) |
|
|
| elif model_type == "pythia": |
| config = AutoConfig.from_pretrained( |
| peft_config.base_model_name_or_path, |
| num_labels=2, |
| ) |
| with suppress_load_report_only(): |
| base = AutoModelForSequenceClassification.from_pretrained( |
| peft_config.base_model_name_or_path, |
| config=config, |
| torch_dtype=torch.float32, |
| ).to(device) |
| else: |
| raise ValueError(f"Unsupported model_type: {model_type}") |
|
|
| with suppress_load_report_only(): |
| model = PeftModel.from_pretrained(base, model_path).to(device) |
|
|
| model.is_prefix_tuning = str(peft_config.peft_type) == "PeftType.PREFIX_TUNING" |
|
|
| |
| if getattr(model.config, "pad_token_id", None) is None and getattr(model.config, "eos_token_id", None) is not None: |
| model.config.pad_token_id = model.config.eos_token_id |
| if hasattr(model, "base_model") and hasattr(model.base_model, "config"): |
| if getattr(model.base_model.config, "pad_token_id", None) is None and getattr(model.base_model.config, "eos_token_id", None) is not None: |
| model.base_model.config.pad_token_id = model.base_model.config.eos_token_id |
|
|
| model.eval() |
| return model |
|
|
| |
|
|
| def forward_peft_seqcls(model, **inputs): |
| if not getattr(model, "is_prefix_tuning", False): |
| return model(**inputs, use_cache=False) |
|
|
| input_ids = inputs.get("input_ids", None) |
| attention_mask = inputs.get("attention_mask", None) |
| inputs_embeds = inputs.get("inputs_embeds", None) |
| labels = inputs.get("labels", None) |
| output_attentions = inputs.get("output_attentions", None) |
| output_hidden_states = inputs.get("output_hidden_states", None) |
| return_dict = inputs.get("return_dict", True) |
|
|
| if input_ids is not None: |
| batch_size = input_ids.shape[0] |
| elif inputs_embeds is not None: |
| batch_size = inputs_embeds.shape[0] |
| else: |
| raise ValueError("Either input_ids or inputs_embeds must be provided.") |
|
|
| past_key_values = model.get_prompt(batch_size) |
|
|
| if attention_mask is not None: |
| num_virtual_tokens = model.active_peft_config.num_virtual_tokens |
| prefix_attention_mask = torch.ones( |
| batch_size, |
| num_virtual_tokens, |
| device=attention_mask.device, |
| dtype=attention_mask.dtype, |
| ) |
| attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) |
|
|
| try: |
| return model.base_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| labels=labels, |
| past_key_values=past_key_values, |
| use_cache=False, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| except TypeError: |
| pass |
|
|
| transformer_backbone = model.base_model.get_submodule(model.transformer_backbone_name) |
|
|
| outputs = transformer_backbone( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| past_key_values=past_key_values, |
| use_cache=False, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=True, |
| ) |
|
|
| hidden_states = outputs[0] |
|
|
| if "dropout" in [name for name, _ in model.base_model.named_children()]: |
| hidden_states = model.base_model.dropout(hidden_states) |
|
|
| cls_layer = model.base_model.get_submodule(model.cls_layer_name) |
| token_logits = cls_layer(hidden_states) |
|
|
| logits = token_logits[:, -1] |
|
|
| return SimpleNamespace( |
| logits=logits, |
| hidden_states=getattr(outputs, "hidden_states", None), |
| attentions=getattr(outputs, "attentions", None), |
| ) |