File size: 4,235 Bytes
b6ff324 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | # ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Accelerate utilities."""
import atexit
import functools
import logging
import os
import sys
import time
import accelerate
import torch
import wandb
from diffnext.utils.omegaconf_utils import flatten_omega_conf
from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
def build_accelerator(config, **kwargs) -> accelerate.Accelerator:
"""Build accelerator."""
kwargs_handlers = []
# 对普通 DDP 开启 unused param 检测
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
kwargs_handlers.append(ddp_kwargs)
accelerator = accelerate.Accelerator(
log_with=kwargs.get("log_with", None),
mixed_precision=config.training.mixed_precision,
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
kwargs_handlers=kwargs_handlers,
)
if hasattr(accelerator.state.deepspeed_plugin, "deepspeed_config"):
import deepspeed
deepspeed.logger.setLevel(kwargs.get("deepspeed_log_lvl", "WARNING"))
# Dummy size to avoid the raised errors.
accelerator.state.deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] = 1
return accelerator
def build_wandb(config, accelerator):
"""Build wandb for accelerator."""
if "wandb" not in config or not accelerator.is_main_process:
return
config.wandb = config.wandb or type(config)({})
old_run_id = config.wandb.get("run_id", None)
config.wandb.run_id = run_id = old_run_id or wandb.util.generate_id()
init_kwargs = dict(id=run_id, name=config.experiment.name, resume=old_run_id is not None)
init_kwargs["config"] = {k: v for k, v in flatten_omega_conf(config, True)}
accelerator.init_trackers(config.experiment.project, init_kwargs={"wandb": init_kwargs})
def get_ddp_shards(accelerator) -> dict:
"""Return the shard arguments for simple DDP."""
return {"shard_id": accelerator.process_index, "num_shards": accelerator.num_processes}
def precision_to_dtype(precision="bf16") -> torch.dtype:
"""Convert precision string to torch dtype."""
str_dict = {"fp16": "float16", "bf16": "bfloat16", "fp32": "float32"}
return getattr(torch, str_dict.get(precision.lower(), "float32"))
@functools.lru_cache()
def set_logger(output_dir=None, name="diffnext", level="INFO", accelerator=None):
"""Set logger."""
@functools.lru_cache(maxsize=None)
def cached_log_stream(filename):
"""Register a cached filename."""
f = open(filename, "a")
atexit.register(f.close)
return f
logger = logging.getLogger(name)
logger.propagate, _ = False, logger.setLevel(level)
fmt = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
formatter = logging.Formatter(fmt, datefmt="%m/%d %H:%M:%S")
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(level), ch.setFormatter(formatter), logger.addHandler(ch)
output_dir = "" if (accelerator and not accelerator.is_main_process) else output_dir
if output_dir:
os.makedirs(os.path.join(output_dir, "logs"), exist_ok=True)
log_file = time.strftime("%Y%m%d_%H%M%S", time.localtime(time.time())) + ".log"
fh = logging.StreamHandler(cached_log_stream(os.path.join(output_dir, "logs", log_file)))
fh.setLevel(level), fh.setFormatter(formatter), logger.addHandler(fh)
return accelerate.logging.MultiProcessAdapter(logger, {}) if accelerator else logger
|