linaa98 commited on
Commit
762f6ed
·
verified ·
1 Parent(s): 58278de

Upload 20 files

Browse files
utils/.DS_Store ADDED
Binary file (8.2 kB). View file
 
utils/__init__.py ADDED
File without changes
utils/io/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2021/5/17
3
+ # @Author : Lart Pang
4
+ # @GitHub : https://github.com/lartpang
5
+
6
+ from .image import read_color_array, read_gray_array
7
+ from .params import load_weight, save_weight
utils/io/image.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2021/5/17
3
+ # @Author : Lart Pang
4
+ # @GitHub : https://github.com/lartpang
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from utils.ops import minmax
9
+
10
+
11
+ def read_gray_array(path, div_255=False, to_normalize=False, thr=-1, dtype=np.float32) -> np.ndarray:
12
+ """
13
+ 1. read the binary image with the suffix `.jpg` or `.png`
14
+ into a grayscale ndarray
15
+ 2. (to_normalize=True) rescale the ndarray to [0, 1]
16
+ 3. (thr >= 0) binarize the ndarray with `thr`
17
+ 4. return a gray ndarray (np.float32)
18
+ """
19
+ assert path.endswith(".jpg") or path.endswith(".png"), path
20
+ assert not div_255 or not to_normalize, path
21
+ gray_array = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
22
+ assert gray_array is not None, f"Image Not Found: {path}"
23
+
24
+ if div_255:
25
+ gray_array = gray_array / 255
26
+
27
+ if to_normalize:
28
+ gray_array = minmax(gray_array, up_bound=255)
29
+
30
+ if thr >= 0:
31
+ gray_array = gray_array > thr
32
+
33
+ return gray_array.astype(dtype)
34
+
35
+
36
+ def read_color_array(path: str):
37
+ assert path.endswith(".jpg") or path.endswith(".png")
38
+ bgr_array = cv2.imread(path, cv2.IMREAD_COLOR)
39
+ assert bgr_array is not None, f"Image Not Found: {path}"
40
+ rgb_array = cv2.cvtColor(bgr_array, cv2.COLOR_BGR2RGB)
41
+ return rgb_array
utils/io/params.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2020/12/19
3
+ # @Author : Lart Pang
4
+ # @GitHub : https://github.com/lartpang
5
+
6
+ import os
7
+
8
+ import torch
9
+
10
+
11
+ def save_weight(save_path, model):
12
+ print(f"Saving weight '{save_path}'")
13
+ if isinstance(model, dict):
14
+ model_state = model
15
+ else:
16
+ model_state = model.module.state_dict() if hasattr(model, "module") else model.state_dict()
17
+ torch.save(model_state, save_path)
18
+ print(f"Saved weight '{save_path}' " f"(only contain the net's weight)")
19
+
20
+
21
+ def load_weight(load_path, model, *, strict=True, skip_unmatched_shape=False):
22
+ assert os.path.exists(load_path), load_path
23
+ model_params = model.state_dict()
24
+ for k, v in torch.load(load_path, map_location="cpu").items():
25
+ if k.endswith("module."):
26
+ k = k[7:]
27
+ if skip_unmatched_shape and k in model_params and v.shape != model_params[k].shape:
28
+ continue
29
+ model_params[k] = v
30
+ model.load_state_dict(model_params, strict=strict)
utils/ops/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2020/12/19
3
+ # @Author : Lart Pang
4
+ # @GitHub : https://github.com/lartpang
5
+
6
+
7
+ from .array_ops import *
8
+ from .tensor_ops import *
utils/ops/array_ops.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+
8
+ def minmax(data_array: np.ndarray, up_bound: float = None) -> np.ndarray:
9
+ """
10
+ ::
11
+
12
+ data_array = (data_array / up_bound)
13
+ if min_value != max_value:
14
+ data_array = (data_array - min_value) / (max_value - min_value)
15
+
16
+ :param data_array:
17
+ :param up_bound: if is not None, data_array will devided by it before the minmax ops.
18
+ :return:
19
+ """
20
+ if up_bound is not None:
21
+ data_array = data_array / up_bound
22
+ max_value = data_array.max()
23
+ min_value = data_array.min()
24
+ if max_value != min_value:
25
+ data_array = (data_array - min_value) / (max_value - min_value)
26
+ return data_array
27
+
28
+
29
+ def clip_to_normalize(data_array: np.ndarray, clip_range: tuple = None) -> np.ndarray:
30
+ clip_range = sorted(clip_range)
31
+ if len(clip_range) == 3:
32
+ clip_min, clip_mid, clip_max = clip_range
33
+ assert 0 <= clip_min < clip_mid < clip_max <= 1, clip_range
34
+ lower_array = data_array[data_array < clip_mid]
35
+ higher_array = data_array[data_array > clip_mid]
36
+ if lower_array.size > 0:
37
+ lower_array = np.clip(lower_array, a_min=clip_min, a_max=1)
38
+ max_lower = lower_array.max()
39
+ lower_array = minmax(lower_array) * max_lower
40
+ data_array[data_array < clip_mid] = lower_array
41
+ if higher_array.size > 0:
42
+ higher_array = np.clip(higher_array, a_min=0, a_max=clip_max)
43
+ min_lower = higher_array.min()
44
+ higher_array = minmax(higher_array) * (1 - min_lower) + min_lower
45
+ data_array[data_array > clip_mid] = higher_array
46
+ elif len(clip_range) == 2:
47
+ clip_min, clip_max = clip_range
48
+ assert 0 <= clip_min < clip_max <= 1, clip_range
49
+ if clip_min != 0 and clip_max != 1:
50
+ data_array = np.clip(data_array, a_min=clip_min, a_max=clip_max)
51
+ data_array = minmax(data_array)
52
+ elif clip_range is None:
53
+ data_array = minmax(data_array)
54
+ else:
55
+ raise NotImplementedError
56
+ return data_array
57
+
58
+
59
+ def save_array_as_image(data_array: np.ndarray, save_name: str, save_dir: str, to_minmax: bool = False):
60
+ """
61
+ save the ndarray as a image
62
+
63
+ Args:
64
+ data_array: np.float32 the max value is less than or equal to 1
65
+ save_name: with special suffix
66
+ save_dir: the dirname of the image path
67
+ to_minmax: minmax the array
68
+ """
69
+ if not os.path.exists(save_dir):
70
+ os.makedirs(save_dir)
71
+ save_path = os.path.join(save_dir, save_name)
72
+ if data_array.dtype != np.uint8:
73
+ if data_array.max() > 1:
74
+ raise Exception("the range of data_array has smoe errors")
75
+ data_array = (data_array * 255).astype(np.uint8)
76
+ if to_minmax:
77
+ data_array = minmax(data_array, up_bound=255)
78
+ data_array = (data_array * 255).astype(np.uint8)
79
+ cv2.imwrite(save_path, data_array)
80
+
81
+
82
+ def resize(image_array: np.ndarray, height, width, interpolation=cv2.INTER_LINEAR):
83
+ h, w = image_array.shape[:2]
84
+ if h == height and w == width:
85
+ return image_array
86
+
87
+ resized_image_array = cv2.resize(image_array, dsize=(width, height), interpolation=interpolation)
88
+ return resized_image_array
89
+
90
+
91
+ def ms_resize(img, scales, base_h=None, base_w=None, interpolation=cv2.INTER_LINEAR):
92
+ assert isinstance(scales, (list, tuple))
93
+ if base_h is None:
94
+ base_h = img.shape[0]
95
+ if base_w is None:
96
+ base_w = img.shape[1]
97
+ return [resize(img, height=int(base_h * s), width=int(base_w * s), interpolation=interpolation) for s in scales]
utils/ops/tensor_ops.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2020
3
+ # @Author : Lart Pang
4
+ # @GitHub : https://github.com/lartpang
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def rescale_2x(x: torch.Tensor, scale_factor=2):
10
+ return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
11
+
12
+
13
+ def resize_to(x: torch.Tensor, tgt_hw: tuple):
14
+ return F.interpolate(x, size=tgt_hw, mode="bilinear", align_corners=False)
15
+
16
+
17
+ def clip_grad(params, mode, clip_cfg: dict):
18
+ if mode == "norm":
19
+ if "max_norm" not in clip_cfg:
20
+ raise ValueError("`clip_cfg` must contain `max_norm`.")
21
+ torch.nn.utils.clip_grad_norm_(
22
+ params,
23
+ max_norm=clip_cfg.get("max_norm"),
24
+ norm_type=clip_cfg.get("norm_type", 2.0),
25
+ )
26
+ elif mode == "value":
27
+ if "clip_value" not in clip_cfg:
28
+ raise ValueError("`clip_cfg` must contain `clip_value`.")
29
+ torch.nn.utils.clip_grad_value_(params, clip_value=clip_cfg.get("clip_value"))
30
+ else:
31
+ raise NotImplementedError
utils/pipeline/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2021/5/31
3
+ # @Author : Lart Pang
4
+ # @GitHub : https://github.com/lartpang
5
+
6
+ from .optimizer import construct_optimizer
7
+ from .scaler import Scaler
8
+ from .scheduler import Scheduler
utils/pipeline/optimizer.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2020/12/19
3
+ # @Author : Lart Pang
4
+ # @GitHub : https://github.com/lartpang
5
+ import types
6
+
7
+ import torchvision.models
8
+ from torch import nn
9
+ from torch.optim import SGD, Adam, AdamW
10
+
11
+
12
+ def get_optimizer(mode, params, initial_lr, optim_cfg):
13
+ if mode == "sgd":
14
+ optimizer = SGD(
15
+ params=params,
16
+ lr=initial_lr,
17
+ momentum=optim_cfg["momentum"],
18
+ weight_decay=optim_cfg["weight_decay"],
19
+ nesterov=optim_cfg.get("nesterov", False),
20
+ )
21
+ elif mode == "adamw":
22
+ optimizer = AdamW(
23
+ params=params,
24
+ lr=initial_lr,
25
+ betas=optim_cfg.get("betas", (0.9, 0.999)),
26
+ weight_decay=optim_cfg.get("weight_decay", 0),
27
+ amsgrad=optim_cfg.get("amsgrad", False),
28
+ )
29
+ elif mode == "adam":
30
+ optimizer = Adam(
31
+ params=params,
32
+ lr=initial_lr,
33
+ betas=optim_cfg.get("betas", (0.9, 0.999)),
34
+ weight_decay=optim_cfg.get("weight_decay", 0),
35
+ amsgrad=optim_cfg.get("amsgrad", False),
36
+ )
37
+ else:
38
+ raise NotImplementedError(mode)
39
+ return optimizer
40
+
41
+
42
+ def group_params(model: nn.Module, group_mode: str, initial_lr: float, optim_cfg: dict):
43
+ if group_mode == "yolov5":
44
+ """
45
+ norm, weight, bias = [], [], [] # optimizer parameter groups
46
+ for k, v in model.named_modules():
47
+ if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
48
+ bias.append(v.bias) # biases
49
+ if isinstance(v, nn.BatchNorm2d):
50
+ norm.append(v.weight) # no decay
51
+ elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
52
+ weight.append(v.weight) # apply decay
53
+
54
+ if opt.adam:
55
+ optimizer = optim.Adam(norm, lr=hyp["lr0"], betas=(hyp["momentum"], 0.999)) # adjust beta1 to momentum
56
+ else:
57
+ optimizer = optim.SGD(norm, lr=hyp["lr0"], momentum=hyp["momentum"], nesterov=True)
58
+
59
+ optimizer.add_param_group({"params": weight, "weight_decay": hyp["weight_decay"]}) # add weight with weight_decay
60
+ optimizer.add_param_group({"params": bias}) # add bias (biases)
61
+ """
62
+ norm, weight, bias = [], [], [] # optimizer parameter groups
63
+ for k, v in model.named_modules():
64
+ if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
65
+ bias.append(v.bias) # conv bias and bn bias
66
+ if isinstance(v, nn.BatchNorm2d):
67
+ norm.append(v.weight) # bn weight
68
+ elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
69
+ weight.append(v.weight) # conv weight
70
+ params = [
71
+ {"params": filter(lambda p: p.requires_grad, bias), "weight_decay": 0.0},
72
+ {"params": filter(lambda p: p.requires_grad, norm), "weight_decay": 0.0},
73
+ {"params": filter(lambda p: p.requires_grad, weight)},
74
+ ]
75
+ elif group_mode == "r3":
76
+ params = [
77
+ # 不对bias参数执行weight decay操作,weight decay主要的作用就是通过对网络
78
+ # 层的参数(包括weight和bias)做约束(L2正则化会使得网络层的参数更加平滑)达
79
+ # 到减少模型过拟合的效果。
80
+ {
81
+ "params": [
82
+ param for name, param in model.named_parameters() if name[-4:] == "bias" and param.requires_grad
83
+ ],
84
+ "lr": 2 * initial_lr,
85
+ "weight_decay": 0,
86
+ },
87
+ {
88
+ "params": [
89
+ param for name, param in model.named_parameters() if name[-4:] != "bias" and param.requires_grad
90
+ ],
91
+ "lr": initial_lr,
92
+ "weight_decay": optim_cfg["weight_decay"],
93
+ },
94
+ ]
95
+ elif group_mode == "all":
96
+ params = model.parameters()
97
+ elif group_mode == "finetune":
98
+ if hasattr(model, "module"):
99
+ model = model.module
100
+ assert hasattr(model, "get_grouped_params"), "Cannot get the method get_grouped_params of the model."
101
+ params_groups = model.get_grouped_params()
102
+ params = [
103
+ {
104
+ "params": filter(lambda p: p.requires_grad, params_groups["pretrained"]),
105
+ "lr": optim_cfg.get("diff_factor", 0.1) * initial_lr,
106
+ },
107
+ {
108
+ "params": filter(lambda p: p.requires_grad, params_groups["retrained"]),
109
+ "lr": initial_lr,
110
+ },
111
+ ]
112
+ elif group_mode == "finetune2":
113
+ if hasattr(model, "module"):
114
+ model = model.module
115
+ assert hasattr(model, "get_grouped_params"), "Cannot get the method get_grouped_params of the model."
116
+ params_groups = model.get_grouped_params()
117
+ params = [
118
+ {
119
+ "params": filter(lambda p: p.requires_grad, params_groups["pretrained_backbone"]),
120
+ "lr": 0.1 * initial_lr,
121
+ },
122
+ {
123
+ "params": filter(lambda p: p.requires_grad, params_groups["pretrained_siamese"]),
124
+ "lr": 0.5 * initial_lr,
125
+ },
126
+ {
127
+ "params": filter(lambda p: p.requires_grad, params_groups["retrained"]),
128
+ "lr": initial_lr,
129
+ },
130
+ ]
131
+ else:
132
+ raise NotImplementedError
133
+ return params
134
+
135
+
136
+ def construct_optimizer(model, initial_lr, mode, group_mode, cfg):
137
+ params = group_params(model, group_mode=group_mode, initial_lr=initial_lr, optim_cfg=cfg)
138
+ optimizer = get_optimizer(mode=mode, params=params, initial_lr=initial_lr, optim_cfg=cfg)
139
+ optimizer.lr_groups = types.MethodType(get_lr_groups, optimizer)
140
+ optimizer.lr_string = types.MethodType(get_lr_strings, optimizer)
141
+ return optimizer
142
+
143
+
144
+ def get_lr_groups(self):
145
+ return [group["lr"] for group in self.param_groups]
146
+
147
+
148
+ def get_lr_strings(self):
149
+ return ",".join([f"{group['lr']:.3e}" for group in self.param_groups])
150
+
151
+
152
+ if __name__ == "__main__":
153
+ model = torchvision.models.vgg11_bn()
154
+ norm, weight, bias = [], [], [] # optimizer parameter groups
155
+ for k, v in model.named_modules():
156
+ if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
157
+ bias.append(v.bias) # biases
158
+ if isinstance(v, nn.BatchNorm2d):
159
+ norm.append(v.weight) # no decay
160
+ elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
161
+ weight.append(v.weight) # apply decay
162
+
163
+ optimizer = Adam(norm, lr=0.001, betas=(0.98, 0.999)) # adjust beta1 to momentum
164
+ # optimizer = optim.SGD(norm, lr=hyp["lr0"], momentum=hyp["momentum"], nesterov=True)
165
+
166
+ optimizer.add_param_group({"params": weight, "weight_decay": 1e-4}) # add weight with weight_decay
167
+ optimizer.add_param_group({"params": bias}) # add bias (biases)
168
+
169
+ print(optimizer)
utils/pipeline/scaler.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from itertools import chain
3
+
4
+ from torch.cuda.amp import GradScaler, autocast
5
+
6
+ from .. import ops
7
+
8
+
9
+ class Scaler:
10
+ def __init__(
11
+ self, optimizer, use_fp16=False, *, set_to_none=False, clip_grad=False, clip_mode=None, clip_cfg=None
12
+ ) -> None:
13
+ self.optimizer = optimizer
14
+ self.set_to_none = set_to_none
15
+ self.autocast = autocast(enabled=use_fp16)
16
+ self.scaler = GradScaler(enabled=use_fp16)
17
+
18
+ if clip_grad:
19
+ self.grad_clip_ops = partial(ops.clip_grad, mode=clip_mode, clip_cfg=clip_cfg)
20
+ else:
21
+ self.grad_clip_ops = None
22
+
23
+ def calculate_grad(self, loss):
24
+ self.scaler.scale(loss).backward()
25
+ if self.grad_clip_ops is not None:
26
+ self.scaler.unscale_(self.optimizer)
27
+ self.grad_clip_ops(chain(*[group["params"] for group in self.optimizer.param_groups]))
28
+
29
+ def update_grad(self):
30
+ self.scaler.step(self.optimizer)
31
+ self.scaler.update()
32
+ self.optimizer.zero_grad(set_to_none=self.set_to_none)
33
+
34
+ def state_dict(self):
35
+ r"""
36
+ Returns the state of the scaler as a :class:`dict`. It contains five entries:
37
+
38
+ * ``"scale"`` - a Python float containing the current scale
39
+ * ``"growth_factor"`` - a Python float containing the current growth factor
40
+ * ``"backoff_factor"`` - a Python float containing the current backoff factor
41
+ * ``"growth_interval"`` - a Python int containing the current growth interval
42
+ * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
43
+
44
+ If this instance is not enabled, returns an empty dict.
45
+
46
+ .. note::
47
+ If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
48
+ should be called after :meth:`update`.
49
+ """
50
+ return self.scaler.state_dict()
51
+
52
+ def load_state_dict(self, state_dict):
53
+ r"""
54
+ Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
55
+
56
+ Args:
57
+ state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
58
+ """
59
+ self.scaler.load_state_dict(state_dict)
utils/pipeline/scheduler.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2020/12/19
3
+ # @Author : Lart Pang
4
+ # @GitHub : https://github.com/lartpang
5
+
6
+ import copy
7
+ import math
8
+ import os.path
9
+ import warnings
10
+ from bisect import bisect_right
11
+
12
+ import matplotlib
13
+ import numpy as np
14
+ import torch.optim
15
+ from adjustText import adjust_text
16
+
17
+ matplotlib.use("Agg")
18
+ from matplotlib import pyplot as plt
19
+
20
+ # helper function ----------------------------------------------------------------------
21
+
22
+
23
+ def linear_increase(low_bound, up_bound, percentage):
24
+ """low_bound + [0, 1] * (up_bound - low_bound)"""
25
+ assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]"
26
+ return low_bound + (up_bound - low_bound) * percentage
27
+
28
+
29
+ def cos_anneal(low_bound, up_bound, percentage):
30
+ assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]"
31
+ cos_percentage = (1 + math.cos(math.pi * percentage)) / 2.0
32
+ return linear_increase(low_bound, up_bound, percentage=cos_percentage)
33
+
34
+
35
+ def poly_anneal(low_bound, up_bound, percentage, lr_decay):
36
+ assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]"
37
+ poly_percentage = pow((1 - percentage), lr_decay)
38
+ return linear_increase(low_bound, up_bound, percentage=poly_percentage)
39
+
40
+
41
+ def linear_anneal(low_bound, up_bound, percentage):
42
+ assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]"
43
+ return linear_increase(low_bound, up_bound, percentage=1 - percentage)
44
+
45
+
46
+ # coefficient function ----------------------------------------------------------------------
47
+
48
+
49
+ def get_f3_coef_func(num_iters):
50
+ """
51
+ F3Net
52
+
53
+ :param num_iters: The number of iterations for the total process.
54
+ :return:
55
+ """
56
+
57
+ def get_f3_coef(curr_idx):
58
+ assert 0 <= curr_idx <= num_iters
59
+ return 1 - abs((curr_idx + 1) / (num_iters + 1) * 2 - 1)
60
+
61
+ return get_f3_coef
62
+
63
+
64
+ def get_step_coef_func(gamma, milestones):
65
+ """
66
+ lr = baselr * gamma ** 0 if curr_idx < milestones[0]
67
+ lr = baselr * gamma ** 1 if milestones[0] <= epoch < milestones[1]
68
+ ...
69
+
70
+ :param gamma:
71
+ :param milestones:
72
+ :return: The function for generating the coefficient.
73
+ """
74
+ if isinstance(milestones, (tuple, list)):
75
+ milestones = list(sorted(milestones))
76
+ return lambda curr_idx: gamma ** bisect_right(milestones, curr_idx)
77
+ elif isinstance(milestones, int):
78
+ return lambda curr_idx: gamma ** ((curr_idx + 1) // milestones)
79
+ else:
80
+ raise ValueError(f"milestones only can be list/tuple/int, but now it is {type(milestones)}")
81
+
82
+
83
+ def get_cos_coef_func(half_cycle, min_coef, max_coef=1):
84
+ """
85
+ :param half_cycle: The number of iterations in a half cycle.
86
+ :param min_coef: The minimum coefficient of the learning rate.
87
+ :param max_coef: The maximum coefficient of the learning rate.
88
+ :return: The function for generating the coefficient.
89
+ """
90
+
91
+ def get_cos_coef(curr_idx):
92
+ recomputed_idx = curr_idx % (half_cycle + 1)
93
+ # recomputed \in [0, half_cycle]
94
+ return cos_anneal(low_bound=min_coef, up_bound=max_coef, percentage=recomputed_idx / half_cycle)
95
+
96
+ return get_cos_coef
97
+
98
+
99
+ def get_fatcos_coef_func(start_iter, half_cycle, min_coef, max_coef=1):
100
+ """
101
+ :param half_cycle: The number of iterations in a half cycle.
102
+ :param min_coef: The minimum coefficient of the learning rate.
103
+ :param max_coef: The maximum coefficient of the learning rate.
104
+ :return: The function for generating the coefficient.
105
+ """
106
+
107
+ def get_cos_coef(curr_idx):
108
+ curr_idx = max(0, curr_idx - start_iter)
109
+ recomputed_idx = curr_idx % (half_cycle + 1)
110
+ # recomputed \in [0, half_cycle]
111
+ return cos_anneal(low_bound=min_coef, up_bound=max_coef, percentage=recomputed_idx / half_cycle)
112
+
113
+ return get_cos_coef
114
+
115
+
116
+ def get_poly_coef_func(num_iters, lr_decay, min_coef, max_coef=1):
117
+ """
118
+ :param num_iters: The number of iterations for the polynomial descent process.
119
+ :param lr_decay: The decay item of the polynomial descent process.
120
+ :param min_coef: The minimum coefficient of the learning rate.
121
+ :param max_coef: The maximum coefficient of the learning rate.
122
+ :return: The function for generating the coefficient.
123
+ """
124
+
125
+ def get_poly_coef(curr_idx):
126
+ assert 0 <= curr_idx <= num_iters, (curr_idx, num_iters)
127
+ return poly_anneal(low_bound=min_coef, up_bound=max_coef, percentage=curr_idx / num_iters, lr_decay=lr_decay)
128
+
129
+ return get_poly_coef
130
+
131
+
132
+ # coefficient entry function ----------------------------------------------------------------------
133
+
134
+
135
+ def get_scheduler_coef_func(mode, num_iters, cfg):
136
+ """
137
+ the region is a closed interval: [0, num_iters]
138
+ """
139
+ assert num_iters > 0
140
+ min_coef = cfg.get("min_coef", 1e-6)
141
+ if min_coef is None or min_coef == 0:
142
+ warnings.warn(f"The min_coef ({min_coef}) of the scheduler will be replaced with 1e-6")
143
+ min_coef = 1e-6
144
+
145
+ if mode == "step":
146
+ coef_func = get_step_coef_func(gamma=cfg["gamma"], milestones=cfg["milestones"])
147
+ elif mode == "cos":
148
+ if half_cycle := cfg.get("half_cycle"):
149
+ half_cycle -= 1
150
+ else:
151
+ half_cycle = num_iters
152
+ if (num_iters - half_cycle) % (half_cycle + 1) != 0:
153
+ # idx starts from 0
154
+ percentage = ((num_iters - half_cycle) % (half_cycle + 1)) / (half_cycle + 1) * 100
155
+ warnings.warn(
156
+ f"The final annealing process ({percentage:.3f}%) is not complete. "
157
+ f"Please pay attention to the generated 'lr_coef_curve.png'."
158
+ )
159
+ coef_func = get_cos_coef_func(half_cycle=half_cycle, min_coef=min_coef)
160
+ elif mode == "fatcos":
161
+ assert 0 <= cfg.start_percent < 1, cfg.start_percent
162
+ start_iter = int(cfg.start_percent * num_iters)
163
+
164
+ num_iters -= start_iter
165
+ if half_cycle := cfg.get("half_cycle"):
166
+ half_cycle -= 1
167
+ else:
168
+ half_cycle = num_iters
169
+ if (num_iters - half_cycle) % (half_cycle + 1) != 0:
170
+ # idx starts from 0
171
+ percentage = ((num_iters - half_cycle) % (half_cycle + 1)) / (half_cycle + 1) * 100
172
+ warnings.warn(
173
+ f"The final annealing process ({percentage:.3f}%) is not complete. "
174
+ f"Please pay attention to the generated 'lr_coef_curve.png'."
175
+ )
176
+ coef_func = get_fatcos_coef_func(start_iter=start_iter, half_cycle=half_cycle, min_coef=min_coef)
177
+ elif mode == "poly":
178
+ coef_func = get_poly_coef_func(num_iters=num_iters, lr_decay=cfg["lr_decay"], min_coef=min_coef)
179
+ elif mode == "constant":
180
+ coef_func = lambda x: cfg.get("coef", 1)
181
+ elif mode == "f3":
182
+ coef_func = get_f3_coef_func(num_iters=num_iters)
183
+ else:
184
+ raise NotImplementedError(f"{mode} must be in {Scheduler.supported_scheduler_modes}")
185
+ return coef_func
186
+
187
+
188
+ def get_warmup_coef_func(num_iters, min_coef, max_coef=1, mode="linear"):
189
+ """
190
+ the region is a closed interval: [0, num_iters]
191
+ """
192
+ assert num_iters > 0
193
+ if mode == "cos":
194
+ anneal_func = cos_anneal
195
+ elif mode == "linear":
196
+ anneal_func = linear_anneal
197
+ else:
198
+ raise NotImplementedError(f"{mode} must be in {Scheduler.supported_warmup_modes}")
199
+
200
+ def get_warmup_coef(curr_idx):
201
+ return anneal_func(low_bound=min_coef, up_bound=max_coef, percentage=1 - curr_idx / num_iters)
202
+
203
+ return get_warmup_coef
204
+
205
+
206
+ # main class ----------------------------------------------------------------------
207
+
208
+
209
+ class Scheduler:
210
+ supported_scheduler_modes = ("step", "cos", "fatcos", "poly", "constant", "f3")
211
+ supported_warmup_modes = ("cos", "linear")
212
+
213
+ def __init__(self, optimizer, num_iters, epoch_length, scheduler_cfg, step_by_batch=True):
214
+ """A customized wrapper of the scheduler.
215
+
216
+ Args:
217
+ optimizer (): Optimizer.
218
+ num_iters (int): The total number of the iterations.
219
+ epoch_length (int): The number of the iterations of one epoch.
220
+ scheduler_cfg (dict): The config of the scheduler.
221
+ step_by_batch (bool, optional): The mode of updating the scheduler. Defaults to True.
222
+
223
+ Raises:
224
+ NotImplementedError:
225
+ """
226
+ self.optimizer = optimizer
227
+ self.num_iters = num_iters
228
+ self.epoch_length = epoch_length
229
+ self.step_by_batch = step_by_batch
230
+
231
+ self.scheduler_cfg = copy.deepcopy(scheduler_cfg)
232
+ self.mode = scheduler_cfg["mode"]
233
+ if self.mode not in self.supported_scheduler_modes:
234
+ raise NotImplementedError(
235
+ f"{self.mode} is not implemented. Has been supported: {self.supported_scheduler_modes}"
236
+ )
237
+ warmup_cfg = scheduler_cfg.get("warmup", None)
238
+
239
+ num_warmup_iters = 0
240
+ if warmup_cfg is not None and isinstance(warmup_cfg, dict):
241
+ num_warmup_iters = warmup_cfg["num_iters"]
242
+ if num_warmup_iters > 0:
243
+ print("Will using warmup")
244
+ self.warmup_coef_func = get_warmup_coef_func(
245
+ num_warmup_iters,
246
+ min_coef=warmup_cfg.get("initial_coef", 0.01),
247
+ mode=warmup_cfg.get("mode", "linear"),
248
+ )
249
+ self.num_warmup_iters = num_warmup_iters
250
+
251
+ if step_by_batch:
252
+ num_scheduler_iters = num_iters - num_warmup_iters
253
+ else:
254
+ num_scheduler_iters = (num_iters - num_warmup_iters) // epoch_length
255
+ # the region is a closed interval
256
+ self.lr_coef_func = get_scheduler_coef_func(
257
+ mode=self.mode, num_iters=num_scheduler_iters - 1, cfg=scheduler_cfg["cfg"]
258
+ )
259
+ self.num_scheduler_iters = num_scheduler_iters
260
+
261
+ self.last_lr_coef = 0
262
+ self.initial_lrs = None
263
+
264
+ def __repr__(self):
265
+ formatted_string = [
266
+ f"{self.__class__.__name__}: (\n",
267
+ f"num_iters: {self.num_iters}\n",
268
+ f"epoch_length: {self.epoch_length}\n",
269
+ f"warmup_iter: [0, {self.num_warmup_iters})\n",
270
+ f"scheduler_iter: [{self.num_warmup_iters}, {self.num_iters - 1}]\n",
271
+ f"mode: {self.mode}\n",
272
+ f"scheduler_cfg: {self.scheduler_cfg}\n",
273
+ f"initial_lrs: {self.initial_lrs}\n",
274
+ f"step_by_batch: {self.step_by_batch}\n)",
275
+ ]
276
+ return " ".join(formatted_string)
277
+
278
+ def record_lrs(self, param_groups):
279
+ self.initial_lrs = [g["lr"] for g in param_groups]
280
+
281
+ def update(self, coef: float):
282
+ assert self.initial_lrs is not None, "Please run .record_lrs(optimizer) first."
283
+ for curr_group, initial_lr in zip(self.optimizer.param_groups, self.initial_lrs):
284
+ curr_group["lr"] = coef * initial_lr
285
+
286
+ def step(self, curr_idx):
287
+ if curr_idx < self.num_warmup_iters:
288
+ # get maximum value (1.0) when curr_idx == self.num_warmup_iters
289
+ self.update(coef=self.get_lr_coef(curr_idx))
290
+ else:
291
+ # Start from a value lower than 1 (curr_idx == self.num_warmup_iters)
292
+ if self.step_by_batch:
293
+ self.update(coef=self.get_lr_coef(curr_idx))
294
+ else:
295
+ if curr_idx % self.epoch_length == 0:
296
+ self.update(coef=self.get_lr_coef(curr_idx))
297
+
298
+ def get_lr_coef(self, curr_idx):
299
+ coef = None
300
+ if curr_idx < self.num_warmup_iters:
301
+ coef = self.warmup_coef_func(curr_idx)
302
+ else:
303
+ # when curr_idx == self.num_warmup_iters, coef == 1.0
304
+ # down from the largest coef (1.0)
305
+ if self.step_by_batch:
306
+ coef = self.lr_coef_func(curr_idx - self.num_warmup_iters)
307
+ else:
308
+ if curr_idx % self.epoch_length == 0 or curr_idx == self.num_warmup_iters:
309
+ # warmup结束后尚未开始按照epoch进行调整的学习率调整,此时需要将系数调整为最大。
310
+ coef = self.lr_coef_func((curr_idx - self.num_warmup_iters) // self.epoch_length)
311
+ if coef is not None:
312
+ self.last_lr_coef = coef
313
+ return self.last_lr_coef
314
+
315
+ def plot_lr_coef_curve(self, save_path=""):
316
+ plt.rc("xtick", labelsize="small")
317
+ plt.rc("ytick", labelsize="small")
318
+
319
+ fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(8, 4), dpi=600)
320
+ # give plot a title
321
+ ax.set_title("Learning Rate Coefficient Curve")
322
+ # make axis labels
323
+ ax.set_xlabel("Iteration")
324
+ ax.set_ylabel("Coefficient")
325
+
326
+ x_data = np.arange(self.num_iters)
327
+ y_data = np.array([self.get_lr_coef(x) for x in x_data])
328
+
329
+ # set lim
330
+ x_min, x_max = 0, self.num_iters - 1
331
+ dx = self.num_iters * 0.1
332
+ ax.set_xlim(x_min - dx, x_max + 2 * dx)
333
+
334
+ y_min, y_max = y_data.min(), y_data.max()
335
+ dy = (y_data.max() - y_data.min()) * 0.1
336
+ ax.set_ylim((y_min - dy, y_max + dy))
337
+
338
+ if self.step_by_batch:
339
+ marker_on = [0, -1]
340
+ key_point_xs = [0, self.num_iters - 1]
341
+ for idx in range(1, len(y_data) - 1):
342
+ prev_y = y_data[idx - 1]
343
+ curr_y = y_data[idx]
344
+ next_y = y_data[idx + 1]
345
+ if (
346
+ (curr_y > prev_y and curr_y >= next_y)
347
+ or (curr_y >= prev_y and curr_y > next_y)
348
+ or (curr_y <= prev_y and curr_y < next_y)
349
+ or (curr_y < prev_y and curr_y <= next_y)
350
+ ):
351
+ marker_on.append(idx)
352
+ key_point_xs.append(idx)
353
+
354
+ marker_on = sorted(set(marker_on))
355
+ key_point_xs = sorted(set(key_point_xs))
356
+ key_point_ys = []
357
+
358
+ texts = []
359
+ for x in key_point_xs:
360
+ y = y_data[x]
361
+ key_point_ys.append(y)
362
+
363
+ texts.append(ax.text(x=x, y=y, s=f"({x:d},{y:.3e})"))
364
+ adjust_text(texts, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.3"))
365
+
366
+ # set ticks
367
+ ax.set_xticks(key_point_xs)
368
+ # ax.set_yticks(key_point_ys)
369
+
370
+ ax.plot(x_data, y_data, marker="o", markevery=marker_on)
371
+ else:
372
+ ax.plot(x_data, y_data)
373
+
374
+ ax.spines["right"].set_visible(False)
375
+ ax.spines["top"].set_visible(False)
376
+ ax.spines["left"].set_visible(True)
377
+ ax.spines["bottom"].set_visible(True)
378
+
379
+ plt.tight_layout()
380
+ if save_path:
381
+ fig.savefig(os.path.join(save_path, "lr_coef.png"))
382
+ plt.close()
383
+
384
+
385
+ if __name__ == "__main__":
386
+ model = torch.nn.Conv2d(10, 10, 3, 1, 1)
387
+ sche = Scheduler(
388
+ optimizer=torch.optim.SGD(model.parameters(), lr=0.1),
389
+ num_iters=30300,
390
+ epoch_length=505,
391
+ scheduler_cfg=dict(
392
+ warmup=dict(
393
+ num_iters=6060,
394
+ initial_coef=0.01,
395
+ mode="cos",
396
+ ),
397
+ mode="cos",
398
+ cfg=dict(
399
+ half_cycle=6060,
400
+ lr_decay=0.9,
401
+ min_coef=0.001,
402
+ ),
403
+ ),
404
+ step_by_batch=True,
405
+ )
406
+ print(sche)
407
+ sche.plot_lr_coef_curve(
408
+ # save_path="/home/lart/Coding/SOD.torch",
409
+ show=True,
410
+ )
utils/pt_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ from torch.backends import cudnn
9
+
10
+ LOGGER = logging.getLogger("main")
11
+
12
+
13
+ def customized_worker_init_fn(worker_id):
14
+ worker_seed = torch.initial_seed() % 2**32
15
+ np.random.seed(worker_seed)
16
+
17
+
18
+ def set_seed_for_lib(seed):
19
+ random.seed(seed)
20
+ np.random.seed(seed)
21
+
22
+ os.environ["PYTHONHASHSEED"] = str(seed)
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed(seed)
25
+ torch.cuda.manual_seed_all(seed)
26
+
27
+
28
+ def initialize_seed_cudnn(seed, deterministic):
29
+ assert isinstance(deterministic, bool) and isinstance(seed, int)
30
+ if seed >= 0:
31
+ LOGGER.info(f"We will use a fixed seed {seed}")
32
+ else:
33
+ seed = np.random.randint(2**32)
34
+ LOGGER.info(f"We will use a random seed {seed}")
35
+ set_seed_for_lib(seed)
36
+ if not deterministic:
37
+ LOGGER.info("We will use `torch.backends.cudnn.benchmark`")
38
+ else:
39
+ LOGGER.info("We will not use `torch.backends.cudnn.benchmark`")
40
+ cudnn.enabled = True
41
+ cudnn.benchmark = not deterministic
42
+ cudnn.deterministic = deterministic
43
+
44
+
45
+ def to_device(data, device="cuda"):
46
+ if isinstance(data, (tuple, list)):
47
+ return [to_device(item, device) for item in data]
48
+ elif isinstance(data, dict):
49
+ return {name: to_device(item, device) for name, item in data.items()}
50
+ elif isinstance(data, torch.Tensor):
51
+ return data.to(device=device, non_blocking=True)
52
+ else:
53
+ raise TypeError(f"Unsupported type {type(data)}. Only support Tensor or tuple/list/dict containing Tensors.")
54
+
55
+
56
+ def frozen_bn_stats(model, freeze_affine=False):
57
+ """
58
+ Set all the bn layers to eval mode.
59
+ Args:
60
+ model (model): model to set bn layers to eval mode.
61
+ """
62
+ for m in model.modules():
63
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
64
+ m.eval()
65
+ if freeze_affine:
66
+ m.requires_grad_(False)
utils/py_utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import copy
3
+ import logging
4
+ import os
5
+ import shutil
6
+ from collections import OrderedDict, abc
7
+ from datetime import datetime
8
+
9
+ LOGGER = logging.getLogger("main")
10
+
11
+
12
+ def construct_path(output_dir: str, exp_name: str) -> dict:
13
+ proj_root = os.path.join(output_dir, exp_name)
14
+ exp_idx = 0
15
+ exp_output_dir = os.path.join(proj_root, f"exp_{exp_idx}")
16
+ while os.path.exists(exp_output_dir):
17
+ exp_idx += 1
18
+ exp_output_dir = os.path.join(proj_root, f"exp_{exp_idx}")
19
+
20
+ tb_path = os.path.join(exp_output_dir, "tb")
21
+ save_path = os.path.join(exp_output_dir, "pre")
22
+ pth_path = os.path.join(exp_output_dir, "pth")
23
+
24
+ final_full_model_path = os.path.join(pth_path, "checkpoint_final.pth")
25
+ final_state_path = os.path.join(pth_path, "state_final.pth")
26
+
27
+ log_path = os.path.join(exp_output_dir, f"log_{str(datetime.now())[:10]}.txt")
28
+ cfg_copy_path = os.path.join(exp_output_dir, f"config.py")
29
+ trainer_copy_path = os.path.join(exp_output_dir, f"trainer.txt")
30
+ excel_path = os.path.join(exp_output_dir, f"results.xlsx")
31
+
32
+ path_config = {
33
+ "output_dir": output_dir,
34
+ "pth_log": exp_output_dir,
35
+ "tb": tb_path,
36
+ "save": save_path,
37
+ "pth": pth_path,
38
+ "final_full_net": final_full_model_path,
39
+ "final_state_net": final_state_path,
40
+ "log": log_path,
41
+ "cfg_copy": cfg_copy_path,
42
+ "excel": excel_path,
43
+ "trainer_copy": trainer_copy_path,
44
+ }
45
+
46
+ return path_config
47
+
48
+
49
+ def construct_exp_name(model_name: str, cfg: dict):
50
+ # bs_16_lr_0.05_e30_noamp_2gpu_noms_352
51
+ focus_item = OrderedDict(
52
+ {
53
+ "train/batch_size": "bs",
54
+ "train/lr": "lr",
55
+ "train/num_epochs": "e",
56
+ "train/num_iters": "i",
57
+ "train/data/shape/h": "h",
58
+ "train/data/shape/w": "w",
59
+ "train/optimizer/mode": "opm",
60
+ "train/optimizer/group_mode": "opgm",
61
+ "train/scheduler/mode": "sc",
62
+ "train/scheduler/warmup/num_iters": "wu",
63
+ "train/use_amp": "amp",
64
+ }
65
+ )
66
+ config = copy.deepcopy(cfg)
67
+
68
+ def _format_item(_i):
69
+ if isinstance(_i, bool):
70
+ _i = "" if _i else "false"
71
+ elif isinstance(_i, (int, float)):
72
+ if _i == 0:
73
+ _i = "false"
74
+ elif isinstance(_i, (list, tuple)):
75
+ _i = "" if _i else "false" # 只是判断是否非空
76
+ elif isinstance(_i, str):
77
+ if "_" in _i:
78
+ _i = _i.replace("_", "").lower()
79
+ elif _i is None:
80
+ _i = "none"
81
+ # else: other types and values will be returned directly
82
+ return _i
83
+
84
+ if (epoch_based := config.train.get("epoch_based", None)) is not None and (not epoch_based):
85
+ focus_item.pop("train/num_epochs")
86
+ else:
87
+ # 默认基于epoch
88
+ focus_item.pop("train/num_iters")
89
+
90
+ exp_names = [model_name]
91
+ for key, alias in focus_item.items():
92
+ item = get_value_recurse(keys=key.split("/"), info=config)
93
+ formatted_item = _format_item(item)
94
+ if formatted_item == "false":
95
+ continue
96
+ exp_names.append(f"{alias.upper()}{formatted_item}")
97
+
98
+ info = config.get("info", None)
99
+ if info:
100
+ exp_names.append(f"INFO{info.lower()}")
101
+
102
+ return "_".join(exp_names)
103
+
104
+
105
+ def pre_mkdir(path_config):
106
+ # 提前创建好记录文件,避免自动创建的时候触发文件创建事件
107
+ check_mkdir(path_config["pth_log"])
108
+ make_log(path_config["log"], f"=== log {datetime.now()} ===")
109
+
110
+ # 提前创建好存储预测结果和存放模型的文件夹
111
+ check_mkdir(path_config["save"])
112
+ check_mkdir(path_config["pth"])
113
+
114
+
115
+ def check_mkdir(dir_name, delete_if_exists=False):
116
+ if not os.path.exists(dir_name):
117
+ os.makedirs(dir_name)
118
+ else:
119
+ if delete_if_exists:
120
+ print(f"{dir_name} will be re-created!!!")
121
+ shutil.rmtree(dir_name)
122
+ os.makedirs(dir_name)
123
+
124
+
125
+ def make_log(path, context):
126
+ with open(path, "a") as log:
127
+ log.write(f"{context}\n")
128
+
129
+
130
+ def iterate_nested_sequence(nested_sequence):
131
+ """
132
+ 当前支持list/tuple/int/float/range()的多层嵌套,注意不要嵌套的太深,小心超出python默认的最大递归深度
133
+
134
+ 例子
135
+ ::
136
+
137
+ for x in iterate_nested_sequence([[1, (2, 3)], range(3, 10), 0]):
138
+ print(x)
139
+
140
+ 1
141
+ 2
142
+ 3
143
+ 3
144
+ 4
145
+ 5
146
+ 6
147
+ 7
148
+ 8
149
+ 9
150
+ 0
151
+
152
+ :param nested_sequence: 多层嵌套的序列
153
+ :return: generator
154
+ """
155
+ for item in nested_sequence:
156
+ if isinstance(item, (int, float)):
157
+ yield item
158
+ elif isinstance(item, (list, tuple, range)):
159
+ yield from iterate_nested_sequence(item)
160
+ else:
161
+ raise NotImplementedError
162
+
163
+
164
+ def get_value_recurse(keys: list, info: dict):
165
+ curr_key, sub_keys = keys[0], keys[1:]
166
+
167
+ if (sub_info := info.get(curr_key, "NoKey")) == "NoKey":
168
+ raise KeyError(f"{curr_key} must be contained in {info}")
169
+
170
+ if sub_keys:
171
+ return get_value_recurse(keys=sub_keys, info=sub_info)
172
+ else:
173
+ return sub_info
174
+
175
+
176
+ def mapping_to_str(mapping: abc.Mapping, *, prefix: str = " ", lvl: int = 0, max_lvl: int = 1) -> str:
177
+ """
178
+ Print the structural information of the dict.
179
+ """
180
+ sub_lvl = lvl + 1
181
+ cur_prefix = prefix * lvl
182
+ sub_prefix = prefix * sub_lvl
183
+
184
+ if lvl == max_lvl:
185
+ sub_items = str(mapping)
186
+ else:
187
+ sub_items = ["{"]
188
+ for k, v in mapping.items():
189
+ sub_item = sub_prefix + k + ": "
190
+ if isinstance(v, abc.Mapping):
191
+ sub_item += mapping_to_str(v, prefix=prefix, lvl=sub_lvl, max_lvl=max_lvl)
192
+ else:
193
+ sub_item += str(v)
194
+ sub_items.append(sub_item)
195
+ sub_items.append(cur_prefix + "}")
196
+ sub_items = "\n".join(sub_items)
197
+ return sub_items
utils/recorder/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .counter import TrainingCounter
4
+ from .group_metric_caller import GroupedMetricRecorder
5
+ from .logger import TBLogger
6
+ from .meter_recorder import AvgMeter, HistoryBuffer
7
+ from .visualize_results import plot_results
utils/recorder/counter.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+
4
+ class TrainingCounter:
5
+ def __init__(self, epoch_length, epoch_based=True, *, num_epochs=None, num_total_iters=None) -> None:
6
+ self.num_inner_iters = epoch_length
7
+ self._iter_counter = 0
8
+ self._epoch_counter = 0
9
+
10
+ if epoch_based:
11
+ assert num_epochs is not None
12
+ self.num_epochs = num_epochs
13
+ self.num_total_iters = num_epochs * epoch_length
14
+ else:
15
+ assert num_total_iters is not None
16
+ self.num_total_iters = num_total_iters
17
+ self.num_epochs = math.ceil(num_total_iters / epoch_length)
18
+
19
+ def set_start_epoch(self, start_epoch):
20
+ self._epoch_counter = start_epoch
21
+ self._iter_counter = start_epoch * self.num_inner_iters
22
+
23
+ def set_start_iterations(self, start_iteration):
24
+ self._iter_counter = start_iteration
25
+ self._epoch_counter = start_iteration // self.num_inner_iters
26
+
27
+ def every_n_epochs(self, n: int) -> bool:
28
+ return (self._epoch_counter + 1) % n == 0 if n > 0 else False
29
+
30
+ def every_n_iters(self, n: int) -> bool:
31
+ return (self._iter_counter + 1) % n == 0 if n > 0 else False
32
+
33
+ def is_first_epoch(self) -> bool:
34
+ return self._epoch_counter == 0
35
+
36
+ def is_last_epoch(self) -> bool:
37
+ return self._epoch_counter == self.num_epochs - 1
38
+
39
+ def is_first_inner_iter(self) -> bool:
40
+ return self._iter_counter % self.num_inner_iters == 0
41
+
42
+ def is_last_inner_iter(self) -> bool:
43
+ return (self._iter_counter + 1) % self.num_inner_iters == 0
44
+
45
+ def is_first_total_iter(self) -> bool:
46
+ return self._iter_counter == 0
47
+
48
+ def is_last_total_iter(self) -> bool:
49
+ return self._iter_counter == self.num_total_iters - 1
50
+
51
+ def update_iter_counter(self):
52
+ self._iter_counter += 1
53
+
54
+ def update_epoch_counter(self):
55
+ self._epoch_counter += 1
56
+
57
+ def reset_iter_all_counter(self):
58
+ self._iter_counter = 0
59
+ self._epoch_counter = 0
60
+
61
+ @property
62
+ def curr_iter(self):
63
+ return self._iter_counter
64
+
65
+ @property
66
+ def next_iter(self):
67
+ return self._iter_counter + 1
68
+
69
+ @property
70
+ def curr_epoch(self):
71
+ return self._epoch_counter
72
+
73
+ @property
74
+ def curr_percent(self):
75
+ return self._iter_counter / self.num_total_iters
utils/recorder/group_metric_caller.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2021/1/4
3
+ # @Author : Lart Pang
4
+ # @GitHub : https://github.com/lartpang
5
+
6
+ from collections import OrderedDict
7
+
8
+ import numpy as np
9
+ import py_sod_metrics
10
+
11
+
12
+ def ndarray_to_basetype(data):
13
+
14
+
15
+ def _to_list_or_scalar(item):
16
+ listed_item = item.tolist()
17
+ if isinstance(listed_item, list) and len(listed_item) == 1:
18
+ listed_item = listed_item[0]
19
+ return listed_item
20
+
21
+ if isinstance(data, (tuple, list)):
22
+ results = [_to_list_or_scalar(item) for item in data]
23
+ elif isinstance(data, dict):
24
+ results = {k: _to_list_or_scalar(item) for k, item in data.items()}
25
+ else:
26
+ assert isinstance(data, np.ndarray)
27
+ results = _to_list_or_scalar(data)
28
+ return results
29
+
30
+
31
+ def round_w_zero_padding(x, bit_width):
32
+ x = str(round(x, bit_width))
33
+ x += "0" * (bit_width - len(x.split(".")[-1]))
34
+ return x
35
+
36
+
37
+ INDIVADUAL_METRIC_MAPPING = {
38
+ "sm": py_sod_metrics.Smeasure,
39
+ "wfm": py_sod_metrics.WeightedFmeasure,
40
+ "mae": py_sod_metrics.MAE,
41
+ "em": py_sod_metrics.Emeasure,
42
+ }
43
+ BINARY_CLASSIFICATION_METRIC_MAPPING = {
44
+ "fmeasure": {
45
+ "handler": py_sod_metrics.FmeasureHandler,
46
+ "kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=False, beta=0.3),
47
+ },
48
+ "iou": {
49
+ "handler": py_sod_metrics.IOUHandler,
50
+ "kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=False),
51
+ },
52
+ "dice": {
53
+ "handler": py_sod_metrics.DICEHandler,
54
+ "kwargs": dict(with_dynamic=True, with_adaptive=False, with_binary=False),
55
+ },
56
+ }
57
+
58
+
59
+ class ImageMetricRecorder:
60
+ supported_metrics = sorted(INDIVADUAL_METRIC_MAPPING.keys()) + sorted(BINARY_CLASSIFICATION_METRIC_MAPPING.keys())
61
+
62
+ def __init__(self, metric_names=("sm", "wfm", "mae", "fmeasure", "em")):
63
+
64
+ if not metric_names:
65
+ metric_names = self.supported_metrics
66
+ assert all([m in self.supported_metrics for m in metric_names]), f"Only support: {self.supported_metrics}"
67
+
68
+ self.metric_objs = {}
69
+ has_existed = False
70
+ for metric_name in metric_names:
71
+ if metric_name in INDIVADUAL_METRIC_MAPPING:
72
+ self.metric_objs[metric_name] = INDIVADUAL_METRIC_MAPPING[metric_name]()
73
+ else: # metric_name in BINARY_CLASSIFICATION_METRIC_MAPPING
74
+ if not has_existed: # only init once
75
+ self.metric_objs["fmeasurev2"] = py_sod_metrics.FmeasureV2()
76
+ has_existed = True
77
+ metric_handler = BINARY_CLASSIFICATION_METRIC_MAPPING[metric_name]
78
+ self.metric_objs["fmeasurev2"].add_handler(
79
+ handler_name=metric_name, metric_handler=metric_handler["handler"](**metric_handler["kwargs"])
80
+ )
81
+
82
+ def step(self, pre: np.ndarray, gt: np.ndarray, gt_path: str):
83
+ assert pre.shape == gt.shape, (pre.shape, gt.shape, gt_path)
84
+ assert pre.dtype == gt.dtype == np.uint8, (pre.dtype, gt.dtype, gt_path)
85
+
86
+ for m_obj in self.metric_objs.values():
87
+ m_obj.step(pre, gt)
88
+
89
+ def get_all_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
90
+ sequential_results = {}
91
+ numerical_results = {}
92
+ for m_name, m_obj in self.metric_objs.items():
93
+ info = m_obj.get_results()
94
+ if m_name == "fmeasurev2":
95
+ for _name, results in info.items():
96
+ dynamic_results = results.get("dynamic")
97
+ adaptive_results = results.get("adaptive")
98
+ if dynamic_results is not None:
99
+ sequential_results[_name] = np.flip(dynamic_results)
100
+ numerical_results[f"max{_name}"] = dynamic_results.max()
101
+ numerical_results[f"avg{_name}"] = dynamic_results.mean()
102
+ if adaptive_results is not None:
103
+ numerical_results[f"adp{_name}"] = adaptive_results
104
+ else:
105
+ results = info[m_name]
106
+ if m_name in ("wfm", "sm", "mae"):
107
+ numerical_results[m_name] = results
108
+ elif m_name == "em":
109
+ sequential_results[m_name] = np.flip(results["curve"])
110
+ numerical_results.update(
111
+ {
112
+ "maxem": results["curve"].max(),
113
+ "avgem": results["curve"].mean(),
114
+ "adpem": results["adp"],
115
+ }
116
+ )
117
+ else:
118
+ raise NotImplementedError(m_name)
119
+
120
+ if num_bits is not None and isinstance(num_bits, int):
121
+ numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()}
122
+ if not return_ndarray:
123
+ sequential_results = ndarray_to_basetype(sequential_results)
124
+ numerical_results = ndarray_to_basetype(numerical_results)
125
+ return {"sequential": sequential_results, "numerical": numerical_results}
126
+
127
+ def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
128
+ return self.get_all_results(num_bits=num_bits, return_ndarray=return_ndarray)["numerical"]
129
+
130
+
131
+ class GroupedMetricRecorder(object):
132
+ supported_metrics = ["mae", "em", "sm", "wfm"] + sorted(BINARY_CLASSIFICATION_METRIC_MAPPING.keys())
133
+
134
+ def __init__(self, group_names=None, metric_names=("sm", "wfm", "mae", "fmeasure", "em")):
135
+ self.group_names = group_names
136
+ self.metric_names = metric_names
137
+ self.zero()
138
+
139
+ def zero(self):
140
+ self.metric_recorders = {}
141
+ if self.group_names is not None:
142
+ self.metric_recorders.update(
143
+ {n: ImageMetricRecorder(metric_names=self.metric_names) for n in self.group_names}
144
+ )
145
+
146
+ def step(self, group_name: str, pre: np.ndarray, gt: np.ndarray, gt_path: str):
147
+ if group_name not in self.metric_recorders:
148
+ self.metric_recorders[group_name] = ImageMetricRecorder(metric_names=self.metric_names)
149
+ self.metric_recorders[group_name].step(pre, gt, gt_path)
150
+
151
+ def show(self, num_bits: int = 3, return_group: bool = False):
152
+ groups_metrics = {
153
+ n: r.get_all_results(num_bits=None, return_ndarray=True) for n, r in self.metric_recorders.items()
154
+ }
155
+
156
+ results = {}
157
+ for group_metrics in groups_metrics.values():
158
+ for metric_type, metric_group in group_metrics.items(): # sequential and numerical
159
+ results.setdefault(metric_type, {})
160
+ for metric_name, metric_array in metric_group.items():
161
+ results[metric_type].setdefault(metric_name, []).append(metric_array)
162
+
163
+ numerical_results = {}
164
+ for metric_type, metric_group in results.items():
165
+ for metric_name, metric_array in metric_group.items():
166
+ metric_array = np.mean(np.vstack(metric_array), axis=0) # average over all groups
167
+
168
+ if metric_name in BINARY_CLASSIFICATION_METRIC_MAPPING or metric_name == "em":
169
+ if metric_type == "sequential":
170
+ numerical_results[f"max{metric_name}"] = metric_array.max()
171
+ numerical_results[f"avg{metric_name}"] = metric_array.mean()
172
+ else:
173
+ if metric_type == "numerical":
174
+ if metric_name.startswith(("max", "avg")):
175
+ # these metrics (maxfm, avgfm, maxem, avgem) will be recomputed within the group
176
+ continue
177
+ numerical_results[metric_name] = metric_array
178
+
179
+ numerical_results = ndarray_to_basetype(numerical_results)
180
+ numerical_results = {k: round(v, num_bits) for k, v in numerical_results.items()}
181
+ numerical_results = self.sort_results(numerical_results)
182
+ if not return_group:
183
+ return numerical_results
184
+
185
+ group_numerical_results = {}
186
+ for group_name, group_metric in groups_metrics.items():
187
+ group_metric = {k: v.round(num_bits) for k, v in group_metric["numerical"].items()}
188
+ group_metric = ndarray_to_basetype(group_metric)
189
+ group_numerical_results[group_name] = self.sort_results(group_metric)
190
+ return numerical_results, group_numerical_results
191
+
192
+ def sort_results(self, results: dict) -> OrderedDict:
193
+ """for a single group of metrics"""
194
+ sorted_results = OrderedDict()
195
+ all_keys = sorted(results.keys(), key=lambda item: item[::-1])
196
+ for name in self.metric_names:
197
+ for key in all_keys:
198
+ if key.endswith(name):
199
+ sorted_results[key] = results[key]
200
+ return sorted_results
utils/recorder/logger.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.tensorboard import SummaryWriter
2
+
3
+
4
+ class TBLogger:
5
+ def __init__(self, tb_root):
6
+ self.tb_root = tb_root
7
+ self.tb = None
8
+
9
+ def write_to_tb(self, name, data, curr_iter):
10
+ assert self.tb_root is not None
11
+
12
+ if self.tb is None:
13
+ self.tb = SummaryWriter(self.tb_root)
14
+
15
+ if not isinstance(data, (tuple, list)):
16
+ self.tb.add_scalar(f"data/{name}", data, curr_iter)
17
+ else:
18
+ for idx, data_item in enumerate(data):
19
+ self.tb.add_scalar(f"data/{name}_{idx}", data_item, curr_iter)
20
+
21
+ def close_tb(self):
22
+ if self.tb is not None:
23
+ self.tb.close()
utils/recorder/meter_recorder.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from collections import deque
3
+
4
+
5
+ class AvgMeter(object):
6
+ __slots__ = ["value", "sum", "count"]
7
+
8
+ def __init__(self):
9
+ self.value = 0
10
+ self.sum = 0
11
+ self.count = 0
12
+
13
+ def reset(self):
14
+ self.value = 0
15
+ self.sum = 0
16
+ self.count = 0
17
+
18
+ def update(self, value, num=1):
19
+ self.value = value
20
+ self.sum += value * num
21
+ self.count += num
22
+
23
+ @property
24
+ def avg(self):
25
+ return self.sum / self.count
26
+
27
+ def __repr__(self) -> str:
28
+ return f"{self.avg:.5f}"
29
+
30
+
31
+ class HistoryBuffer:
32
+ """The class tracks a series of values and provides access to the smoothed
33
+ value over a window or the global average / sum of the sequence.
34
+
35
+ Args:
36
+ window_size (int): The maximal number of values that can
37
+ be stored in the buffer. Defaults to 20.
38
+
39
+ Example::
40
+
41
+ >>> his_buf = HistoryBuffer()
42
+ >>> his_buf.update(0.1)
43
+ >>> his_buf.update(0.2)
44
+ >>> his_buf.avg
45
+ 0.15
46
+ """
47
+
48
+ def __init__(self, window_size: int = 20) -> None:
49
+ self._history = deque(maxlen=window_size)
50
+ self._count: int = 0
51
+ self._sum: float = 0
52
+ self.reset()
53
+
54
+ def reset(self):
55
+ self._history.clear()
56
+ self._count = 0
57
+ self._sum = 0
58
+
59
+ def update(self, value: float, num: int = 1) -> None:
60
+ """Add a new scalar value. If the length of queue exceeds ``window_size``,
61
+ the oldest element will be removed from the queue.
62
+ """
63
+ self._history.append(value)
64
+ self._count += num
65
+ self._sum += value * num
66
+
67
+ @property
68
+ def latest(self) -> float:
69
+ """The latest value of the queue."""
70
+ return self._history[-1]
71
+
72
+ @property
73
+ def avg(self) -> float:
74
+ """The average over the window."""
75
+ if len(self._history) == 0:
76
+ return 0
77
+ else:
78
+ return sum(self._history) / len(self._history)
79
+
80
+ @property
81
+ def global_avg(self) -> float:
82
+ """The global average of the queue."""
83
+ if self._count == 0:
84
+ return 0
85
+ else:
86
+ return self._sum / self._count
87
+
88
+ @property
89
+ def global_sum(self) -> float:
90
+ """The global sum of the queue."""
91
+ return self._sum
utils/recorder/visualize_results.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import matplotlib
5
+
6
+ matplotlib.use("Agg")
7
+ import numpy as np
8
+ import torchvision.transforms.functional as tv_tf
9
+ from torchvision.utils import make_grid
10
+
11
+
12
+ def plot_results(data_container, save_path, base_size=256, is_rgb=True):
13
+ """Plot the results conresponding to the batched images based on the `make_grid` method from `torchvision`.
14
+
15
+ Args:
16
+ data_container (dict): Dict containing data you want to plot.
17
+ save_path (str): Path of the exported image.
18
+ """
19
+ font_cfg = dict(fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, thickness=2)
20
+
21
+ grids = []
22
+ for subplot_id, (name, data) in enumerate(data_container.items()):
23
+ if data.ndim == 3:
24
+ data = data.unsqueeze(1)
25
+
26
+ grid = make_grid(data, nrow=data.shape[0], padding=2, normalize=False)
27
+ grid = np.array(tv_tf.to_pil_image(grid.float()))
28
+ h, w = grid.shape[:2]
29
+ ratio = base_size / h
30
+ grid = cv2.resize(grid, dsize=None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
31
+
32
+ (text_w, text_h), baseline = cv2.getTextSize(text=name, **font_cfg)
33
+ text_xy = 20, 20 + text_h // 2 + baseline
34
+ cv2.putText(grid, text=name, org=text_xy, color=(255, 255, 255), **font_cfg)
35
+
36
+ grids.append(grid)
37
+ grids = np.concatenate(grids, axis=0) # H,W,C
38
+
39
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
40
+
41
+ if is_rgb:
42
+ grids = cv2.cvtColor(grids, cv2.COLOR_RGB2BGR)
43
+ cv2.imwrite(save_path, grids)