| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Main entry point to run the experiments. Contains general setup and the proper training code. |
| """ |
|
|
| import argparse |
| import datetime as dt |
| import json |
| import os |
| import random |
| import sys |
| import textwrap |
| import time |
| from collections.abc import Callable |
| from contextlib import AbstractContextManager, nullcontext |
| from functools import partial |
| from typing import Any, Literal, Optional |
|
|
| import torch |
| from torch import nn |
| from torch.amp import GradScaler, autocast |
| from tqdm import tqdm |
| from transformers import GenerationConfig, set_seed |
| from utils import ( |
| FILE_NAME_TRAIN_PARAMS, |
| BucketIterator, |
| TrainResult, |
| TrainStatus, |
| get_accuracy, |
| get_base_model_info, |
| get_dataset_info, |
| get_file_size, |
| get_model, |
| get_optimizer_and_scheduler, |
| get_peft_branch, |
| get_tokenizer, |
| get_train_config, |
| init_accelerator, |
| log_results, |
| validate_experiment_path, |
| ) |
|
|
| from data import get_train_valid_test_datasets, get_wiki_small |
| from peft import AdaLoraConfig, PeftConfig |
| from peft.utils import CONFIG_NAME, infer_device |
|
|
|
|
| |
| BUCKET_FACTOR = 20 |
| |
| ACCELERATOR_EMPTY_CACHE_SCHEDULE = 10 |
|
|
| |
| os.environ["TORCHINDUCTOR_FORCE_DISABLE_CACHES"] = "1" |
|
|
|
|
| def get_generation_config(*, seq_len, generate_kwargs) -> GenerationConfig: |
| |
| generate_kwargs = {k: v for k, v in generate_kwargs.items() if v is not None} |
| if ("max_length" in generate_kwargs) and ("max_new_tokens" in generate_kwargs): |
| |
| |
| new_max_length = min(generate_kwargs["max_new_tokens"] + seq_len, generate_kwargs["max_length"]) |
| del generate_kwargs["max_new_tokens"] |
| generate_kwargs["max_length"] = new_max_length |
| generation_config = GenerationConfig(**generate_kwargs) |
| return generation_config |
|
|
|
|
| def evaluate(model, tokenizer, ds, batch_size, generate_kwargs, use_tqdm: bool = False) -> tuple[list[str], list[str]]: |
| generate_kwargs = generate_kwargs.copy() |
| generate_kwargs["pad_token_id"] = tokenizer.eos_token_id |
| with torch.inference_mode(): |
| predictions = [] |
| responses = [] |
| pbar = range(0, len(ds), batch_size) |
| if use_tqdm: |
| pbar = tqdm(pbar) |
| for j in pbar: |
| sliced = ds[j : j + batch_size] |
| responses += sliced.pop("response") |
| batch = tokenizer.pad(sliced, return_tensors="pt", padding_side="left").to(model.device) |
| seq_len = batch["input_ids"].shape[1] |
| generation_config = get_generation_config(seq_len=seq_len, generate_kwargs=generate_kwargs) |
| outputs = model.generate(**batch, generation_config=generation_config) |
| predictions += tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| return predictions, responses |
|
|
|
|
| @torch.inference_mode |
| def calculate_mean_per_token_loss(model, tokenizer, rows: list[str], batch_size: int, max_length: int) -> float: |
| """Calculate the mean loss per token on the given dataset. |
| |
| Useful to determine general model performance before and after training to get an estimate of the magnitude of |
| 'forgetting'. Note that for Wikipedia data, since the information density is quite high, the loss can be |
| surprisingly large. |
| |
| """ |
| losses: list[float] = [] |
| for j in range(0, len(rows), batch_size): |
| sliced = rows[j : j + batch_size] |
| batch = tokenizer(sliced, truncation=True, max_length=max_length) |
| batch = tokenizer.pad(batch, return_tensors="pt", padding_side="left").to(model.device) |
| outputs = model(**batch, pad_token_id=tokenizer.eos_token_id) |
| logits = outputs.logits |
| for logit, target, mask in zip(logits, batch["input_ids"], batch["attention_mask"]): |
| |
| num_tokens = mask.sum() |
| token_losses = torch.nn.functional.cross_entropy(logit[-num_tokens:], target[-num_tokens:], reduction="none") |
| losses.extend(loss.item() for loss in token_losses) |
| return torch.tensor(losses).mean().item() |
|
|
|
|
| class DummyGradScaler: |
| |
| def scale(self, loss): |
| return loss |
|
|
| def unscale_(self, optimizer): |
| pass |
|
|
| def step(self, optimizer): |
| optimizer.step() |
|
|
| def update(self): |
| pass |
|
|
|
|
| def train( |
| *, |
| model: nn.Module, |
| max_steps: int, |
| batch_size: int, |
| batch_size_eval: int, |
| tokenizer: Any, |
| accelerator_memory_init: int, |
| eval_steps: int, |
| generation_kwargs: dict[str, Any], |
| grad_norm_clip: float, |
| optimizer_type: str, |
| optimizer_kwargs: dict[str, Any], |
| query_template: str, |
| lr_scheduler_arg: Optional[Literal["cosine"]], |
| use_amp: bool, |
| is_adalora: bool, |
| ) -> TrainResult: |
| accelerator_memory_allocated_log = [] |
| accelerator_memory_reserved_log = [] |
| losses = [] |
| durations = [] |
| metrics = [] |
| sample = 0 |
| total_samples = 0 |
| total_tokens = [] |
|
|
| device_type = infer_device() |
| torch_accelerator_module = getattr(torch, device_type, torch.cuda) |
| if use_amp: |
| grad_scaler: GradScaler | DummyGradScaler = GradScaler(device=device_type) |
| autocast_ctx: Callable[[], AbstractContextManager[Any]] = partial(autocast, device_type=device_type) |
| else: |
| grad_scaler = DummyGradScaler() |
| autocast_ctx = nullcontext |
|
|
| optimizer, lr_scheduler = get_optimizer_and_scheduler( |
| model, |
| optimizer_type=optimizer_type, |
| max_steps=max_steps, |
| lr_scheduler_arg=lr_scheduler_arg, |
| **optimizer_kwargs, |
| ) |
| |
| if hasattr(model, "get_nb_trainable_parameters"): |
| num_trainable_params, num_params = model.get_nb_trainable_parameters() |
| else: |
| num_params = model.num_parameters() |
| num_trainable_params = num_params |
| print_verbose( |
| f"trainable params: {num_trainable_params:,d} || all params: {num_params:,d} || " |
| f"trainable: {100 * num_trainable_params / num_params:.4f}%" |
| ) |
|
|
| status = TrainStatus.FAILED |
| tic_train = time.perf_counter() |
| eval_time = 0.0 |
| error_msg = "" |
|
|
| rows_wiki = get_wiki_small() |
| model.eval() |
| |
| wiki_loss_before = calculate_mean_per_token_loss( |
| model=model, tokenizer=tokenizer, rows=rows_wiki, batch_size=batch_size, max_length=768 |
| ) |
| model.train() |
|
|
| ds_train, ds_valid, ds_test = get_train_valid_test_datasets( |
| tokenizer=tokenizer, query_template=query_template, print_fn=print_verbose |
| ) |
| |
| |
| iterator_train = BucketIterator( |
| ds_train, |
| batch_size=batch_size, |
| bucket_factor=BUCKET_FACTOR, |
| delete_cols=["response"], |
| ) |
| try: |
| pbar = tqdm(range(1, max_steps + 1)) |
| for step, batch in zip(pbar, iterator_train): |
| tic = time.perf_counter() |
|
|
| |
| tokens_per_sample = [len(i) for i in batch["input_ids"]] |
| total_tokens.append(sum(tokens_per_sample) + len(tokens_per_sample)) |
| batch = tokenizer.pad(batch, return_tensors="pt").to(model.device) |
| actual_batch_size = len(batch["input_ids"]) |
| total_samples += actual_batch_size |
| sample += batch_size |
| if sample >= len(ds_train): |
| sample = 0 |
|
|
| |
| labels = batch["input_ids"].clone() |
| |
| |
| |
| |
| for i, num_tokens in enumerate(tokens_per_sample): |
| labels[i, num_tokens + 1 :] = -100 |
| batch["labels"] = labels |
| num_items_in_batch = batch["attention_mask"].sum().item() |
|
|
| |
| optimizer.zero_grad() |
| with autocast_ctx(): |
| outputs = model(**batch, num_items_in_batch=num_items_in_batch) |
| loss = outputs.loss |
| grad_scaler.scale(loss).backward() |
| if grad_norm_clip: |
| grad_scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip) |
| grad_scaler.step(optimizer) |
| grad_scaler.update() |
| lr_scheduler.step() |
|
|
| if is_adalora: |
| model.base_model.update_and_allocate(step) |
|
|
| losses.append(loss.item()) |
| pbar.set_postfix({"loss": loss.item()}) |
| accelerator_memory_allocated_log.append( |
| torch_accelerator_module.memory_allocated() - accelerator_memory_init |
| ) |
| accelerator_memory_reserved_log.append( |
| torch_accelerator_module.memory_reserved() - accelerator_memory_init |
| ) |
| toc = time.perf_counter() |
| durations.append(toc - tic) |
|
|
| |
| if step % eval_steps == 0: |
| tic_eval = time.perf_counter() |
| loss_avg = sum(losses[-eval_steps:]) / eval_steps |
| memory_allocated_avg = sum(accelerator_memory_allocated_log[-eval_steps:]) / eval_steps |
| memory_reserved_avg = sum(accelerator_memory_reserved_log[-eval_steps:]) / eval_steps |
| token_sum = sum(total_tokens[-eval_steps:]) |
| dur_train = sum(durations[-eval_steps:]) |
| tokens_per_sec = token_sum / dur_train |
|
|
| model.eval() |
| predictions, responses = evaluate( |
| model=model, |
| tokenizer=tokenizer, |
| ds=ds_valid, |
| batch_size=batch_size_eval, |
| generate_kwargs={**generation_kwargs}, |
| ) |
| model.train() |
|
|
| example = random.choice(predictions) |
| example = textwrap.shorten(example, width=750) |
| example = textwrap.indent(example, " ") |
| print_verbose(f"\nExample prediction:\n{example}\n") |
| accuracy = get_accuracy(predictions=predictions, responses=responses) |
| num_tokens_generated = sum(sum(mask) for mask in tokenizer(predictions)["attention_mask"]) |
|
|
| toc_eval = time.perf_counter() |
| dur_eval = toc_eval - tic_eval |
| eval_time += toc_eval - tic_eval |
| elapsed = time.perf_counter() - tic_train |
|
|
| metrics.append( |
| { |
| "step": step, |
| "valid accuracy": accuracy, |
| "train loss": loss_avg, |
| "train samples": total_samples, |
| "train time": dur_train, |
| "eval time": dur_eval, |
| "tokens / sec": tokens_per_sec, |
| "mem allocated avg": memory_allocated_avg, |
| "mem reserved avg": memory_reserved_avg, |
| "elapsed time": elapsed, |
| } |
| ) |
|
|
| log_dict = { |
| "step": f"{step:5d}", |
| "samples": f"{total_samples:7d}", |
| "lr": f"{lr_scheduler.get_last_lr()[0]:.2e}", |
| "loss avg": f"{loss_avg:.4f}", |
| "valid acc": f"{accuracy:.3f}", |
| "gen valid tokens": num_tokens_generated, |
| "train time": f"{dur_train:.1f}s", |
| "eval time": f"{dur_eval:.1f}s", |
| "train tokens / sec": f"{tokens_per_sec:.0f}", |
| "mem allocated": f"{memory_allocated_avg:.0f}", |
| "mem reserved": f"{memory_reserved_avg:.0f}", |
| "elapsed time": f"{elapsed // 60:.0f}min {elapsed % 60:.0f}s", |
| } |
| print_verbose(json.dumps(log_dict)) |
|
|
| if step % ACCELERATOR_EMPTY_CACHE_SCHEDULE == 0: |
| torch_accelerator_module.empty_cache() |
|
|
| print_verbose(f"Training finished after {max_steps} steps, evaluation on test set follows.") |
| |
| model.eval() |
| predictions, responses = evaluate( |
| model=model, |
| tokenizer=tokenizer, |
| ds=ds_test, |
| batch_size=batch_size_eval, |
| generate_kwargs={**generation_kwargs, "pad_token_id": tokenizer.eos_token_id}, |
| use_tqdm=len(ds_test) > 100, |
| ) |
| accuracy = get_accuracy(predictions=predictions, responses=responses) |
| |
| wiki_loss_after = calculate_mean_per_token_loss( |
| model=model, tokenizer=tokenizer, rows=rows_wiki, batch_size=batch_size, max_length=768 |
| ) |
| forgetting = wiki_loss_after - wiki_loss_before |
| metrics.append( |
| { |
| "step": step, |
| "test accuracy": accuracy, |
| "train loss": sum(losses[-eval_steps:]) / eval_steps, |
| "train samples": total_samples, |
| "train total tokens": sum(total_tokens), |
| "forgetting": forgetting, |
| } |
| ) |
| print_verbose(f"Test accuracy: {accuracy:.3f}") |
|
|
| except KeyboardInterrupt: |
| print_verbose("canceled training") |
| status = TrainStatus.CANCELED |
| error_msg = "manually canceled" |
| except torch.OutOfMemoryError as exc: |
| |
| print_verbose("out of memory error encountered") |
| status = TrainStatus.CANCELED |
| error_msg = str(exc) |
| except Exception as exc: |
| print_verbose(f"encountered an error: {exc}") |
| status = TrainStatus.CANCELED |
| error_msg = str(exc) |
|
|
| toc_train = time.perf_counter() |
| train_time = toc_train - tic_train - eval_time |
|
|
| if status != TrainStatus.CANCELED: |
| status = TrainStatus.SUCCESS |
| train_result = TrainResult( |
| status=status, |
| train_time=train_time, |
| accelerator_memory_reserved_log=accelerator_memory_reserved_log, |
| losses=losses, |
| metrics=metrics, |
| error_msg=error_msg, |
| num_trainable_params=num_trainable_params, |
| num_total_params=num_params, |
| ) |
| return train_result |
|
|
|
|
| def main(*, path_experiment: str, experiment_name: str, clean: bool) -> None: |
| tic_total = time.perf_counter() |
| start_date = dt.datetime.now(tz=dt.timezone.utc).replace(microsecond=0).isoformat() |
|
|
| peft_branch = get_peft_branch() |
| if peft_branch == "main": |
| print_verbose("===== This experiment is categorized as a MAIN run because the PEFT branch is 'main' ======") |
| else: |
| print_verbose( |
| f"===== This experiment is categorized as a TEST run because the PEFT branch is '{peft_branch}' ======" |
| ) |
|
|
| |
| peft_config: Optional[PeftConfig] = None |
| if os.path.exists(os.path.join(path_experiment, CONFIG_NAME)): |
| peft_config = PeftConfig.from_pretrained(path_experiment) |
| else: |
| print_verbose(f"Could not find PEFT config at {path_experiment}, performing FULL FINETUNING") |
| path_train_config = os.path.join(path_experiment, FILE_NAME_TRAIN_PARAMS) |
| train_config = get_train_config(path_train_config) |
| set_seed(train_config.seed) |
|
|
| |
| accelerator_memory_init = init_accelerator() |
| tokenizer = get_tokenizer(model_id=train_config.model_id, max_seq_length=train_config.max_seq_length) |
|
|
| model_info = get_base_model_info(train_config.model_id) |
| metamath_info = get_dataset_info("meta-math/MetaMathQA") |
| gsm8k_info = get_dataset_info("openai/gsm8k") |
| model = get_model( |
| model_id=train_config.model_id, |
| dtype=train_config.dtype, |
| compile=train_config.compile, |
| attn_implementation=train_config.attn_implementation, |
| peft_config=peft_config, |
| autocast_adapter_dtype=train_config.autocast_adapter_dtype, |
| ) |
| print_verbose(model) |
|
|
| |
| train_result = train( |
| model=model, |
| max_steps=train_config.max_steps, |
| batch_size=train_config.batch_size, |
| batch_size_eval=train_config.batch_size_eval, |
| tokenizer=tokenizer, |
| accelerator_memory_init=accelerator_memory_init, |
| eval_steps=train_config.eval_steps, |
| generation_kwargs=train_config.generation_kwargs, |
| grad_norm_clip=train_config.grad_norm_clip, |
| optimizer_type=train_config.optimizer_type, |
| optimizer_kwargs=train_config.optimizer_kwargs, |
| query_template=train_config.query_template, |
| lr_scheduler_arg=train_config.lr_scheduler, |
| use_amp=train_config.use_amp, |
| is_adalora=isinstance(peft_config, AdaLoraConfig), |
| ) |
|
|
| if train_result.status == TrainStatus.FAILED: |
| print_verbose("Training failed, not logging results") |
| sys.exit(1) |
|
|
| file_size = get_file_size( |
| model, |
| peft_config=peft_config, |
| clean=clean, |
| print_fn=print_verbose, |
| ) |
|
|
| time_total = time.perf_counter() - tic_total |
| |
| log_results( |
| experiment_name=experiment_name, |
| train_result=train_result, |
| accelerator_memory_init=accelerator_memory_init, |
| time_total=time_total, |
| file_size=file_size, |
| model_info=model_info, |
| datasets_info={"metamath": metamath_info, "gsm8k": gsm8k_info}, |
| start_date=start_date, |
| train_config=train_config, |
| peft_config=peft_config, |
| print_fn=print_verbose, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output") |
| parser.add_argument("path_experiment", type=str, help="Path to the experiment directory") |
| parser.add_argument( |
| "--clean", |
| action="store_true", |
| help="Delete training artifacts after run finishes (logs are still saved)", |
| ) |
| args = parser.parse_args() |
|
|
| experiment_name = validate_experiment_path(args.path_experiment) |
|
|
| if args.verbose: |
|
|
| def print_verbose(*args, **kwargs) -> None: |
| kwargs["file"] = sys.stderr |
| print(*args, **kwargs) |
| else: |
|
|
| def print_verbose(*args, **kwargs) -> None: |
| pass |
|
|
| main( |
| path_experiment=args.path_experiment, |
| experiment_name=experiment_name, |
| clean=args.clean, |
| ) |
|
|