from transformers import OlmoModel, OlmoPreTrainedModel, GenerationMixin, AutoConfig, AutoModelForSequenceClassification from transformers.modeling_outputs import SequenceClassifierOutputWithPast import torch from peft import PeftModel, PeftConfig from transformers import AutoConfig import logging from contextlib import contextmanager from types import SimpleNamespace # The custom model for using Olmo with a sequence classification task device = "cuda" if torch.cuda.is_available() else "cpu" class OlmoForSequenceClassification(OlmoPreTrainedModel, GenerationMixin): def __init__(self, config): super().__init__(config) self.model = OlmoModel(config) self.num_labels = config.num_labels self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor | None = None, labels: torch.LongTensor | None = None, **kwargs, ) -> SequenceClassifierOutputWithPast: outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, **kwargs, ) logits = self.classifier(outputs.last_hidden_state) pooled_logits = logits[:, -1] # NOTE: tokenizer.padding_side must be 'left' loss = None if labels is not None: loss = self.loss_function( logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config, ) return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) # The function for loading a fulltuning model def get_fulltuning_model(model_path, model_type="olmo"): if model_type == "olmo": model = OlmoForSequenceClassification.from_pretrained( model_path, trust_remote_code=True, torch_dtype=torch.float32, ).to("cuda" if torch.cuda.is_available() else "cpu") model.eval() elif model_type == "pythia": cfg = AutoConfig.from_pretrained(model_path, num_labels=3) model = AutoModelForSequenceClassification.from_pretrained( model_path, config=cfg, torch_dtype=torch.float32, ).to(device) else: raise ValueError(f"Unsupported model_type: {model_type}") return model # The following function is used to suppress a "missing or unexpected params" warning. # This warning is no reason for concern. It stems from the fact that the model is first loaded # without a classifier head, which is added afterwards. class DropLoadReport(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: return "LOAD REPORT" not in record.getMessage() @contextmanager def suppress_load_report_only(): f = DropLoadReport() names = [ "transformers.modeling_utils", "transformers.modeling_tf_pytorch_utils", "transformers", ] loggers = [logging.getLogger(n) for n in names] for lg in loggers: lg.addFilter(f) try: yield finally: for lg in loggers: lg.removeFilter(f) # The function for loading a softprompt model def get_peft_model(model_path, model_type="olmo"): peft_config = PeftConfig.from_pretrained(model_path) device = "cuda" if torch.cuda.is_available() else "cpu" if model_type == "olmo": config = AutoConfig.from_pretrained( peft_config.base_model_name_or_path, trust_remote_code=True, num_labels=2, ) with suppress_load_report_only(): base = OlmoForSequenceClassification.from_pretrained( peft_config.base_model_name_or_path, trust_remote_code=True, torch_dtype=torch.float32, config=config, ).to(device) elif model_type == "pythia": config = AutoConfig.from_pretrained( peft_config.base_model_name_or_path, num_labels=2, ) with suppress_load_report_only(): base = AutoModelForSequenceClassification.from_pretrained( peft_config.base_model_name_or_path, config=config, torch_dtype=torch.float32, ).to(device) else: raise ValueError(f"Unsupported model_type: {model_type}") with suppress_load_report_only(): model = PeftModel.from_pretrained(base, model_path).to(device) model.is_prefix_tuning = str(peft_config.peft_type) == "PeftType.PREFIX_TUNING" # helpful for batching / last-token pooling if getattr(model.config, "pad_token_id", None) is None and getattr(model.config, "eos_token_id", None) is not None: model.config.pad_token_id = model.config.eos_token_id if hasattr(model, "base_model") and hasattr(model.base_model, "config"): if getattr(model.base_model.config, "pad_token_id", None) is None and getattr(model.base_model.config, "eos_token_id", None) is not None: model.base_model.config.pad_token_id = model.base_model.config.eos_token_id model.eval() return model # This function helps when loading prefix finetuned models def forward_peft_seqcls(model, **inputs): if not getattr(model, "is_prefix_tuning", False): return model(**inputs, use_cache=False) input_ids = inputs.get("input_ids", None) attention_mask = inputs.get("attention_mask", None) inputs_embeds = inputs.get("inputs_embeds", None) labels = inputs.get("labels", None) output_attentions = inputs.get("output_attentions", None) output_hidden_states = inputs.get("output_hidden_states", None) return_dict = inputs.get("return_dict", True) if input_ids is not None: batch_size = input_ids.shape[0] elif inputs_embeds is not None: batch_size = inputs_embeds.shape[0] else: raise ValueError("Either input_ids or inputs_embeds must be provided.") past_key_values = model.get_prompt(batch_size) if attention_mask is not None: num_virtual_tokens = model.active_peft_config.num_virtual_tokens prefix_attention_mask = torch.ones( batch_size, num_virtual_tokens, device=attention_mask.device, dtype=attention_mask.dtype, ) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) try: return model.base_model( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, labels=labels, past_key_values=past_key_values, use_cache=False, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) except TypeError: pass transformer_backbone = model.base_model.get_submodule(model.transformer_backbone_name) outputs = transformer_backbone( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=False, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=True, ) hidden_states = outputs[0] if "dropout" in [name for name, _ in model.base_model.named_children()]: hidden_states = model.base_model.dropout(hidden_states) cls_layer = model.base_model.get_submodule(model.cls_layer_name) token_logits = cls_layer(hidden_states) logits = token_logits[:, -1] return SimpleNamespace( logits=logits, hidden_states=getattr(outputs, "hidden_states", None), attentions=getattr(outputs, "attentions", None), )