IvoHoese's picture
Upload 2 files
0533c5e verified
from transformers import OlmoModel, OlmoPreTrainedModel, GenerationMixin, AutoConfig, AutoModelForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoConfig
import logging
from contextlib import contextmanager
from types import SimpleNamespace
# The custom model for using Olmo with a sequence classification task
device = "cuda" if torch.cuda.is_available() else "cpu"
class OlmoForSequenceClassification(OlmoPreTrainedModel, GenerationMixin):
def __init__(self, config):
super().__init__(config)
self.model = OlmoModel(config)
self.num_labels = config.num_labels
self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: torch.Tensor | None = None,
labels: torch.LongTensor | None = None,
**kwargs,
) -> SequenceClassifierOutputWithPast:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)
logits = self.classifier(outputs.last_hidden_state)
pooled_logits = logits[:, -1] # NOTE: tokenizer.padding_side must be 'left'
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
pooled_logits=pooled_logits,
config=self.config,
)
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# The function for loading a fulltuning model
def get_fulltuning_model(model_path, model_type="olmo"):
if model_type == "olmo":
model = OlmoForSequenceClassification.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.float32,
).to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
elif model_type == "pythia":
cfg = AutoConfig.from_pretrained(model_path, num_labels=3)
model = AutoModelForSequenceClassification.from_pretrained(
model_path,
config=cfg,
torch_dtype=torch.float32,
).to(device)
else:
raise ValueError(f"Unsupported model_type: {model_type}")
return model
# The following function is used to suppress a "missing or unexpected params" warning.
# This warning is no reason for concern. It stems from the fact that the model is first loaded
# without a classifier head, which is added afterwards.
class DropLoadReport(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
return "LOAD REPORT" not in record.getMessage()
@contextmanager
def suppress_load_report_only():
f = DropLoadReport()
names = [
"transformers.modeling_utils",
"transformers.modeling_tf_pytorch_utils",
"transformers",
]
loggers = [logging.getLogger(n) for n in names]
for lg in loggers:
lg.addFilter(f)
try:
yield
finally:
for lg in loggers:
lg.removeFilter(f)
# The function for loading a softprompt model
def get_peft_model(model_path, model_type="olmo"):
peft_config = PeftConfig.from_pretrained(model_path)
device = "cuda" if torch.cuda.is_available() else "cpu"
if model_type == "olmo":
config = AutoConfig.from_pretrained(
peft_config.base_model_name_or_path,
trust_remote_code=True,
num_labels=2,
)
with suppress_load_report_only():
base = OlmoForSequenceClassification.from_pretrained(
peft_config.base_model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.float32,
config=config,
).to(device)
elif model_type == "pythia":
config = AutoConfig.from_pretrained(
peft_config.base_model_name_or_path,
num_labels=2,
)
with suppress_load_report_only():
base = AutoModelForSequenceClassification.from_pretrained(
peft_config.base_model_name_or_path,
config=config,
torch_dtype=torch.float32,
).to(device)
else:
raise ValueError(f"Unsupported model_type: {model_type}")
with suppress_load_report_only():
model = PeftModel.from_pretrained(base, model_path).to(device)
model.is_prefix_tuning = str(peft_config.peft_type) == "PeftType.PREFIX_TUNING"
# helpful for batching / last-token pooling
if getattr(model.config, "pad_token_id", None) is None and getattr(model.config, "eos_token_id", None) is not None:
model.config.pad_token_id = model.config.eos_token_id
if hasattr(model, "base_model") and hasattr(model.base_model, "config"):
if getattr(model.base_model.config, "pad_token_id", None) is None and getattr(model.base_model.config, "eos_token_id", None) is not None:
model.base_model.config.pad_token_id = model.base_model.config.eos_token_id
model.eval()
return model
# This function helps when loading prefix finetuned models
def forward_peft_seqcls(model, **inputs):
if not getattr(model, "is_prefix_tuning", False):
return model(**inputs, use_cache=False)
input_ids = inputs.get("input_ids", None)
attention_mask = inputs.get("attention_mask", None)
inputs_embeds = inputs.get("inputs_embeds", None)
labels = inputs.get("labels", None)
output_attentions = inputs.get("output_attentions", None)
output_hidden_states = inputs.get("output_hidden_states", None)
return_dict = inputs.get("return_dict", True)
if input_ids is not None:
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("Either input_ids or inputs_embeds must be provided.")
past_key_values = model.get_prompt(batch_size)
if attention_mask is not None:
num_virtual_tokens = model.active_peft_config.num_virtual_tokens
prefix_attention_mask = torch.ones(
batch_size,
num_virtual_tokens,
device=attention_mask.device,
dtype=attention_mask.dtype,
)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
try:
return model.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
past_key_values=past_key_values,
use_cache=False,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
except TypeError:
pass
transformer_backbone = model.base_model.get_submodule(model.transformer_backbone_name)
outputs = transformer_backbone(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=False,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
hidden_states = outputs[0]
if "dropout" in [name for name, _ in model.base_model.named_children()]:
hidden_states = model.base_model.dropout(hidden_states)
cls_layer = model.base_model.get_submodule(model.cls_layer_name)
token_logits = cls_layer(hidden_states)
logits = token_logits[:, -1]
return SimpleNamespace(
logits=logits,
hidden_states=getattr(outputs, "hidden_states", None),
attentions=getattr(outputs, "attentions", None),
)