PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified
import os, torch
import wandb # 新增
from tqdm import tqdm
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger
def launch_training_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
learning_rate: float = 1e-5,
weight_decay: float = 1e-2,
num_workers: int = 1,
save_steps: int = None,
num_epochs: int = 1,
args = None,
):
if args is not None:
learning_rate = args.learning_rate
weight_decay = args.weight_decay
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
collate_fn=lambda x: x[0],
num_workers=num_workers,
)
model, optimizer, dataloader, scheduler = accelerator.prepare(
model, optimizer, dataloader, scheduler
)
global_step = 0 # 用于 wandb 记录全局 step
for epoch_id in range(num_epochs):
# 只在本地主进程显示 tqdm,避免多卡重复进度条
pbar = tqdm(
dataloader,
disable=not accelerator.is_local_main_process,
desc=f"Epoch {epoch_id}",
)
for data in pbar:
with accelerator.accumulate(model):
optimizer.zero_grad()
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_step_end(accelerator, model, save_steps)
scheduler.step()
global_step += 1
# ============= wandb logging(只在主进程) =============
if (
args is not None
and hasattr(args, "wandb_mode")
and args.wandb_mode != "disabled"
and accelerator.is_main_process
):
log_every = getattr(args, "wandb_log_every", 10)
if global_step % log_every == 0:
# 这里直接用当前进程的 loss 就够了
loss_value = loss.detach().float().item()
try:
lr = scheduler.get_last_lr()[0]
except Exception:
lr = optimizer.param_groups[0]["lr"]
wandb.log(
{
"train/loss": loss_value,
"train/lr": lr,
"train/epoch": epoch_id,
"train/step": global_step,
}
)
# =======================================================
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)
def launch_data_process_task(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule,
model_logger: ModelLogger,
num_workers: int = 8,
args = None,
):
if args is not None:
num_workers = args.dataset_num_workers
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=False,
collate_fn=lambda x: x[0],
num_workers=num_workers,
)
model, dataloader = accelerator.prepare(model, dataloader)
for data_id, data in enumerate(tqdm(
dataloader,
disable=not accelerator.is_local_main_process,
desc="Data process",
)):
with accelerator.accumulate(model):
with torch.no_grad():
folder = os.path.join(model_logger.output_path, str(accelerator.process_index))
os.makedirs(folder, exist_ok=True)
save_path = os.path.join(
model_logger.output_path,
str(accelerator.process_index),
f"{data_id}.pth",
)
data = model(data)
torch.save(data, save_path)