File size: 4,794 Bytes
2ee4cd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Train a diffnext model."""

import json
import os

from diffnext.engine.train_engine import Trainer
from diffnext.engine.train_engine import engine_utils
from diffnext.utils import accelerate_utils
from diffnext.utils import omegaconf_utils


def prepare_checkpoints(config):
    """Prepare checkpoints for model resuming.

    Args:
        config (omegaconf.DictConfig)
            The model config.
    """
    config.experiment.setdefault("resume_from_checkpoint", "")
    ckpt_dir = os.path.abspath(os.path.join(config.experiment.output_dir, "checkpoints"))
    resume_iter, _ = 0, os.makedirs(ckpt_dir, exist_ok=True)
    if config.experiment.resume_from_checkpoint == "latest":
        ckpts = [_ for _ in os.listdir(ckpt_dir) if _.startswith("checkpoint-")]
        if ckpts:
            resume_iter, ckpt = sorted((int(_.split("-")[-1]), _) for _ in ckpts)[-1]
            config.experiment.resume_from_checkpoint = os.path.join(ckpt_dir, ckpt)
    elif config.experiment.resume_from_checkpoint:
        resume_iter = int(os.path.split(config.experiment.resume_from_checkpoint).split("-")[-1])
    config.experiment.resume_iter = resume_iter
    if resume_iter and not hasattr(config.model, "lora"):  # Override the pretrained path.
        config.pipeline.paths.pretrained_path = config.experiment.resume_from_checkpoint


def prepare_datasets(config, accelerator):
    """Prepare datasets for model training.

    Args:
        config (omegaconf.DictConfig)
            The model config.
        accelerator (accelerate.Accelerator)
            The accelerator instance.
    """
    dataset = config.train_dataloader.params.dataset
    metadata = json.load(open(os.path.join(dataset, "METADATA")))
    config.train_dataloader.params.max_examples = metadata["entries"]
    if "batch_size" in metadata:
        batch_size = metadata["batch_size"][accelerator.process_index]
        bucket_dataset = dataset + "/" + str(accelerator.process_index).zfill(3)
        config.train_dataloader.params.dataset = bucket_dataset
        config.train_dataloader.params.batch_size = config.training.batch_size = batch_size
        if "num_metrics" in metadata:
            config.training.num_metrics = metadata["num_metrics"]
    elif "shard_id" not in config.train_dataloader.params:
        # By default, we use dataset shards across all processes.
        config.train_dataloader.params.update(accelerate_utils.get_ddp_shards(accelerator))


def run_train(config, accelerator, logger):
    """Start a model training task.

    Args:
        config (omegaconf.DictConfig)
            The model config.
        accelerator (accelerate.Accelerator)
            The accelerator instance.
        logger (logging.Logger)
            The logger instance.
    """
    trainer = Trainer(config, accelerator, logger)
    if accelerator.is_main_process:  # Configs have already been determined.
        config_path = os.path.join(config.experiment.output_dir, "config.yaml")
        omegaconf_utils.save_config(config, config_path)
    logger.info("#Params: %.2fM" % engine_utils.count_params(trainer.model))
    logger.info("Start training...")
    trainer.train_loop()
    trainer.ema.update(trainer.model) if trainer.ema else None
    trainer.save()


def main():
    """Main entry point."""
    config = omegaconf_utils.get_config()
    accelerator = accelerate_utils.build_accelerator(config, log_with="wandb")
    accelerate_utils.build_wandb(config, accelerator=accelerator)
    logger = accelerate_utils.set_logger(config.experiment.output_dir, accelerator=accelerator)
    device_seed = config.training.seed + accelerator.process_index
    config.training.gpu_id, config.training.seed = accelerator.device.index, device_seed
    engine_utils.manual_seed(config.training.seed, (config.training.gpu_id, device_seed))
    prepare_checkpoints(config), prepare_datasets(config, accelerator)
    logger.info(f"Config:\n{omegaconf_utils.config_to_yaml(config)}")
    run_train(config, accelerator, logger)


if __name__ == "__main__":
    main()