File size: 8,927 Bytes
b6ff324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Custom trainer focused on data parallelism specialization."""

import collections
import os
import shutil

import torch

from diffnext.engine import engine_utils
from diffnext.engine.model_ema import ModelEMA
from diffnext.pipelines.builder import build_pipeline
from diffnext.pipelines.builder import get_pipeline_path
from diffnext.utils import accelerate_utils
from diffnext.utils import profiler
from diffnext.utils.omegaconf_utils import config_to_class
from diffnext.utils.omegaconf_utils import config_to_object


class Trainer(object):
    """Schedule the iterative model training."""

    def __init__(self, config, accelerator, logger):
        """Create a trainer instance."""
        self.config, self.accelerator, self.logger = config, accelerator, logger
        self.dtype = accelerate_utils.precision_to_dtype(config.training.mixed_precision)
        self.train_dataloader = config_to_object(config.train_dataloader)
        self.pipe_path = get_pipeline_path(**config.pipeline.paths)
        self.pipe = build_pipeline(self.pipe_path, config_to_class(config.pipeline), self.dtype)
        self.pipe = self.pipe.to(device=engine_utils.get_device(config.training.gpu_id))
        self.ema = ModelEMA(self.pipe.model, **config.ema.params) if "ema" in config else None
        self.model = self.pipe.configure_model(config, accelerator, logger)
        param_groups = [_ for _ in self.model.parameters() if _.requires_grad]
        if config.optimizer.get("param_groups", True):
            param_groups = engine_utils.get_param_groups(self.model)
        self.optimizer = config_to_object(config.optimizer, params=param_groups)
        self.scheduler = config_to_object(config.lr_scheduler)
        self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
        if config.training.get("sequence_parallel_size", 1) > 1:
            if not hasattr(self.pipe.model, "configure_sequence_parallel"):
                raise RuntimeError("Model does not support sequence parallelism.")
            self.pipe.model.configure_sequence_parallel(self.model.seq_parallel_group)
        self.metrics = collections.OrderedDict()
        if self.ema and config.experiment.resume_iter > 0:
            ckpt = config.experiment.resume_from_checkpoint
            ema_ckpt = ckpt.replace("checkpoints", "ema_checkpoints")
            ema_weights = os.path.join(ema_ckpt, config.model.name, "diffusion_pytorch_model.bin")
            engine_utils.load_weights(self.ema.model, ema_weights)

    @property
    def global_step(self) -> int:
        """Return the global iteration step.

        Returns:
            int: The global step.
        """
        return self.scheduler._step_count

    def save(self):
        """Save the checkpoint of current iterative step."""
        f = "checkpoint-{}/{}".format(self.global_step, self.config.model.name)
        f = os.path.join(self.config.experiment.output_dir, "checkpoints", f)
        if self.accelerator.is_main_process and not os.path.exists(f):
            self.model.save_pretrained(f, safe_serialization=False)
            self.logger.info("Wrote snapshot to: {:s}".format(f))
            if self.ema is not None:
                config_json = os.path.join(f, "config.json")
                f = f.replace("checkpoints", "ema_checkpoints")
                os.makedirs(f), shutil.copy(config_json, os.path.join(f, "config.json"))
                f = os.path.join(f, "diffusion_pytorch_model.bin")
                torch.save(self.ema.model.state_dict(), f)

    def add_metrics(self, stats):
        """Add or update the metrics.

        Args:
            stats (Dict)
                The current iteration stats.
        """
        for k, v in stats["metrics"].items():
            if k not in self.metrics:
                self.metrics[k] = profiler.SmoothedValue()
            self.metrics[k].update(v)

    def log_metrics(self, stats):
        """Send metrics to available trackers.

        Args:
            stats (Dict)
                The current iteration stats.
        """
        iter_template = "Iteration %d, lr = %.8f, time = %.2fs"
        metric_template = " " * 4 + "Train net output({}): {}"
        [self.logger.info(iter_template % (stats["step"], stats["lr"], stats["time"]))]
        [self.logger.info(metric_template.format(k, v)) for k, v in self.metrics.items()]
        tracker_logs = dict((k, v.median) for k, v in self.metrics.items())
        tracker_logs.update({"lr": stats["lr"], "time": stats["time"]})
        self.accelerator.log(tracker_logs, step=stats["step"])
        self.metrics.clear()

    def run_model(self, inputs, metrics, accum_steps=1):
        """Run multiple model steps.

        Args:
            inputs (Dict)
                The model inputs.
            metrics (Dict)
                The current iteration metrics.
            accum_step (int, optional, defaults to 1)
                The gradient accumulation steps.

        """
        for _ in range(accum_steps):
            inputs = inputs if inputs else self.train_dataloader.next()[0]
            outputs, losses = self.model(inputs), []
            for k, v in outputs.items():
                if "loss" not in k and "metric" not in k:
                    continue
                if isinstance(v, torch.Tensor) and v.requires_grad:
                    losses.append(v)
                if k.startswith("metric/"):  # Custom metrics.
                    metrics[k[len("metric/") :]] += float(v.mean()) / accum_steps
                elif f"metric/{k}" not in outputs:  # Legacy metrics.
                    metrics[k] += float(self.accelerator.gather(v).mean()) / accum_steps
            losses = sum(losses[1:], losses[0])
            self.accelerator.accumulate().__enter__()
            self.accelerator.backward(losses)

    def run_step(self, inputs, accum_steps=1) -> dict:
        """Run single iteration step.

        Args:
            inputs (Dict)
                The model inputs.
            accum_step (int, optional, defaults to 1)
                The gradient accumulation steps.

        Returns:
            Dict: The current iteration stats.
        """
        stats = {"step": self.global_step}
        metrics = collections.defaultdict(float)
        timer = profiler.Timer().tic()
        stats["lr"] = self.scheduler.get_lr()
        for group in self.optimizer.param_groups:
            group["lr"] = stats["lr"] * group.get("lr_scale", 1.0)
        self.run_model(inputs, metrics, accum_steps)
        self.optimizer.step()
        self.optimizer.zero_grad(set_to_none=True)
        self.scheduler.step()
        stats["time"] = timer.toc()
        stats["metrics"] = collections.OrderedDict(sorted(metrics.items()))
        return stats

    def train_loop(self):
        """Training loop."""
        timer = profiler.Timer()
        max_steps = self.config.training.max_train_steps
        accum_steps = self.config.training.gradient_accumulation_steps
        log_every = self.config.experiment.log_every
        save_every = self.config.experiment.save_every
        data_every, inputs = self.config.training.get("data_every", -1), {}
        self.scheduler._step_count = self.config.experiment.get("resume_iter", 0)
        while self.global_step < max_steps:
            if data_every >= 1 and self.global_step % data_every == 0:
                inputs = self.train_dataloader.next()[0]
            with timer.tic_and_toc():
                stats = self.run_step(inputs, accum_steps)
            self.add_metrics(stats)
            if stats["step"] % log_every == 0:
                self.log_metrics(stats)
            if self.global_step % (10 * log_every) == 0:
                self.logger.info(profiler.get_progress(timer, self.global_step, max_steps))
            if self.ema and self.global_step % self.ema.update_every == 0:
                self.ema.update(self.model)
            if self.global_step % save_every == 0:
                self.save()
        stats["step"] = self.global_step
        self.log_metrics({**stats, **{"step": self.global_step}})
        self.accelerator.wait_for_everyone()
        self.accelerator.end_training()