File size: 4,538 Bytes
1146a67 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | import os, torch
import wandb # 新增
from tqdm import tqdm
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger
def launch_training_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
learning_rate: float = 1e-5,
weight_decay: float = 1e-2,
num_workers: int = 1,
save_steps: int = None,
num_epochs: int = 1,
args = None,
):
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
collate_fn=lambda x: x[0],
num_workers=num_workers,
)
model, optimizer, dataloader, scheduler = accelerator.prepare(
model, optimizer, dataloader, scheduler
)
global_step = 0 # 用于 wandb 记录全局 step
for epoch_id in range(num_epochs):
# 只在本地主进程显示 tqdm,避免多卡重复进度条
pbar = tqdm(
dataloader,
disable=not accelerator.is_local_main_process,
desc=f"Epoch {epoch_id}",
)
for data in pbar:
with accelerator.accumulate(model):
optimizer.zero_grad()
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps)
scheduler.step()
global_step += 1
# ============= wandb logging(只在主进程) =============
if (
args is not None
and hasattr(args, "wandb_mode")
and args.wandb_mode != "disabled"
and accelerator.is_main_process
):
log_every = getattr(args, "wandb_log_every", 10)
if global_step % log_every == 0:
# 这里直接用当前进程的 loss 就够了
loss_value = loss.detach().float().item()
try:
lr = scheduler.get_last_lr()[0]
except Exception:
lr = optimizer.param_groups[0]["lr"]
wandb.log(
{
"train/loss": loss_value,
"train/lr": lr,
"train/epoch": epoch_id,
"train/step": global_step,
}
)
# =======================================================
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)
def launch_data_process_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
num_workers: int = 8,
args = None,
):
if args is not None:
num_workers = args.dataset_num_workers
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=False,
collate_fn=lambda x: x[0],
num_workers=num_workers,
)
model, dataloader = accelerator.prepare(model, dataloader)
for data_id, data in enumerate(tqdm(
dataloader,
disable=not accelerator.is_local_main_process,
desc="Data process",
)):
with accelerator.accumulate(model):
with torch.no_grad():
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
os.makedirs(folder, exist_ok=True)
save_path = os.path.join(
model_logger.output_path,
str(accelerator.process_index),
f"{data_id}.pth",
)
data = model(data)
torch.save(data, save_path)
|