| """Common training utilities for BitTransformer models.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Callable, Dict, List, Optional |
| import contextlib |
| import sys |
| import warnings |
| import math |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| from .compression import compress_bits, pack_bits, unpack_bits |
| from .optimization import configure_optimizer |
| from .model import BitTransformerLM |
| from .utils import set_dropout |
| from .torch_utils import cpu_autocast |
|
|
|
|
| def cosine_ramp(step: int, start: float, end: float, total_steps: int) -> float: |
| """Cosine ramp from ``start`` to ``end`` over ``total_steps``.""" |
| if total_steps <= 0 or step >= total_steps: |
| return end |
| cos_inner = math.pi * step / total_steps |
| return start + (end - start) * (1 - math.cos(cos_inner)) / 2 |
|
|
|
|
| def train_loop( |
| model: BitTransformerLM, |
| data: torch.Tensor, |
| *, |
| epochs: int = 1, |
| extra_steps: int = 0, |
| compress_prob: float = 0.5, |
| direct_prob: float = 0.0, |
| batch_size: int = 8, |
| num_workers: int = 0, |
| accum_steps: int = 1, |
| amp: bool = False, |
| compile_model: bool = False, |
| log: bool = False, |
| forward_kwargs: Optional[Dict] = None, |
| optimizer: Optional[torch.optim.Optimizer] = None, |
| scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, |
| diffusion: bool = False, |
| noise_fn: Optional[Callable[[], float]] = None, |
| diffusion_curriculum: bool = False, |
| compress_warmup: int = 0, |
| ) -> List[Dict[str, float]]: |
| """Generic training loop supporting optional compression and diffusion. |
| |
| ``compress_prob`` controls the fraction of batches that are run through |
| ``forward_compressed``. ``direct_prob`` instead feeds the model with the |
| bit-packed result of ``compress_bits`` after converting back to a bit |
| tensor. When enabled, metrics for direct-compressed batches are tracked |
| separately. |
| |
| When ``diffusion`` is ``True`` the loop performs denoising training. Batches |
| are noised by randomly flipping bits with a probability given by |
| ``noise_fn`` (defaulting to a uniform draw in ``[0, 0.5]``). When |
| ``diffusion_curriculum`` is ``True`` the noise probability decreases |
| linearly from ``0.5`` to ``0.0`` over the training epochs. The model is |
| then trained to recover the clean sequence using full-context attention |
| (``causal=False``). |
| |
| Existing ``optimizer`` and ``scheduler`` instances may be supplied to allow |
| integration with long-running training sessions, otherwise new ones are |
| created automatically. |
| """ |
| if compile_model and sys.version_info < (3, 12) and torch.__version__ >= "2.1": |
| model = torch.compile(model) |
| elif compile_model: |
| warnings.warn("torch.compile skipped: requires torch>=2.1 and Python<3.12") |
|
|
| model.train() |
| set_dropout(model, 0.1) |
|
|
| device = next(model.parameters()).device |
| loader = DataLoader( |
| data, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=num_workers, |
| persistent_workers=num_workers > 0, |
| ) |
| steps_per_epoch = max(1, len(loader)) |
| total_updates = math.ceil(epochs * (steps_per_epoch + extra_steps) / accum_steps) |
| if optimizer is None or scheduler is None: |
| optimizer, scheduler = configure_optimizer( |
| model, lr=1e-3, total_steps=total_updates |
| ) |
| metrics: List[Dict[str, float]] = [] |
|
|
| global_step = 0 |
| for epoch in range(epochs): |
| raw_losses: List[float] = [] |
| raw_accs: List[float] = [] |
| comp_losses: List[float] = [] |
| comp_accs: List[float] = [] |
| comp_ratios: List[float] = [] |
| direct_losses: List[float] = [] |
|
|
| last_batch = None |
| for step, batch in enumerate(loader): |
| last_batch = batch |
| batch = batch.to(device) |
| cur_compress = ( |
| cosine_ramp(global_step, 0.0, compress_prob, compress_warmup) |
| if not diffusion |
| else compress_prob |
| ) |
| if diffusion: |
| if diffusion_curriculum: |
| p = 0.5 * (1 - epoch / max(1, epochs - 1)) |
| else: |
| p = noise_fn() if noise_fn is not None else float(torch.rand(()) * 0.5) |
| noise = (torch.rand_like(batch.float()) < p).long() |
| noisy = batch ^ noise |
| with ( |
| torch.cuda.amp.autocast(dtype=torch.bfloat16) |
| if amp and torch.cuda.is_available() |
| else cpu_autocast() if amp else contextlib.nullcontext() |
| ): |
| logits, _ = model(noisy, causal=False) |
| pred = logits.reshape(-1, 2) |
| target = batch.reshape(-1) |
| loss = F.cross_entropy(pred, target) / accum_steps |
| acc = (pred.argmax(dim=-1) == target).float().mean().item() |
| raw_losses.append(loss.item() * accum_steps) |
| raw_accs.append(acc) |
| loss.backward() |
| if (step + 1) % accum_steps == 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| global_step += 1 |
| continue |
|
|
| r = torch.rand(()) |
| key = "raw" |
| ratio = 1.0 |
| target = batch[:, 1:].reshape(-1) |
|
|
| if r < direct_prob: |
| packed = [pack_bits(row.to(torch.uint8)) for row in batch] |
| unpacked = [unpack_bits(p, n_bits=batch.size(1)) for p in packed] |
| max_len = min( |
| max(u.numel() for u in unpacked), |
| model.pos_enc.pe.size(0), |
| ) |
| padded = [F.pad(u[:max_len], (0, max_len - min(u.numel(), max_len))) for u in unpacked] |
| dc_batch = torch.stack(padded).long() |
| with ( |
| torch.cuda.amp.autocast(dtype=torch.bfloat16) |
| if amp and torch.cuda.is_available() |
| else cpu_autocast() if amp else contextlib.nullcontext() |
| ): |
| logits, _ = model(dc_batch, **(forward_kwargs or {})) |
| ratio = sum(p.numel() for p in packed) / batch.numel() |
| target = dc_batch[:, 1:].reshape(-1) |
| key = "direct" |
| elif r < direct_prob + cur_compress: |
| comp_batch = [compress_bits(row.to(torch.uint8)) for row in batch] |
| with ( |
| torch.cuda.amp.autocast(dtype=torch.bfloat16) |
| if amp and torch.cuda.is_available() |
| else cpu_autocast() if amp else contextlib.nullcontext() |
| ): |
| logits, _ = model.forward_compressed(comp_batch, **(forward_kwargs or {})) |
| ratio = sum(c.numel() for c in comp_batch) / batch.numel() |
| target = batch[:, 1:].reshape(-1) |
| key = "compressed" |
| else: |
| with ( |
| torch.cuda.amp.autocast(dtype=torch.bfloat16) |
| if amp and torch.cuda.is_available() |
| else cpu_autocast() if amp else contextlib.nullcontext() |
| ): |
| logits, _ = model(batch, **(forward_kwargs or {})) |
|
|
| pred = logits[:, :-1, :].reshape(-1, 2) |
| loss = F.cross_entropy(pred, target) / accum_steps |
| acc = (pred.argmax(dim=-1) == target).float().mean().item() |
|
|
| loss.backward() |
| if (step + 1) % accum_steps == 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| global_step += 1 |
|
|
| if key == "compressed": |
| comp_losses.append(loss.item() * accum_steps) |
| comp_accs.append(acc) |
| comp_ratios.append(ratio) |
| elif key == "direct": |
| direct_losses.append(loss.item() * accum_steps) |
| comp_ratios.append(ratio) |
| else: |
| raw_losses.append(loss.item() * accum_steps) |
| raw_accs.append(acc) |
|
|
| |
| if extra_steps > 0 and last_batch is not None and not diffusion: |
| for step in range(extra_steps): |
| with ( |
| torch.cuda.amp.autocast(dtype=torch.bfloat16) |
| if amp and torch.cuda.is_available() |
| else cpu_autocast() if amp else contextlib.nullcontext() |
| ): |
| logits, _ = model(last_batch, **(forward_kwargs or {})) |
| pred = logits[:, :-1, :].reshape(-1, 2) |
| target = last_batch[:, 1:].reshape(-1) |
| loss = F.cross_entropy(pred, target) / accum_steps |
| acc = (pred.argmax(dim=-1) == target).float().mean().item() |
| loss.backward() |
| if (step + 1) % accum_steps == 0: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| raw_losses.append(loss.item() * accum_steps) |
| raw_accs.append(acc) |
| global_step += 1 |
|
|
| m = { |
| "raw_loss": float(sum(raw_losses) / len(raw_losses)) if raw_losses else 0.0, |
| "raw_acc": float(sum(raw_accs) / len(raw_accs)) if raw_accs else 0.0, |
| "compressed_loss": float(sum(comp_losses) / len(comp_losses)) if comp_losses else 0.0, |
| "compressed_acc": float(sum(comp_accs) / len(comp_accs)) if comp_accs else 0.0, |
| "direct_loss": float(sum(direct_losses) / len(direct_losses)) if direct_losses else 0.0, |
| "compression_ratio": float(sum(comp_ratios) / len(comp_ratios)) if comp_ratios else 0.0, |
| } |
| metrics.append(m) |
|
|
| if log: |
| print( |
| f"Epoch {epoch} " |
| f"raw_loss={m['raw_loss']:.4f} acc={m['raw_acc']:.3f} | " |
| f"compressed_loss={m['compressed_loss']:.4f} acc={m['compressed_acc']:.3f} " |
| f"direct_loss={m['direct_loss']:.4f} ratio={m['compression_ratio']:.2f}" |
| ) |
|
|
| return metrics |
|
|
| __all__ = ["train_loop"] |
|
|