File size: 4,794 Bytes
2ee4cd6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | # ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Train a diffnext model."""
import json
import os
from diffnext.engine.train_engine import Trainer
from diffnext.engine.train_engine import engine_utils
from diffnext.utils import accelerate_utils
from diffnext.utils import omegaconf_utils
def prepare_checkpoints(config):
"""Prepare checkpoints for model resuming.
Args:
config (omegaconf.DictConfig)
The model config.
"""
config.experiment.setdefault("resume_from_checkpoint", "")
ckpt_dir = os.path.abspath(os.path.join(config.experiment.output_dir, "checkpoints"))
resume_iter, _ = 0, os.makedirs(ckpt_dir, exist_ok=True)
if config.experiment.resume_from_checkpoint == "latest":
ckpts = [_ for _ in os.listdir(ckpt_dir) if _.startswith("checkpoint-")]
if ckpts:
resume_iter, ckpt = sorted((int(_.split("-")[-1]), _) for _ in ckpts)[-1]
config.experiment.resume_from_checkpoint = os.path.join(ckpt_dir, ckpt)
elif config.experiment.resume_from_checkpoint:
resume_iter = int(os.path.split(config.experiment.resume_from_checkpoint).split("-")[-1])
config.experiment.resume_iter = resume_iter
if resume_iter and not hasattr(config.model, "lora"): # Override the pretrained path.
config.pipeline.paths.pretrained_path = config.experiment.resume_from_checkpoint
def prepare_datasets(config, accelerator):
"""Prepare datasets for model training.
Args:
config (omegaconf.DictConfig)
The model config.
accelerator (accelerate.Accelerator)
The accelerator instance.
"""
dataset = config.train_dataloader.params.dataset
metadata = json.load(open(os.path.join(dataset, "METADATA")))
config.train_dataloader.params.max_examples = metadata["entries"]
if "batch_size" in metadata:
batch_size = metadata["batch_size"][accelerator.process_index]
bucket_dataset = dataset + "/" + str(accelerator.process_index).zfill(3)
config.train_dataloader.params.dataset = bucket_dataset
config.train_dataloader.params.batch_size = config.training.batch_size = batch_size
if "num_metrics" in metadata:
config.training.num_metrics = metadata["num_metrics"]
elif "shard_id" not in config.train_dataloader.params:
# By default, we use dataset shards across all processes.
config.train_dataloader.params.update(accelerate_utils.get_ddp_shards(accelerator))
def run_train(config, accelerator, logger):
"""Start a model training task.
Args:
config (omegaconf.DictConfig)
The model config.
accelerator (accelerate.Accelerator)
The accelerator instance.
logger (logging.Logger)
The logger instance.
"""
trainer = Trainer(config, accelerator, logger)
if accelerator.is_main_process: # Configs have already been determined.
config_path = os.path.join(config.experiment.output_dir, "config.yaml")
omegaconf_utils.save_config(config, config_path)
logger.info("#Params: %.2fM" % engine_utils.count_params(trainer.model))
logger.info("Start training...")
trainer.train_loop()
trainer.ema.update(trainer.model) if trainer.ema else None
trainer.save()
def main():
"""Main entry point."""
config = omegaconf_utils.get_config()
accelerator = accelerate_utils.build_accelerator(config, log_with="wandb")
accelerate_utils.build_wandb(config, accelerator=accelerator)
logger = accelerate_utils.set_logger(config.experiment.output_dir, accelerator=accelerator)
device_seed = config.training.seed + accelerator.process_index
config.training.gpu_id, config.training.seed = accelerator.device.index, device_seed
engine_utils.manual_seed(config.training.seed, (config.training.gpu_id, device_seed))
prepare_checkpoints(config), prepare_datasets(config, accelerator)
logger.info(f"Config:\n{omegaconf_utils.config_to_yaml(config)}")
run_train(config, accelerator, logger)
if __name__ == "__main__":
main()
|