File size: 4,235 Bytes
b6ff324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Accelerate utilities."""

import atexit
import functools
import logging
import os
import sys
import time

import accelerate
import torch
import wandb

from diffnext.utils.omegaconf_utils import flatten_omega_conf

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs


def build_accelerator(config, **kwargs) -> accelerate.Accelerator:
    """Build accelerator."""
    
    kwargs_handlers = []

    # 对普通 DDP 开启 unused param 检测
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    kwargs_handlers.append(ddp_kwargs)
    
    accelerator = accelerate.Accelerator(
        log_with=kwargs.get("log_with", None),
        mixed_precision=config.training.mixed_precision,
        gradient_accumulation_steps=config.training.gradient_accumulation_steps,
        kwargs_handlers=kwargs_handlers,
    )
    if hasattr(accelerator.state.deepspeed_plugin, "deepspeed_config"):
        import deepspeed

        deepspeed.logger.setLevel(kwargs.get("deepspeed_log_lvl", "WARNING"))
        # Dummy size to avoid the raised errors.
        accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 1
    return accelerator


def build_wandb(config, accelerator):
    """Build wandb for accelerator."""
    if "wandb" not in config or not accelerator.is_main_process:
        return
    config.wandb = config.wandb or type(config)({})
    old_run_id = config.wandb.get("run_id", None)
    config.wandb.run_id = run_id = old_run_id or wandb.util.generate_id()
    init_kwargs = dict(id=run_id, name=config.experiment.name, resume=old_run_id is not None)
    init_kwargs["config"] = {k: v for k, v in flatten_omega_conf(config, True)}
    accelerator.init_trackers(config.experiment.project, init_kwargs={"wandb": init_kwargs})


def get_ddp_shards(accelerator) -> dict:
    """Return the shard arguments for simple DDP."""
    return {"shard_id": accelerator.process_index, "num_shards": accelerator.num_processes}


def precision_to_dtype(precision="bf16") -> torch.dtype:
    """Convert precision string to torch dtype."""
    str_dict = {"fp16": "float16", "bf16": "bfloat16", "fp32": "float32"}
    return getattr(torch, str_dict.get(precision.lower(), "float32"))


@functools.lru_cache()
def set_logger(output_dir=None, name="diffnext", level="INFO", accelerator=None):
    """Set logger."""

    @functools.lru_cache(maxsize=None)
    def cached_log_stream(filename):
        """Register a cached filename."""
        f = open(filename, "a")
        atexit.register(f.close)
        return f

    logger = logging.getLogger(name)
    logger.propagate, _ = False, logger.setLevel(level)
    fmt = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
    formatter = logging.Formatter(fmt, datefmt="%m/%d %H:%M:%S")
    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(level), ch.setFormatter(formatter), logger.addHandler(ch)
    output_dir = "" if (accelerator and not accelerator.is_main_process) else output_dir
    if output_dir:
        os.makedirs(os.path.join(output_dir, "logs"), exist_ok=True)
        log_file = time.strftime("%Y%m%d_%H%M%S", time.localtime(time.time())) + ".log"
        fh = logging.StreamHandler(cached_log_stream(os.path.join(output_dir, "logs", log_file)))
        fh.setLevel(level), fh.setFormatter(formatter), logger.addHandler(fh)
    return accelerate.logging.MultiProcessAdapter(logger, {}) if accelerator else logger