Upload 20 files
Browse files- utils/.DS_Store +0 -0
- utils/__init__.py +0 -0
- utils/io/__init__.py +7 -0
- utils/io/image.py +41 -0
- utils/io/params.py +30 -0
- utils/ops/__init__.py +8 -0
- utils/ops/array_ops.py +97 -0
- utils/ops/tensor_ops.py +31 -0
- utils/pipeline/__init__.py +8 -0
- utils/pipeline/optimizer.py +169 -0
- utils/pipeline/scaler.py +59 -0
- utils/pipeline/scheduler.py +410 -0
- utils/pt_utils.py +66 -0
- utils/py_utils.py +197 -0
- utils/recorder/__init__.py +7 -0
- utils/recorder/counter.py +75 -0
- utils/recorder/group_metric_caller.py +200 -0
- utils/recorder/logger.py +23 -0
- utils/recorder/meter_recorder.py +91 -0
- utils/recorder/visualize_results.py +43 -0
utils/.DS_Store
ADDED
|
Binary file (8.2 kB). View file
|
|
|
utils/__init__.py
ADDED
|
File without changes
|
utils/io/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2021/5/17
|
| 3 |
+
# @Author : Lart Pang
|
| 4 |
+
# @GitHub : https://github.com/lartpang
|
| 5 |
+
|
| 6 |
+
from .image import read_color_array, read_gray_array
|
| 7 |
+
from .params import load_weight, save_weight
|
utils/io/image.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2021/5/17
|
| 3 |
+
# @Author : Lart Pang
|
| 4 |
+
# @GitHub : https://github.com/lartpang
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from utils.ops import minmax
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def read_gray_array(path, div_255=False, to_normalize=False, thr=-1, dtype=np.float32) -> np.ndarray:
|
| 12 |
+
"""
|
| 13 |
+
1. read the binary image with the suffix `.jpg` or `.png`
|
| 14 |
+
into a grayscale ndarray
|
| 15 |
+
2. (to_normalize=True) rescale the ndarray to [0, 1]
|
| 16 |
+
3. (thr >= 0) binarize the ndarray with `thr`
|
| 17 |
+
4. return a gray ndarray (np.float32)
|
| 18 |
+
"""
|
| 19 |
+
assert path.endswith(".jpg") or path.endswith(".png"), path
|
| 20 |
+
assert not div_255 or not to_normalize, path
|
| 21 |
+
gray_array = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
|
| 22 |
+
assert gray_array is not None, f"Image Not Found: {path}"
|
| 23 |
+
|
| 24 |
+
if div_255:
|
| 25 |
+
gray_array = gray_array / 255
|
| 26 |
+
|
| 27 |
+
if to_normalize:
|
| 28 |
+
gray_array = minmax(gray_array, up_bound=255)
|
| 29 |
+
|
| 30 |
+
if thr >= 0:
|
| 31 |
+
gray_array = gray_array > thr
|
| 32 |
+
|
| 33 |
+
return gray_array.astype(dtype)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def read_color_array(path: str):
|
| 37 |
+
assert path.endswith(".jpg") or path.endswith(".png")
|
| 38 |
+
bgr_array = cv2.imread(path, cv2.IMREAD_COLOR)
|
| 39 |
+
assert bgr_array is not None, f"Image Not Found: {path}"
|
| 40 |
+
rgb_array = cv2.cvtColor(bgr_array, cv2.COLOR_BGR2RGB)
|
| 41 |
+
return rgb_array
|
utils/io/params.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2020/12/19
|
| 3 |
+
# @Author : Lart Pang
|
| 4 |
+
# @GitHub : https://github.com/lartpang
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def save_weight(save_path, model):
|
| 12 |
+
print(f"Saving weight '{save_path}'")
|
| 13 |
+
if isinstance(model, dict):
|
| 14 |
+
model_state = model
|
| 15 |
+
else:
|
| 16 |
+
model_state = model.module.state_dict() if hasattr(model, "module") else model.state_dict()
|
| 17 |
+
torch.save(model_state, save_path)
|
| 18 |
+
print(f"Saved weight '{save_path}' " f"(only contain the net's weight)")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def load_weight(load_path, model, *, strict=True, skip_unmatched_shape=False):
|
| 22 |
+
assert os.path.exists(load_path), load_path
|
| 23 |
+
model_params = model.state_dict()
|
| 24 |
+
for k, v in torch.load(load_path, map_location="cpu").items():
|
| 25 |
+
if k.endswith("module."):
|
| 26 |
+
k = k[7:]
|
| 27 |
+
if skip_unmatched_shape and k in model_params and v.shape != model_params[k].shape:
|
| 28 |
+
continue
|
| 29 |
+
model_params[k] = v
|
| 30 |
+
model.load_state_dict(model_params, strict=strict)
|
utils/ops/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2020/12/19
|
| 3 |
+
# @Author : Lart Pang
|
| 4 |
+
# @GitHub : https://github.com/lartpang
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from .array_ops import *
|
| 8 |
+
from .tensor_ops import *
|
utils/ops/array_ops.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def minmax(data_array: np.ndarray, up_bound: float = None) -> np.ndarray:
|
| 9 |
+
"""
|
| 10 |
+
::
|
| 11 |
+
|
| 12 |
+
data_array = (data_array / up_bound)
|
| 13 |
+
if min_value != max_value:
|
| 14 |
+
data_array = (data_array - min_value) / (max_value - min_value)
|
| 15 |
+
|
| 16 |
+
:param data_array:
|
| 17 |
+
:param up_bound: if is not None, data_array will devided by it before the minmax ops.
|
| 18 |
+
:return:
|
| 19 |
+
"""
|
| 20 |
+
if up_bound is not None:
|
| 21 |
+
data_array = data_array / up_bound
|
| 22 |
+
max_value = data_array.max()
|
| 23 |
+
min_value = data_array.min()
|
| 24 |
+
if max_value != min_value:
|
| 25 |
+
data_array = (data_array - min_value) / (max_value - min_value)
|
| 26 |
+
return data_array
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def clip_to_normalize(data_array: np.ndarray, clip_range: tuple = None) -> np.ndarray:
|
| 30 |
+
clip_range = sorted(clip_range)
|
| 31 |
+
if len(clip_range) == 3:
|
| 32 |
+
clip_min, clip_mid, clip_max = clip_range
|
| 33 |
+
assert 0 <= clip_min < clip_mid < clip_max <= 1, clip_range
|
| 34 |
+
lower_array = data_array[data_array < clip_mid]
|
| 35 |
+
higher_array = data_array[data_array > clip_mid]
|
| 36 |
+
if lower_array.size > 0:
|
| 37 |
+
lower_array = np.clip(lower_array, a_min=clip_min, a_max=1)
|
| 38 |
+
max_lower = lower_array.max()
|
| 39 |
+
lower_array = minmax(lower_array) * max_lower
|
| 40 |
+
data_array[data_array < clip_mid] = lower_array
|
| 41 |
+
if higher_array.size > 0:
|
| 42 |
+
higher_array = np.clip(higher_array, a_min=0, a_max=clip_max)
|
| 43 |
+
min_lower = higher_array.min()
|
| 44 |
+
higher_array = minmax(higher_array) * (1 - min_lower) + min_lower
|
| 45 |
+
data_array[data_array > clip_mid] = higher_array
|
| 46 |
+
elif len(clip_range) == 2:
|
| 47 |
+
clip_min, clip_max = clip_range
|
| 48 |
+
assert 0 <= clip_min < clip_max <= 1, clip_range
|
| 49 |
+
if clip_min != 0 and clip_max != 1:
|
| 50 |
+
data_array = np.clip(data_array, a_min=clip_min, a_max=clip_max)
|
| 51 |
+
data_array = minmax(data_array)
|
| 52 |
+
elif clip_range is None:
|
| 53 |
+
data_array = minmax(data_array)
|
| 54 |
+
else:
|
| 55 |
+
raise NotImplementedError
|
| 56 |
+
return data_array
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def save_array_as_image(data_array: np.ndarray, save_name: str, save_dir: str, to_minmax: bool = False):
|
| 60 |
+
"""
|
| 61 |
+
save the ndarray as a image
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
data_array: np.float32 the max value is less than or equal to 1
|
| 65 |
+
save_name: with special suffix
|
| 66 |
+
save_dir: the dirname of the image path
|
| 67 |
+
to_minmax: minmax the array
|
| 68 |
+
"""
|
| 69 |
+
if not os.path.exists(save_dir):
|
| 70 |
+
os.makedirs(save_dir)
|
| 71 |
+
save_path = os.path.join(save_dir, save_name)
|
| 72 |
+
if data_array.dtype != np.uint8:
|
| 73 |
+
if data_array.max() > 1:
|
| 74 |
+
raise Exception("the range of data_array has smoe errors")
|
| 75 |
+
data_array = (data_array * 255).astype(np.uint8)
|
| 76 |
+
if to_minmax:
|
| 77 |
+
data_array = minmax(data_array, up_bound=255)
|
| 78 |
+
data_array = (data_array * 255).astype(np.uint8)
|
| 79 |
+
cv2.imwrite(save_path, data_array)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def resize(image_array: np.ndarray, height, width, interpolation=cv2.INTER_LINEAR):
|
| 83 |
+
h, w = image_array.shape[:2]
|
| 84 |
+
if h == height and w == width:
|
| 85 |
+
return image_array
|
| 86 |
+
|
| 87 |
+
resized_image_array = cv2.resize(image_array, dsize=(width, height), interpolation=interpolation)
|
| 88 |
+
return resized_image_array
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def ms_resize(img, scales, base_h=None, base_w=None, interpolation=cv2.INTER_LINEAR):
|
| 92 |
+
assert isinstance(scales, (list, tuple))
|
| 93 |
+
if base_h is None:
|
| 94 |
+
base_h = img.shape[0]
|
| 95 |
+
if base_w is None:
|
| 96 |
+
base_w = img.shape[1]
|
| 97 |
+
return [resize(img, height=int(base_h * s), width=int(base_w * s), interpolation=interpolation) for s in scales]
|
utils/ops/tensor_ops.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2020
|
| 3 |
+
# @Author : Lart Pang
|
| 4 |
+
# @GitHub : https://github.com/lartpang
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def rescale_2x(x: torch.Tensor, scale_factor=2):
|
| 10 |
+
return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def resize_to(x: torch.Tensor, tgt_hw: tuple):
|
| 14 |
+
return F.interpolate(x, size=tgt_hw, mode="bilinear", align_corners=False)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def clip_grad(params, mode, clip_cfg: dict):
|
| 18 |
+
if mode == "norm":
|
| 19 |
+
if "max_norm" not in clip_cfg:
|
| 20 |
+
raise ValueError("`clip_cfg` must contain `max_norm`.")
|
| 21 |
+
torch.nn.utils.clip_grad_norm_(
|
| 22 |
+
params,
|
| 23 |
+
max_norm=clip_cfg.get("max_norm"),
|
| 24 |
+
norm_type=clip_cfg.get("norm_type", 2.0),
|
| 25 |
+
)
|
| 26 |
+
elif mode == "value":
|
| 27 |
+
if "clip_value" not in clip_cfg:
|
| 28 |
+
raise ValueError("`clip_cfg` must contain `clip_value`.")
|
| 29 |
+
torch.nn.utils.clip_grad_value_(params, clip_value=clip_cfg.get("clip_value"))
|
| 30 |
+
else:
|
| 31 |
+
raise NotImplementedError
|
utils/pipeline/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2021/5/31
|
| 3 |
+
# @Author : Lart Pang
|
| 4 |
+
# @GitHub : https://github.com/lartpang
|
| 5 |
+
|
| 6 |
+
from .optimizer import construct_optimizer
|
| 7 |
+
from .scaler import Scaler
|
| 8 |
+
from .scheduler import Scheduler
|
utils/pipeline/optimizer.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2020/12/19
|
| 3 |
+
# @Author : Lart Pang
|
| 4 |
+
# @GitHub : https://github.com/lartpang
|
| 5 |
+
import types
|
| 6 |
+
|
| 7 |
+
import torchvision.models
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.optim import SGD, Adam, AdamW
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_optimizer(mode, params, initial_lr, optim_cfg):
|
| 13 |
+
if mode == "sgd":
|
| 14 |
+
optimizer = SGD(
|
| 15 |
+
params=params,
|
| 16 |
+
lr=initial_lr,
|
| 17 |
+
momentum=optim_cfg["momentum"],
|
| 18 |
+
weight_decay=optim_cfg["weight_decay"],
|
| 19 |
+
nesterov=optim_cfg.get("nesterov", False),
|
| 20 |
+
)
|
| 21 |
+
elif mode == "adamw":
|
| 22 |
+
optimizer = AdamW(
|
| 23 |
+
params=params,
|
| 24 |
+
lr=initial_lr,
|
| 25 |
+
betas=optim_cfg.get("betas", (0.9, 0.999)),
|
| 26 |
+
weight_decay=optim_cfg.get("weight_decay", 0),
|
| 27 |
+
amsgrad=optim_cfg.get("amsgrad", False),
|
| 28 |
+
)
|
| 29 |
+
elif mode == "adam":
|
| 30 |
+
optimizer = Adam(
|
| 31 |
+
params=params,
|
| 32 |
+
lr=initial_lr,
|
| 33 |
+
betas=optim_cfg.get("betas", (0.9, 0.999)),
|
| 34 |
+
weight_decay=optim_cfg.get("weight_decay", 0),
|
| 35 |
+
amsgrad=optim_cfg.get("amsgrad", False),
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
raise NotImplementedError(mode)
|
| 39 |
+
return optimizer
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def group_params(model: nn.Module, group_mode: str, initial_lr: float, optim_cfg: dict):
|
| 43 |
+
if group_mode == "yolov5":
|
| 44 |
+
"""
|
| 45 |
+
norm, weight, bias = [], [], [] # optimizer parameter groups
|
| 46 |
+
for k, v in model.named_modules():
|
| 47 |
+
if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
|
| 48 |
+
bias.append(v.bias) # biases
|
| 49 |
+
if isinstance(v, nn.BatchNorm2d):
|
| 50 |
+
norm.append(v.weight) # no decay
|
| 51 |
+
elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
|
| 52 |
+
weight.append(v.weight) # apply decay
|
| 53 |
+
|
| 54 |
+
if opt.adam:
|
| 55 |
+
optimizer = optim.Adam(norm, lr=hyp["lr0"], betas=(hyp["momentum"], 0.999)) # adjust beta1 to momentum
|
| 56 |
+
else:
|
| 57 |
+
optimizer = optim.SGD(norm, lr=hyp["lr0"], momentum=hyp["momentum"], nesterov=True)
|
| 58 |
+
|
| 59 |
+
optimizer.add_param_group({"params": weight, "weight_decay": hyp["weight_decay"]}) # add weight with weight_decay
|
| 60 |
+
optimizer.add_param_group({"params": bias}) # add bias (biases)
|
| 61 |
+
"""
|
| 62 |
+
norm, weight, bias = [], [], [] # optimizer parameter groups
|
| 63 |
+
for k, v in model.named_modules():
|
| 64 |
+
if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
|
| 65 |
+
bias.append(v.bias) # conv bias and bn bias
|
| 66 |
+
if isinstance(v, nn.BatchNorm2d):
|
| 67 |
+
norm.append(v.weight) # bn weight
|
| 68 |
+
elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
|
| 69 |
+
weight.append(v.weight) # conv weight
|
| 70 |
+
params = [
|
| 71 |
+
{"params": filter(lambda p: p.requires_grad, bias), "weight_decay": 0.0},
|
| 72 |
+
{"params": filter(lambda p: p.requires_grad, norm), "weight_decay": 0.0},
|
| 73 |
+
{"params": filter(lambda p: p.requires_grad, weight)},
|
| 74 |
+
]
|
| 75 |
+
elif group_mode == "r3":
|
| 76 |
+
params = [
|
| 77 |
+
# 不对bias参数执行weight decay操作,weight decay主要的作用就是通过对网络
|
| 78 |
+
# 层的参数(包括weight和bias)做约束(L2正则化会使得网络层的参数更加平滑)达
|
| 79 |
+
# 到减少模型过拟合的效果。
|
| 80 |
+
{
|
| 81 |
+
"params": [
|
| 82 |
+
param for name, param in model.named_parameters() if name[-4:] == "bias" and param.requires_grad
|
| 83 |
+
],
|
| 84 |
+
"lr": 2 * initial_lr,
|
| 85 |
+
"weight_decay": 0,
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"params": [
|
| 89 |
+
param for name, param in model.named_parameters() if name[-4:] != "bias" and param.requires_grad
|
| 90 |
+
],
|
| 91 |
+
"lr": initial_lr,
|
| 92 |
+
"weight_decay": optim_cfg["weight_decay"],
|
| 93 |
+
},
|
| 94 |
+
]
|
| 95 |
+
elif group_mode == "all":
|
| 96 |
+
params = model.parameters()
|
| 97 |
+
elif group_mode == "finetune":
|
| 98 |
+
if hasattr(model, "module"):
|
| 99 |
+
model = model.module
|
| 100 |
+
assert hasattr(model, "get_grouped_params"), "Cannot get the method get_grouped_params of the model."
|
| 101 |
+
params_groups = model.get_grouped_params()
|
| 102 |
+
params = [
|
| 103 |
+
{
|
| 104 |
+
"params": filter(lambda p: p.requires_grad, params_groups["pretrained"]),
|
| 105 |
+
"lr": optim_cfg.get("diff_factor", 0.1) * initial_lr,
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"params": filter(lambda p: p.requires_grad, params_groups["retrained"]),
|
| 109 |
+
"lr": initial_lr,
|
| 110 |
+
},
|
| 111 |
+
]
|
| 112 |
+
elif group_mode == "finetune2":
|
| 113 |
+
if hasattr(model, "module"):
|
| 114 |
+
model = model.module
|
| 115 |
+
assert hasattr(model, "get_grouped_params"), "Cannot get the method get_grouped_params of the model."
|
| 116 |
+
params_groups = model.get_grouped_params()
|
| 117 |
+
params = [
|
| 118 |
+
{
|
| 119 |
+
"params": filter(lambda p: p.requires_grad, params_groups["pretrained_backbone"]),
|
| 120 |
+
"lr": 0.1 * initial_lr,
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"params": filter(lambda p: p.requires_grad, params_groups["pretrained_siamese"]),
|
| 124 |
+
"lr": 0.5 * initial_lr,
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"params": filter(lambda p: p.requires_grad, params_groups["retrained"]),
|
| 128 |
+
"lr": initial_lr,
|
| 129 |
+
},
|
| 130 |
+
]
|
| 131 |
+
else:
|
| 132 |
+
raise NotImplementedError
|
| 133 |
+
return params
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def construct_optimizer(model, initial_lr, mode, group_mode, cfg):
|
| 137 |
+
params = group_params(model, group_mode=group_mode, initial_lr=initial_lr, optim_cfg=cfg)
|
| 138 |
+
optimizer = get_optimizer(mode=mode, params=params, initial_lr=initial_lr, optim_cfg=cfg)
|
| 139 |
+
optimizer.lr_groups = types.MethodType(get_lr_groups, optimizer)
|
| 140 |
+
optimizer.lr_string = types.MethodType(get_lr_strings, optimizer)
|
| 141 |
+
return optimizer
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_lr_groups(self):
|
| 145 |
+
return [group["lr"] for group in self.param_groups]
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def get_lr_strings(self):
|
| 149 |
+
return ",".join([f"{group['lr']:.3e}" for group in self.param_groups])
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
model = torchvision.models.vgg11_bn()
|
| 154 |
+
norm, weight, bias = [], [], [] # optimizer parameter groups
|
| 155 |
+
for k, v in model.named_modules():
|
| 156 |
+
if hasattr(v, "bias") and isinstance(v.bias, nn.Parameter):
|
| 157 |
+
bias.append(v.bias) # biases
|
| 158 |
+
if isinstance(v, nn.BatchNorm2d):
|
| 159 |
+
norm.append(v.weight) # no decay
|
| 160 |
+
elif hasattr(v, "weight") and isinstance(v.weight, nn.Parameter):
|
| 161 |
+
weight.append(v.weight) # apply decay
|
| 162 |
+
|
| 163 |
+
optimizer = Adam(norm, lr=0.001, betas=(0.98, 0.999)) # adjust beta1 to momentum
|
| 164 |
+
# optimizer = optim.SGD(norm, lr=hyp["lr0"], momentum=hyp["momentum"], nesterov=True)
|
| 165 |
+
|
| 166 |
+
optimizer.add_param_group({"params": weight, "weight_decay": 1e-4}) # add weight with weight_decay
|
| 167 |
+
optimizer.add_param_group({"params": bias}) # add bias (biases)
|
| 168 |
+
|
| 169 |
+
print(optimizer)
|
utils/pipeline/scaler.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
from itertools import chain
|
| 3 |
+
|
| 4 |
+
from torch.cuda.amp import GradScaler, autocast
|
| 5 |
+
|
| 6 |
+
from .. import ops
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Scaler:
|
| 10 |
+
def __init__(
|
| 11 |
+
self, optimizer, use_fp16=False, *, set_to_none=False, clip_grad=False, clip_mode=None, clip_cfg=None
|
| 12 |
+
) -> None:
|
| 13 |
+
self.optimizer = optimizer
|
| 14 |
+
self.set_to_none = set_to_none
|
| 15 |
+
self.autocast = autocast(enabled=use_fp16)
|
| 16 |
+
self.scaler = GradScaler(enabled=use_fp16)
|
| 17 |
+
|
| 18 |
+
if clip_grad:
|
| 19 |
+
self.grad_clip_ops = partial(ops.clip_grad, mode=clip_mode, clip_cfg=clip_cfg)
|
| 20 |
+
else:
|
| 21 |
+
self.grad_clip_ops = None
|
| 22 |
+
|
| 23 |
+
def calculate_grad(self, loss):
|
| 24 |
+
self.scaler.scale(loss).backward()
|
| 25 |
+
if self.grad_clip_ops is not None:
|
| 26 |
+
self.scaler.unscale_(self.optimizer)
|
| 27 |
+
self.grad_clip_ops(chain(*[group["params"] for group in self.optimizer.param_groups]))
|
| 28 |
+
|
| 29 |
+
def update_grad(self):
|
| 30 |
+
self.scaler.step(self.optimizer)
|
| 31 |
+
self.scaler.update()
|
| 32 |
+
self.optimizer.zero_grad(set_to_none=self.set_to_none)
|
| 33 |
+
|
| 34 |
+
def state_dict(self):
|
| 35 |
+
r"""
|
| 36 |
+
Returns the state of the scaler as a :class:`dict`. It contains five entries:
|
| 37 |
+
|
| 38 |
+
* ``"scale"`` - a Python float containing the current scale
|
| 39 |
+
* ``"growth_factor"`` - a Python float containing the current growth factor
|
| 40 |
+
* ``"backoff_factor"`` - a Python float containing the current backoff factor
|
| 41 |
+
* ``"growth_interval"`` - a Python int containing the current growth interval
|
| 42 |
+
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
|
| 43 |
+
|
| 44 |
+
If this instance is not enabled, returns an empty dict.
|
| 45 |
+
|
| 46 |
+
.. note::
|
| 47 |
+
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
|
| 48 |
+
should be called after :meth:`update`.
|
| 49 |
+
"""
|
| 50 |
+
return self.scaler.state_dict()
|
| 51 |
+
|
| 52 |
+
def load_state_dict(self, state_dict):
|
| 53 |
+
r"""
|
| 54 |
+
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
|
| 58 |
+
"""
|
| 59 |
+
self.scaler.load_state_dict(state_dict)
|
utils/pipeline/scheduler.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2020/12/19
|
| 3 |
+
# @Author : Lart Pang
|
| 4 |
+
# @GitHub : https://github.com/lartpang
|
| 5 |
+
|
| 6 |
+
import copy
|
| 7 |
+
import math
|
| 8 |
+
import os.path
|
| 9 |
+
import warnings
|
| 10 |
+
from bisect import bisect_right
|
| 11 |
+
|
| 12 |
+
import matplotlib
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch.optim
|
| 15 |
+
from adjustText import adjust_text
|
| 16 |
+
|
| 17 |
+
matplotlib.use("Agg")
|
| 18 |
+
from matplotlib import pyplot as plt
|
| 19 |
+
|
| 20 |
+
# helper function ----------------------------------------------------------------------
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def linear_increase(low_bound, up_bound, percentage):
|
| 24 |
+
"""low_bound + [0, 1] * (up_bound - low_bound)"""
|
| 25 |
+
assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]"
|
| 26 |
+
return low_bound + (up_bound - low_bound) * percentage
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def cos_anneal(low_bound, up_bound, percentage):
|
| 30 |
+
assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]"
|
| 31 |
+
cos_percentage = (1 + math.cos(math.pi * percentage)) / 2.0
|
| 32 |
+
return linear_increase(low_bound, up_bound, percentage=cos_percentage)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def poly_anneal(low_bound, up_bound, percentage, lr_decay):
|
| 36 |
+
assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]"
|
| 37 |
+
poly_percentage = pow((1 - percentage), lr_decay)
|
| 38 |
+
return linear_increase(low_bound, up_bound, percentage=poly_percentage)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def linear_anneal(low_bound, up_bound, percentage):
|
| 42 |
+
assert 0 <= percentage <= 1, f"percentage({percentage}) must be in [0, 1]"
|
| 43 |
+
return linear_increase(low_bound, up_bound, percentage=1 - percentage)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# coefficient function ----------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_f3_coef_func(num_iters):
|
| 50 |
+
"""
|
| 51 |
+
F3Net
|
| 52 |
+
|
| 53 |
+
:param num_iters: The number of iterations for the total process.
|
| 54 |
+
:return:
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def get_f3_coef(curr_idx):
|
| 58 |
+
assert 0 <= curr_idx <= num_iters
|
| 59 |
+
return 1 - abs((curr_idx + 1) / (num_iters + 1) * 2 - 1)
|
| 60 |
+
|
| 61 |
+
return get_f3_coef
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_step_coef_func(gamma, milestones):
|
| 65 |
+
"""
|
| 66 |
+
lr = baselr * gamma ** 0 if curr_idx < milestones[0]
|
| 67 |
+
lr = baselr * gamma ** 1 if milestones[0] <= epoch < milestones[1]
|
| 68 |
+
...
|
| 69 |
+
|
| 70 |
+
:param gamma:
|
| 71 |
+
:param milestones:
|
| 72 |
+
:return: The function for generating the coefficient.
|
| 73 |
+
"""
|
| 74 |
+
if isinstance(milestones, (tuple, list)):
|
| 75 |
+
milestones = list(sorted(milestones))
|
| 76 |
+
return lambda curr_idx: gamma ** bisect_right(milestones, curr_idx)
|
| 77 |
+
elif isinstance(milestones, int):
|
| 78 |
+
return lambda curr_idx: gamma ** ((curr_idx + 1) // milestones)
|
| 79 |
+
else:
|
| 80 |
+
raise ValueError(f"milestones only can be list/tuple/int, but now it is {type(milestones)}")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def get_cos_coef_func(half_cycle, min_coef, max_coef=1):
|
| 84 |
+
"""
|
| 85 |
+
:param half_cycle: The number of iterations in a half cycle.
|
| 86 |
+
:param min_coef: The minimum coefficient of the learning rate.
|
| 87 |
+
:param max_coef: The maximum coefficient of the learning rate.
|
| 88 |
+
:return: The function for generating the coefficient.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def get_cos_coef(curr_idx):
|
| 92 |
+
recomputed_idx = curr_idx % (half_cycle + 1)
|
| 93 |
+
# recomputed \in [0, half_cycle]
|
| 94 |
+
return cos_anneal(low_bound=min_coef, up_bound=max_coef, percentage=recomputed_idx / half_cycle)
|
| 95 |
+
|
| 96 |
+
return get_cos_coef
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_fatcos_coef_func(start_iter, half_cycle, min_coef, max_coef=1):
|
| 100 |
+
"""
|
| 101 |
+
:param half_cycle: The number of iterations in a half cycle.
|
| 102 |
+
:param min_coef: The minimum coefficient of the learning rate.
|
| 103 |
+
:param max_coef: The maximum coefficient of the learning rate.
|
| 104 |
+
:return: The function for generating the coefficient.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
def get_cos_coef(curr_idx):
|
| 108 |
+
curr_idx = max(0, curr_idx - start_iter)
|
| 109 |
+
recomputed_idx = curr_idx % (half_cycle + 1)
|
| 110 |
+
# recomputed \in [0, half_cycle]
|
| 111 |
+
return cos_anneal(low_bound=min_coef, up_bound=max_coef, percentage=recomputed_idx / half_cycle)
|
| 112 |
+
|
| 113 |
+
return get_cos_coef
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_poly_coef_func(num_iters, lr_decay, min_coef, max_coef=1):
|
| 117 |
+
"""
|
| 118 |
+
:param num_iters: The number of iterations for the polynomial descent process.
|
| 119 |
+
:param lr_decay: The decay item of the polynomial descent process.
|
| 120 |
+
:param min_coef: The minimum coefficient of the learning rate.
|
| 121 |
+
:param max_coef: The maximum coefficient of the learning rate.
|
| 122 |
+
:return: The function for generating the coefficient.
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
def get_poly_coef(curr_idx):
|
| 126 |
+
assert 0 <= curr_idx <= num_iters, (curr_idx, num_iters)
|
| 127 |
+
return poly_anneal(low_bound=min_coef, up_bound=max_coef, percentage=curr_idx / num_iters, lr_decay=lr_decay)
|
| 128 |
+
|
| 129 |
+
return get_poly_coef
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# coefficient entry function ----------------------------------------------------------------------
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def get_scheduler_coef_func(mode, num_iters, cfg):
|
| 136 |
+
"""
|
| 137 |
+
the region is a closed interval: [0, num_iters]
|
| 138 |
+
"""
|
| 139 |
+
assert num_iters > 0
|
| 140 |
+
min_coef = cfg.get("min_coef", 1e-6)
|
| 141 |
+
if min_coef is None or min_coef == 0:
|
| 142 |
+
warnings.warn(f"The min_coef ({min_coef}) of the scheduler will be replaced with 1e-6")
|
| 143 |
+
min_coef = 1e-6
|
| 144 |
+
|
| 145 |
+
if mode == "step":
|
| 146 |
+
coef_func = get_step_coef_func(gamma=cfg["gamma"], milestones=cfg["milestones"])
|
| 147 |
+
elif mode == "cos":
|
| 148 |
+
if half_cycle := cfg.get("half_cycle"):
|
| 149 |
+
half_cycle -= 1
|
| 150 |
+
else:
|
| 151 |
+
half_cycle = num_iters
|
| 152 |
+
if (num_iters - half_cycle) % (half_cycle + 1) != 0:
|
| 153 |
+
# idx starts from 0
|
| 154 |
+
percentage = ((num_iters - half_cycle) % (half_cycle + 1)) / (half_cycle + 1) * 100
|
| 155 |
+
warnings.warn(
|
| 156 |
+
f"The final annealing process ({percentage:.3f}%) is not complete. "
|
| 157 |
+
f"Please pay attention to the generated 'lr_coef_curve.png'."
|
| 158 |
+
)
|
| 159 |
+
coef_func = get_cos_coef_func(half_cycle=half_cycle, min_coef=min_coef)
|
| 160 |
+
elif mode == "fatcos":
|
| 161 |
+
assert 0 <= cfg.start_percent < 1, cfg.start_percent
|
| 162 |
+
start_iter = int(cfg.start_percent * num_iters)
|
| 163 |
+
|
| 164 |
+
num_iters -= start_iter
|
| 165 |
+
if half_cycle := cfg.get("half_cycle"):
|
| 166 |
+
half_cycle -= 1
|
| 167 |
+
else:
|
| 168 |
+
half_cycle = num_iters
|
| 169 |
+
if (num_iters - half_cycle) % (half_cycle + 1) != 0:
|
| 170 |
+
# idx starts from 0
|
| 171 |
+
percentage = ((num_iters - half_cycle) % (half_cycle + 1)) / (half_cycle + 1) * 100
|
| 172 |
+
warnings.warn(
|
| 173 |
+
f"The final annealing process ({percentage:.3f}%) is not complete. "
|
| 174 |
+
f"Please pay attention to the generated 'lr_coef_curve.png'."
|
| 175 |
+
)
|
| 176 |
+
coef_func = get_fatcos_coef_func(start_iter=start_iter, half_cycle=half_cycle, min_coef=min_coef)
|
| 177 |
+
elif mode == "poly":
|
| 178 |
+
coef_func = get_poly_coef_func(num_iters=num_iters, lr_decay=cfg["lr_decay"], min_coef=min_coef)
|
| 179 |
+
elif mode == "constant":
|
| 180 |
+
coef_func = lambda x: cfg.get("coef", 1)
|
| 181 |
+
elif mode == "f3":
|
| 182 |
+
coef_func = get_f3_coef_func(num_iters=num_iters)
|
| 183 |
+
else:
|
| 184 |
+
raise NotImplementedError(f"{mode} must be in {Scheduler.supported_scheduler_modes}")
|
| 185 |
+
return coef_func
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_warmup_coef_func(num_iters, min_coef, max_coef=1, mode="linear"):
|
| 189 |
+
"""
|
| 190 |
+
the region is a closed interval: [0, num_iters]
|
| 191 |
+
"""
|
| 192 |
+
assert num_iters > 0
|
| 193 |
+
if mode == "cos":
|
| 194 |
+
anneal_func = cos_anneal
|
| 195 |
+
elif mode == "linear":
|
| 196 |
+
anneal_func = linear_anneal
|
| 197 |
+
else:
|
| 198 |
+
raise NotImplementedError(f"{mode} must be in {Scheduler.supported_warmup_modes}")
|
| 199 |
+
|
| 200 |
+
def get_warmup_coef(curr_idx):
|
| 201 |
+
return anneal_func(low_bound=min_coef, up_bound=max_coef, percentage=1 - curr_idx / num_iters)
|
| 202 |
+
|
| 203 |
+
return get_warmup_coef
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
# main class ----------------------------------------------------------------------
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
class Scheduler:
|
| 210 |
+
supported_scheduler_modes = ("step", "cos", "fatcos", "poly", "constant", "f3")
|
| 211 |
+
supported_warmup_modes = ("cos", "linear")
|
| 212 |
+
|
| 213 |
+
def __init__(self, optimizer, num_iters, epoch_length, scheduler_cfg, step_by_batch=True):
|
| 214 |
+
"""A customized wrapper of the scheduler.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
optimizer (): Optimizer.
|
| 218 |
+
num_iters (int): The total number of the iterations.
|
| 219 |
+
epoch_length (int): The number of the iterations of one epoch.
|
| 220 |
+
scheduler_cfg (dict): The config of the scheduler.
|
| 221 |
+
step_by_batch (bool, optional): The mode of updating the scheduler. Defaults to True.
|
| 222 |
+
|
| 223 |
+
Raises:
|
| 224 |
+
NotImplementedError:
|
| 225 |
+
"""
|
| 226 |
+
self.optimizer = optimizer
|
| 227 |
+
self.num_iters = num_iters
|
| 228 |
+
self.epoch_length = epoch_length
|
| 229 |
+
self.step_by_batch = step_by_batch
|
| 230 |
+
|
| 231 |
+
self.scheduler_cfg = copy.deepcopy(scheduler_cfg)
|
| 232 |
+
self.mode = scheduler_cfg["mode"]
|
| 233 |
+
if self.mode not in self.supported_scheduler_modes:
|
| 234 |
+
raise NotImplementedError(
|
| 235 |
+
f"{self.mode} is not implemented. Has been supported: {self.supported_scheduler_modes}"
|
| 236 |
+
)
|
| 237 |
+
warmup_cfg = scheduler_cfg.get("warmup", None)
|
| 238 |
+
|
| 239 |
+
num_warmup_iters = 0
|
| 240 |
+
if warmup_cfg is not None and isinstance(warmup_cfg, dict):
|
| 241 |
+
num_warmup_iters = warmup_cfg["num_iters"]
|
| 242 |
+
if num_warmup_iters > 0:
|
| 243 |
+
print("Will using warmup")
|
| 244 |
+
self.warmup_coef_func = get_warmup_coef_func(
|
| 245 |
+
num_warmup_iters,
|
| 246 |
+
min_coef=warmup_cfg.get("initial_coef", 0.01),
|
| 247 |
+
mode=warmup_cfg.get("mode", "linear"),
|
| 248 |
+
)
|
| 249 |
+
self.num_warmup_iters = num_warmup_iters
|
| 250 |
+
|
| 251 |
+
if step_by_batch:
|
| 252 |
+
num_scheduler_iters = num_iters - num_warmup_iters
|
| 253 |
+
else:
|
| 254 |
+
num_scheduler_iters = (num_iters - num_warmup_iters) // epoch_length
|
| 255 |
+
# the region is a closed interval
|
| 256 |
+
self.lr_coef_func = get_scheduler_coef_func(
|
| 257 |
+
mode=self.mode, num_iters=num_scheduler_iters - 1, cfg=scheduler_cfg["cfg"]
|
| 258 |
+
)
|
| 259 |
+
self.num_scheduler_iters = num_scheduler_iters
|
| 260 |
+
|
| 261 |
+
self.last_lr_coef = 0
|
| 262 |
+
self.initial_lrs = None
|
| 263 |
+
|
| 264 |
+
def __repr__(self):
|
| 265 |
+
formatted_string = [
|
| 266 |
+
f"{self.__class__.__name__}: (\n",
|
| 267 |
+
f"num_iters: {self.num_iters}\n",
|
| 268 |
+
f"epoch_length: {self.epoch_length}\n",
|
| 269 |
+
f"warmup_iter: [0, {self.num_warmup_iters})\n",
|
| 270 |
+
f"scheduler_iter: [{self.num_warmup_iters}, {self.num_iters - 1}]\n",
|
| 271 |
+
f"mode: {self.mode}\n",
|
| 272 |
+
f"scheduler_cfg: {self.scheduler_cfg}\n",
|
| 273 |
+
f"initial_lrs: {self.initial_lrs}\n",
|
| 274 |
+
f"step_by_batch: {self.step_by_batch}\n)",
|
| 275 |
+
]
|
| 276 |
+
return " ".join(formatted_string)
|
| 277 |
+
|
| 278 |
+
def record_lrs(self, param_groups):
|
| 279 |
+
self.initial_lrs = [g["lr"] for g in param_groups]
|
| 280 |
+
|
| 281 |
+
def update(self, coef: float):
|
| 282 |
+
assert self.initial_lrs is not None, "Please run .record_lrs(optimizer) first."
|
| 283 |
+
for curr_group, initial_lr in zip(self.optimizer.param_groups, self.initial_lrs):
|
| 284 |
+
curr_group["lr"] = coef * initial_lr
|
| 285 |
+
|
| 286 |
+
def step(self, curr_idx):
|
| 287 |
+
if curr_idx < self.num_warmup_iters:
|
| 288 |
+
# get maximum value (1.0) when curr_idx == self.num_warmup_iters
|
| 289 |
+
self.update(coef=self.get_lr_coef(curr_idx))
|
| 290 |
+
else:
|
| 291 |
+
# Start from a value lower than 1 (curr_idx == self.num_warmup_iters)
|
| 292 |
+
if self.step_by_batch:
|
| 293 |
+
self.update(coef=self.get_lr_coef(curr_idx))
|
| 294 |
+
else:
|
| 295 |
+
if curr_idx % self.epoch_length == 0:
|
| 296 |
+
self.update(coef=self.get_lr_coef(curr_idx))
|
| 297 |
+
|
| 298 |
+
def get_lr_coef(self, curr_idx):
|
| 299 |
+
coef = None
|
| 300 |
+
if curr_idx < self.num_warmup_iters:
|
| 301 |
+
coef = self.warmup_coef_func(curr_idx)
|
| 302 |
+
else:
|
| 303 |
+
# when curr_idx == self.num_warmup_iters, coef == 1.0
|
| 304 |
+
# down from the largest coef (1.0)
|
| 305 |
+
if self.step_by_batch:
|
| 306 |
+
coef = self.lr_coef_func(curr_idx - self.num_warmup_iters)
|
| 307 |
+
else:
|
| 308 |
+
if curr_idx % self.epoch_length == 0 or curr_idx == self.num_warmup_iters:
|
| 309 |
+
# warmup结束后尚未开始按照epoch进行调整的学习率调整,此时需要将系数调整为最大。
|
| 310 |
+
coef = self.lr_coef_func((curr_idx - self.num_warmup_iters) // self.epoch_length)
|
| 311 |
+
if coef is not None:
|
| 312 |
+
self.last_lr_coef = coef
|
| 313 |
+
return self.last_lr_coef
|
| 314 |
+
|
| 315 |
+
def plot_lr_coef_curve(self, save_path=""):
|
| 316 |
+
plt.rc("xtick", labelsize="small")
|
| 317 |
+
plt.rc("ytick", labelsize="small")
|
| 318 |
+
|
| 319 |
+
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(8, 4), dpi=600)
|
| 320 |
+
# give plot a title
|
| 321 |
+
ax.set_title("Learning Rate Coefficient Curve")
|
| 322 |
+
# make axis labels
|
| 323 |
+
ax.set_xlabel("Iteration")
|
| 324 |
+
ax.set_ylabel("Coefficient")
|
| 325 |
+
|
| 326 |
+
x_data = np.arange(self.num_iters)
|
| 327 |
+
y_data = np.array([self.get_lr_coef(x) for x in x_data])
|
| 328 |
+
|
| 329 |
+
# set lim
|
| 330 |
+
x_min, x_max = 0, self.num_iters - 1
|
| 331 |
+
dx = self.num_iters * 0.1
|
| 332 |
+
ax.set_xlim(x_min - dx, x_max + 2 * dx)
|
| 333 |
+
|
| 334 |
+
y_min, y_max = y_data.min(), y_data.max()
|
| 335 |
+
dy = (y_data.max() - y_data.min()) * 0.1
|
| 336 |
+
ax.set_ylim((y_min - dy, y_max + dy))
|
| 337 |
+
|
| 338 |
+
if self.step_by_batch:
|
| 339 |
+
marker_on = [0, -1]
|
| 340 |
+
key_point_xs = [0, self.num_iters - 1]
|
| 341 |
+
for idx in range(1, len(y_data) - 1):
|
| 342 |
+
prev_y = y_data[idx - 1]
|
| 343 |
+
curr_y = y_data[idx]
|
| 344 |
+
next_y = y_data[idx + 1]
|
| 345 |
+
if (
|
| 346 |
+
(curr_y > prev_y and curr_y >= next_y)
|
| 347 |
+
or (curr_y >= prev_y and curr_y > next_y)
|
| 348 |
+
or (curr_y <= prev_y and curr_y < next_y)
|
| 349 |
+
or (curr_y < prev_y and curr_y <= next_y)
|
| 350 |
+
):
|
| 351 |
+
marker_on.append(idx)
|
| 352 |
+
key_point_xs.append(idx)
|
| 353 |
+
|
| 354 |
+
marker_on = sorted(set(marker_on))
|
| 355 |
+
key_point_xs = sorted(set(key_point_xs))
|
| 356 |
+
key_point_ys = []
|
| 357 |
+
|
| 358 |
+
texts = []
|
| 359 |
+
for x in key_point_xs:
|
| 360 |
+
y = y_data[x]
|
| 361 |
+
key_point_ys.append(y)
|
| 362 |
+
|
| 363 |
+
texts.append(ax.text(x=x, y=y, s=f"({x:d},{y:.3e})"))
|
| 364 |
+
adjust_text(texts, arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=-0.3"))
|
| 365 |
+
|
| 366 |
+
# set ticks
|
| 367 |
+
ax.set_xticks(key_point_xs)
|
| 368 |
+
# ax.set_yticks(key_point_ys)
|
| 369 |
+
|
| 370 |
+
ax.plot(x_data, y_data, marker="o", markevery=marker_on)
|
| 371 |
+
else:
|
| 372 |
+
ax.plot(x_data, y_data)
|
| 373 |
+
|
| 374 |
+
ax.spines["right"].set_visible(False)
|
| 375 |
+
ax.spines["top"].set_visible(False)
|
| 376 |
+
ax.spines["left"].set_visible(True)
|
| 377 |
+
ax.spines["bottom"].set_visible(True)
|
| 378 |
+
|
| 379 |
+
plt.tight_layout()
|
| 380 |
+
if save_path:
|
| 381 |
+
fig.savefig(os.path.join(save_path, "lr_coef.png"))
|
| 382 |
+
plt.close()
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
if __name__ == "__main__":
|
| 386 |
+
model = torch.nn.Conv2d(10, 10, 3, 1, 1)
|
| 387 |
+
sche = Scheduler(
|
| 388 |
+
optimizer=torch.optim.SGD(model.parameters(), lr=0.1),
|
| 389 |
+
num_iters=30300,
|
| 390 |
+
epoch_length=505,
|
| 391 |
+
scheduler_cfg=dict(
|
| 392 |
+
warmup=dict(
|
| 393 |
+
num_iters=6060,
|
| 394 |
+
initial_coef=0.01,
|
| 395 |
+
mode="cos",
|
| 396 |
+
),
|
| 397 |
+
mode="cos",
|
| 398 |
+
cfg=dict(
|
| 399 |
+
half_cycle=6060,
|
| 400 |
+
lr_decay=0.9,
|
| 401 |
+
min_coef=0.001,
|
| 402 |
+
),
|
| 403 |
+
),
|
| 404 |
+
step_by_batch=True,
|
| 405 |
+
)
|
| 406 |
+
print(sche)
|
| 407 |
+
sche.plot_lr_coef_curve(
|
| 408 |
+
# save_path="/home/lart/Coding/SOD.torch",
|
| 409 |
+
show=True,
|
| 410 |
+
)
|
utils/pt_utils.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.backends import cudnn
|
| 9 |
+
|
| 10 |
+
LOGGER = logging.getLogger("main")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def customized_worker_init_fn(worker_id):
|
| 14 |
+
worker_seed = torch.initial_seed() % 2**32
|
| 15 |
+
np.random.seed(worker_seed)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def set_seed_for_lib(seed):
|
| 19 |
+
random.seed(seed)
|
| 20 |
+
np.random.seed(seed)
|
| 21 |
+
|
| 22 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 23 |
+
torch.manual_seed(seed)
|
| 24 |
+
torch.cuda.manual_seed(seed)
|
| 25 |
+
torch.cuda.manual_seed_all(seed)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def initialize_seed_cudnn(seed, deterministic):
|
| 29 |
+
assert isinstance(deterministic, bool) and isinstance(seed, int)
|
| 30 |
+
if seed >= 0:
|
| 31 |
+
LOGGER.info(f"We will use a fixed seed {seed}")
|
| 32 |
+
else:
|
| 33 |
+
seed = np.random.randint(2**32)
|
| 34 |
+
LOGGER.info(f"We will use a random seed {seed}")
|
| 35 |
+
set_seed_for_lib(seed)
|
| 36 |
+
if not deterministic:
|
| 37 |
+
LOGGER.info("We will use `torch.backends.cudnn.benchmark`")
|
| 38 |
+
else:
|
| 39 |
+
LOGGER.info("We will not use `torch.backends.cudnn.benchmark`")
|
| 40 |
+
cudnn.enabled = True
|
| 41 |
+
cudnn.benchmark = not deterministic
|
| 42 |
+
cudnn.deterministic = deterministic
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def to_device(data, device="cuda"):
|
| 46 |
+
if isinstance(data, (tuple, list)):
|
| 47 |
+
return [to_device(item, device) for item in data]
|
| 48 |
+
elif isinstance(data, dict):
|
| 49 |
+
return {name: to_device(item, device) for name, item in data.items()}
|
| 50 |
+
elif isinstance(data, torch.Tensor):
|
| 51 |
+
return data.to(device=device, non_blocking=True)
|
| 52 |
+
else:
|
| 53 |
+
raise TypeError(f"Unsupported type {type(data)}. Only support Tensor or tuple/list/dict containing Tensors.")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def frozen_bn_stats(model, freeze_affine=False):
|
| 57 |
+
"""
|
| 58 |
+
Set all the bn layers to eval mode.
|
| 59 |
+
Args:
|
| 60 |
+
model (model): model to set bn layers to eval mode.
|
| 61 |
+
"""
|
| 62 |
+
for m in model.modules():
|
| 63 |
+
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
|
| 64 |
+
m.eval()
|
| 65 |
+
if freeze_affine:
|
| 66 |
+
m.requires_grad_(False)
|
utils/py_utils.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
from collections import OrderedDict, abc
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
LOGGER = logging.getLogger("main")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def construct_path(output_dir: str, exp_name: str) -> dict:
|
| 13 |
+
proj_root = os.path.join(output_dir, exp_name)
|
| 14 |
+
exp_idx = 0
|
| 15 |
+
exp_output_dir = os.path.join(proj_root, f"exp_{exp_idx}")
|
| 16 |
+
while os.path.exists(exp_output_dir):
|
| 17 |
+
exp_idx += 1
|
| 18 |
+
exp_output_dir = os.path.join(proj_root, f"exp_{exp_idx}")
|
| 19 |
+
|
| 20 |
+
tb_path = os.path.join(exp_output_dir, "tb")
|
| 21 |
+
save_path = os.path.join(exp_output_dir, "pre")
|
| 22 |
+
pth_path = os.path.join(exp_output_dir, "pth")
|
| 23 |
+
|
| 24 |
+
final_full_model_path = os.path.join(pth_path, "checkpoint_final.pth")
|
| 25 |
+
final_state_path = os.path.join(pth_path, "state_final.pth")
|
| 26 |
+
|
| 27 |
+
log_path = os.path.join(exp_output_dir, f"log_{str(datetime.now())[:10]}.txt")
|
| 28 |
+
cfg_copy_path = os.path.join(exp_output_dir, f"config.py")
|
| 29 |
+
trainer_copy_path = os.path.join(exp_output_dir, f"trainer.txt")
|
| 30 |
+
excel_path = os.path.join(exp_output_dir, f"results.xlsx")
|
| 31 |
+
|
| 32 |
+
path_config = {
|
| 33 |
+
"output_dir": output_dir,
|
| 34 |
+
"pth_log": exp_output_dir,
|
| 35 |
+
"tb": tb_path,
|
| 36 |
+
"save": save_path,
|
| 37 |
+
"pth": pth_path,
|
| 38 |
+
"final_full_net": final_full_model_path,
|
| 39 |
+
"final_state_net": final_state_path,
|
| 40 |
+
"log": log_path,
|
| 41 |
+
"cfg_copy": cfg_copy_path,
|
| 42 |
+
"excel": excel_path,
|
| 43 |
+
"trainer_copy": trainer_copy_path,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
return path_config
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def construct_exp_name(model_name: str, cfg: dict):
|
| 50 |
+
# bs_16_lr_0.05_e30_noamp_2gpu_noms_352
|
| 51 |
+
focus_item = OrderedDict(
|
| 52 |
+
{
|
| 53 |
+
"train/batch_size": "bs",
|
| 54 |
+
"train/lr": "lr",
|
| 55 |
+
"train/num_epochs": "e",
|
| 56 |
+
"train/num_iters": "i",
|
| 57 |
+
"train/data/shape/h": "h",
|
| 58 |
+
"train/data/shape/w": "w",
|
| 59 |
+
"train/optimizer/mode": "opm",
|
| 60 |
+
"train/optimizer/group_mode": "opgm",
|
| 61 |
+
"train/scheduler/mode": "sc",
|
| 62 |
+
"train/scheduler/warmup/num_iters": "wu",
|
| 63 |
+
"train/use_amp": "amp",
|
| 64 |
+
}
|
| 65 |
+
)
|
| 66 |
+
config = copy.deepcopy(cfg)
|
| 67 |
+
|
| 68 |
+
def _format_item(_i):
|
| 69 |
+
if isinstance(_i, bool):
|
| 70 |
+
_i = "" if _i else "false"
|
| 71 |
+
elif isinstance(_i, (int, float)):
|
| 72 |
+
if _i == 0:
|
| 73 |
+
_i = "false"
|
| 74 |
+
elif isinstance(_i, (list, tuple)):
|
| 75 |
+
_i = "" if _i else "false" # 只是判断是否非空
|
| 76 |
+
elif isinstance(_i, str):
|
| 77 |
+
if "_" in _i:
|
| 78 |
+
_i = _i.replace("_", "").lower()
|
| 79 |
+
elif _i is None:
|
| 80 |
+
_i = "none"
|
| 81 |
+
# else: other types and values will be returned directly
|
| 82 |
+
return _i
|
| 83 |
+
|
| 84 |
+
if (epoch_based := config.train.get("epoch_based", None)) is not None and (not epoch_based):
|
| 85 |
+
focus_item.pop("train/num_epochs")
|
| 86 |
+
else:
|
| 87 |
+
# 默认基于epoch
|
| 88 |
+
focus_item.pop("train/num_iters")
|
| 89 |
+
|
| 90 |
+
exp_names = [model_name]
|
| 91 |
+
for key, alias in focus_item.items():
|
| 92 |
+
item = get_value_recurse(keys=key.split("/"), info=config)
|
| 93 |
+
formatted_item = _format_item(item)
|
| 94 |
+
if formatted_item == "false":
|
| 95 |
+
continue
|
| 96 |
+
exp_names.append(f"{alias.upper()}{formatted_item}")
|
| 97 |
+
|
| 98 |
+
info = config.get("info", None)
|
| 99 |
+
if info:
|
| 100 |
+
exp_names.append(f"INFO{info.lower()}")
|
| 101 |
+
|
| 102 |
+
return "_".join(exp_names)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def pre_mkdir(path_config):
|
| 106 |
+
# 提前创建好记录文件,避免自动创建的时候触发文件创建事件
|
| 107 |
+
check_mkdir(path_config["pth_log"])
|
| 108 |
+
make_log(path_config["log"], f"=== log {datetime.now()} ===")
|
| 109 |
+
|
| 110 |
+
# 提前创建好存储预测结果和存放模型的文件夹
|
| 111 |
+
check_mkdir(path_config["save"])
|
| 112 |
+
check_mkdir(path_config["pth"])
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def check_mkdir(dir_name, delete_if_exists=False):
|
| 116 |
+
if not os.path.exists(dir_name):
|
| 117 |
+
os.makedirs(dir_name)
|
| 118 |
+
else:
|
| 119 |
+
if delete_if_exists:
|
| 120 |
+
print(f"{dir_name} will be re-created!!!")
|
| 121 |
+
shutil.rmtree(dir_name)
|
| 122 |
+
os.makedirs(dir_name)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def make_log(path, context):
|
| 126 |
+
with open(path, "a") as log:
|
| 127 |
+
log.write(f"{context}\n")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def iterate_nested_sequence(nested_sequence):
|
| 131 |
+
"""
|
| 132 |
+
当前支持list/tuple/int/float/range()的多层嵌套,注意不要嵌套的太深,小心超出python默认的最大递归深度
|
| 133 |
+
|
| 134 |
+
例子
|
| 135 |
+
::
|
| 136 |
+
|
| 137 |
+
for x in iterate_nested_sequence([[1, (2, 3)], range(3, 10), 0]):
|
| 138 |
+
print(x)
|
| 139 |
+
|
| 140 |
+
1
|
| 141 |
+
2
|
| 142 |
+
3
|
| 143 |
+
3
|
| 144 |
+
4
|
| 145 |
+
5
|
| 146 |
+
6
|
| 147 |
+
7
|
| 148 |
+
8
|
| 149 |
+
9
|
| 150 |
+
0
|
| 151 |
+
|
| 152 |
+
:param nested_sequence: 多层嵌套的序列
|
| 153 |
+
:return: generator
|
| 154 |
+
"""
|
| 155 |
+
for item in nested_sequence:
|
| 156 |
+
if isinstance(item, (int, float)):
|
| 157 |
+
yield item
|
| 158 |
+
elif isinstance(item, (list, tuple, range)):
|
| 159 |
+
yield from iterate_nested_sequence(item)
|
| 160 |
+
else:
|
| 161 |
+
raise NotImplementedError
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def get_value_recurse(keys: list, info: dict):
|
| 165 |
+
curr_key, sub_keys = keys[0], keys[1:]
|
| 166 |
+
|
| 167 |
+
if (sub_info := info.get(curr_key, "NoKey")) == "NoKey":
|
| 168 |
+
raise KeyError(f"{curr_key} must be contained in {info}")
|
| 169 |
+
|
| 170 |
+
if sub_keys:
|
| 171 |
+
return get_value_recurse(keys=sub_keys, info=sub_info)
|
| 172 |
+
else:
|
| 173 |
+
return sub_info
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def mapping_to_str(mapping: abc.Mapping, *, prefix: str = " ", lvl: int = 0, max_lvl: int = 1) -> str:
|
| 177 |
+
"""
|
| 178 |
+
Print the structural information of the dict.
|
| 179 |
+
"""
|
| 180 |
+
sub_lvl = lvl + 1
|
| 181 |
+
cur_prefix = prefix * lvl
|
| 182 |
+
sub_prefix = prefix * sub_lvl
|
| 183 |
+
|
| 184 |
+
if lvl == max_lvl:
|
| 185 |
+
sub_items = str(mapping)
|
| 186 |
+
else:
|
| 187 |
+
sub_items = ["{"]
|
| 188 |
+
for k, v in mapping.items():
|
| 189 |
+
sub_item = sub_prefix + k + ": "
|
| 190 |
+
if isinstance(v, abc.Mapping):
|
| 191 |
+
sub_item += mapping_to_str(v, prefix=prefix, lvl=sub_lvl, max_lvl=max_lvl)
|
| 192 |
+
else:
|
| 193 |
+
sub_item += str(v)
|
| 194 |
+
sub_items.append(sub_item)
|
| 195 |
+
sub_items.append(cur_prefix + "}")
|
| 196 |
+
sub_items = "\n".join(sub_items)
|
| 197 |
+
return sub_items
|
utils/recorder/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
|
| 3 |
+
from .counter import TrainingCounter
|
| 4 |
+
from .group_metric_caller import GroupedMetricRecorder
|
| 5 |
+
from .logger import TBLogger
|
| 6 |
+
from .meter_recorder import AvgMeter, HistoryBuffer
|
| 7 |
+
from .visualize_results import plot_results
|
utils/recorder/counter.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TrainingCounter:
|
| 5 |
+
def __init__(self, epoch_length, epoch_based=True, *, num_epochs=None, num_total_iters=None) -> None:
|
| 6 |
+
self.num_inner_iters = epoch_length
|
| 7 |
+
self._iter_counter = 0
|
| 8 |
+
self._epoch_counter = 0
|
| 9 |
+
|
| 10 |
+
if epoch_based:
|
| 11 |
+
assert num_epochs is not None
|
| 12 |
+
self.num_epochs = num_epochs
|
| 13 |
+
self.num_total_iters = num_epochs * epoch_length
|
| 14 |
+
else:
|
| 15 |
+
assert num_total_iters is not None
|
| 16 |
+
self.num_total_iters = num_total_iters
|
| 17 |
+
self.num_epochs = math.ceil(num_total_iters / epoch_length)
|
| 18 |
+
|
| 19 |
+
def set_start_epoch(self, start_epoch):
|
| 20 |
+
self._epoch_counter = start_epoch
|
| 21 |
+
self._iter_counter = start_epoch * self.num_inner_iters
|
| 22 |
+
|
| 23 |
+
def set_start_iterations(self, start_iteration):
|
| 24 |
+
self._iter_counter = start_iteration
|
| 25 |
+
self._epoch_counter = start_iteration // self.num_inner_iters
|
| 26 |
+
|
| 27 |
+
def every_n_epochs(self, n: int) -> bool:
|
| 28 |
+
return (self._epoch_counter + 1) % n == 0 if n > 0 else False
|
| 29 |
+
|
| 30 |
+
def every_n_iters(self, n: int) -> bool:
|
| 31 |
+
return (self._iter_counter + 1) % n == 0 if n > 0 else False
|
| 32 |
+
|
| 33 |
+
def is_first_epoch(self) -> bool:
|
| 34 |
+
return self._epoch_counter == 0
|
| 35 |
+
|
| 36 |
+
def is_last_epoch(self) -> bool:
|
| 37 |
+
return self._epoch_counter == self.num_epochs - 1
|
| 38 |
+
|
| 39 |
+
def is_first_inner_iter(self) -> bool:
|
| 40 |
+
return self._iter_counter % self.num_inner_iters == 0
|
| 41 |
+
|
| 42 |
+
def is_last_inner_iter(self) -> bool:
|
| 43 |
+
return (self._iter_counter + 1) % self.num_inner_iters == 0
|
| 44 |
+
|
| 45 |
+
def is_first_total_iter(self) -> bool:
|
| 46 |
+
return self._iter_counter == 0
|
| 47 |
+
|
| 48 |
+
def is_last_total_iter(self) -> bool:
|
| 49 |
+
return self._iter_counter == self.num_total_iters - 1
|
| 50 |
+
|
| 51 |
+
def update_iter_counter(self):
|
| 52 |
+
self._iter_counter += 1
|
| 53 |
+
|
| 54 |
+
def update_epoch_counter(self):
|
| 55 |
+
self._epoch_counter += 1
|
| 56 |
+
|
| 57 |
+
def reset_iter_all_counter(self):
|
| 58 |
+
self._iter_counter = 0
|
| 59 |
+
self._epoch_counter = 0
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def curr_iter(self):
|
| 63 |
+
return self._iter_counter
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def next_iter(self):
|
| 67 |
+
return self._iter_counter + 1
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def curr_epoch(self):
|
| 71 |
+
return self._epoch_counter
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def curr_percent(self):
|
| 75 |
+
return self._iter_counter / self.num_total_iters
|
utils/recorder/group_metric_caller.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# @Time : 2021/1/4
|
| 3 |
+
# @Author : Lart Pang
|
| 4 |
+
# @GitHub : https://github.com/lartpang
|
| 5 |
+
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import py_sod_metrics
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def ndarray_to_basetype(data):
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _to_list_or_scalar(item):
|
| 16 |
+
listed_item = item.tolist()
|
| 17 |
+
if isinstance(listed_item, list) and len(listed_item) == 1:
|
| 18 |
+
listed_item = listed_item[0]
|
| 19 |
+
return listed_item
|
| 20 |
+
|
| 21 |
+
if isinstance(data, (tuple, list)):
|
| 22 |
+
results = [_to_list_or_scalar(item) for item in data]
|
| 23 |
+
elif isinstance(data, dict):
|
| 24 |
+
results = {k: _to_list_or_scalar(item) for k, item in data.items()}
|
| 25 |
+
else:
|
| 26 |
+
assert isinstance(data, np.ndarray)
|
| 27 |
+
results = _to_list_or_scalar(data)
|
| 28 |
+
return results
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def round_w_zero_padding(x, bit_width):
|
| 32 |
+
x = str(round(x, bit_width))
|
| 33 |
+
x += "0" * (bit_width - len(x.split(".")[-1]))
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
INDIVADUAL_METRIC_MAPPING = {
|
| 38 |
+
"sm": py_sod_metrics.Smeasure,
|
| 39 |
+
"wfm": py_sod_metrics.WeightedFmeasure,
|
| 40 |
+
"mae": py_sod_metrics.MAE,
|
| 41 |
+
"em": py_sod_metrics.Emeasure,
|
| 42 |
+
}
|
| 43 |
+
BINARY_CLASSIFICATION_METRIC_MAPPING = {
|
| 44 |
+
"fmeasure": {
|
| 45 |
+
"handler": py_sod_metrics.FmeasureHandler,
|
| 46 |
+
"kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=False, beta=0.3),
|
| 47 |
+
},
|
| 48 |
+
"iou": {
|
| 49 |
+
"handler": py_sod_metrics.IOUHandler,
|
| 50 |
+
"kwargs": dict(with_dynamic=True, with_adaptive=True, with_binary=False),
|
| 51 |
+
},
|
| 52 |
+
"dice": {
|
| 53 |
+
"handler": py_sod_metrics.DICEHandler,
|
| 54 |
+
"kwargs": dict(with_dynamic=True, with_adaptive=False, with_binary=False),
|
| 55 |
+
},
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ImageMetricRecorder:
|
| 60 |
+
supported_metrics = sorted(INDIVADUAL_METRIC_MAPPING.keys()) + sorted(BINARY_CLASSIFICATION_METRIC_MAPPING.keys())
|
| 61 |
+
|
| 62 |
+
def __init__(self, metric_names=("sm", "wfm", "mae", "fmeasure", "em")):
|
| 63 |
+
|
| 64 |
+
if not metric_names:
|
| 65 |
+
metric_names = self.supported_metrics
|
| 66 |
+
assert all([m in self.supported_metrics for m in metric_names]), f"Only support: {self.supported_metrics}"
|
| 67 |
+
|
| 68 |
+
self.metric_objs = {}
|
| 69 |
+
has_existed = False
|
| 70 |
+
for metric_name in metric_names:
|
| 71 |
+
if metric_name in INDIVADUAL_METRIC_MAPPING:
|
| 72 |
+
self.metric_objs[metric_name] = INDIVADUAL_METRIC_MAPPING[metric_name]()
|
| 73 |
+
else: # metric_name in BINARY_CLASSIFICATION_METRIC_MAPPING
|
| 74 |
+
if not has_existed: # only init once
|
| 75 |
+
self.metric_objs["fmeasurev2"] = py_sod_metrics.FmeasureV2()
|
| 76 |
+
has_existed = True
|
| 77 |
+
metric_handler = BINARY_CLASSIFICATION_METRIC_MAPPING[metric_name]
|
| 78 |
+
self.metric_objs["fmeasurev2"].add_handler(
|
| 79 |
+
handler_name=metric_name, metric_handler=metric_handler["handler"](**metric_handler["kwargs"])
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def step(self, pre: np.ndarray, gt: np.ndarray, gt_path: str):
|
| 83 |
+
assert pre.shape == gt.shape, (pre.shape, gt.shape, gt_path)
|
| 84 |
+
assert pre.dtype == gt.dtype == np.uint8, (pre.dtype, gt.dtype, gt_path)
|
| 85 |
+
|
| 86 |
+
for m_obj in self.metric_objs.values():
|
| 87 |
+
m_obj.step(pre, gt)
|
| 88 |
+
|
| 89 |
+
def get_all_results(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
|
| 90 |
+
sequential_results = {}
|
| 91 |
+
numerical_results = {}
|
| 92 |
+
for m_name, m_obj in self.metric_objs.items():
|
| 93 |
+
info = m_obj.get_results()
|
| 94 |
+
if m_name == "fmeasurev2":
|
| 95 |
+
for _name, results in info.items():
|
| 96 |
+
dynamic_results = results.get("dynamic")
|
| 97 |
+
adaptive_results = results.get("adaptive")
|
| 98 |
+
if dynamic_results is not None:
|
| 99 |
+
sequential_results[_name] = np.flip(dynamic_results)
|
| 100 |
+
numerical_results[f"max{_name}"] = dynamic_results.max()
|
| 101 |
+
numerical_results[f"avg{_name}"] = dynamic_results.mean()
|
| 102 |
+
if adaptive_results is not None:
|
| 103 |
+
numerical_results[f"adp{_name}"] = adaptive_results
|
| 104 |
+
else:
|
| 105 |
+
results = info[m_name]
|
| 106 |
+
if m_name in ("wfm", "sm", "mae"):
|
| 107 |
+
numerical_results[m_name] = results
|
| 108 |
+
elif m_name == "em":
|
| 109 |
+
sequential_results[m_name] = np.flip(results["curve"])
|
| 110 |
+
numerical_results.update(
|
| 111 |
+
{
|
| 112 |
+
"maxem": results["curve"].max(),
|
| 113 |
+
"avgem": results["curve"].mean(),
|
| 114 |
+
"adpem": results["adp"],
|
| 115 |
+
}
|
| 116 |
+
)
|
| 117 |
+
else:
|
| 118 |
+
raise NotImplementedError(m_name)
|
| 119 |
+
|
| 120 |
+
if num_bits is not None and isinstance(num_bits, int):
|
| 121 |
+
numerical_results = {k: v.round(num_bits) for k, v in numerical_results.items()}
|
| 122 |
+
if not return_ndarray:
|
| 123 |
+
sequential_results = ndarray_to_basetype(sequential_results)
|
| 124 |
+
numerical_results = ndarray_to_basetype(numerical_results)
|
| 125 |
+
return {"sequential": sequential_results, "numerical": numerical_results}
|
| 126 |
+
|
| 127 |
+
def show(self, num_bits: int = 3, return_ndarray: bool = False) -> dict:
|
| 128 |
+
return self.get_all_results(num_bits=num_bits, return_ndarray=return_ndarray)["numerical"]
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class GroupedMetricRecorder(object):
|
| 132 |
+
supported_metrics = ["mae", "em", "sm", "wfm"] + sorted(BINARY_CLASSIFICATION_METRIC_MAPPING.keys())
|
| 133 |
+
|
| 134 |
+
def __init__(self, group_names=None, metric_names=("sm", "wfm", "mae", "fmeasure", "em")):
|
| 135 |
+
self.group_names = group_names
|
| 136 |
+
self.metric_names = metric_names
|
| 137 |
+
self.zero()
|
| 138 |
+
|
| 139 |
+
def zero(self):
|
| 140 |
+
self.metric_recorders = {}
|
| 141 |
+
if self.group_names is not None:
|
| 142 |
+
self.metric_recorders.update(
|
| 143 |
+
{n: ImageMetricRecorder(metric_names=self.metric_names) for n in self.group_names}
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
def step(self, group_name: str, pre: np.ndarray, gt: np.ndarray, gt_path: str):
|
| 147 |
+
if group_name not in self.metric_recorders:
|
| 148 |
+
self.metric_recorders[group_name] = ImageMetricRecorder(metric_names=self.metric_names)
|
| 149 |
+
self.metric_recorders[group_name].step(pre, gt, gt_path)
|
| 150 |
+
|
| 151 |
+
def show(self, num_bits: int = 3, return_group: bool = False):
|
| 152 |
+
groups_metrics = {
|
| 153 |
+
n: r.get_all_results(num_bits=None, return_ndarray=True) for n, r in self.metric_recorders.items()
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
results = {}
|
| 157 |
+
for group_metrics in groups_metrics.values():
|
| 158 |
+
for metric_type, metric_group in group_metrics.items(): # sequential and numerical
|
| 159 |
+
results.setdefault(metric_type, {})
|
| 160 |
+
for metric_name, metric_array in metric_group.items():
|
| 161 |
+
results[metric_type].setdefault(metric_name, []).append(metric_array)
|
| 162 |
+
|
| 163 |
+
numerical_results = {}
|
| 164 |
+
for metric_type, metric_group in results.items():
|
| 165 |
+
for metric_name, metric_array in metric_group.items():
|
| 166 |
+
metric_array = np.mean(np.vstack(metric_array), axis=0) # average over all groups
|
| 167 |
+
|
| 168 |
+
if metric_name in BINARY_CLASSIFICATION_METRIC_MAPPING or metric_name == "em":
|
| 169 |
+
if metric_type == "sequential":
|
| 170 |
+
numerical_results[f"max{metric_name}"] = metric_array.max()
|
| 171 |
+
numerical_results[f"avg{metric_name}"] = metric_array.mean()
|
| 172 |
+
else:
|
| 173 |
+
if metric_type == "numerical":
|
| 174 |
+
if metric_name.startswith(("max", "avg")):
|
| 175 |
+
# these metrics (maxfm, avgfm, maxem, avgem) will be recomputed within the group
|
| 176 |
+
continue
|
| 177 |
+
numerical_results[metric_name] = metric_array
|
| 178 |
+
|
| 179 |
+
numerical_results = ndarray_to_basetype(numerical_results)
|
| 180 |
+
numerical_results = {k: round(v, num_bits) for k, v in numerical_results.items()}
|
| 181 |
+
numerical_results = self.sort_results(numerical_results)
|
| 182 |
+
if not return_group:
|
| 183 |
+
return numerical_results
|
| 184 |
+
|
| 185 |
+
group_numerical_results = {}
|
| 186 |
+
for group_name, group_metric in groups_metrics.items():
|
| 187 |
+
group_metric = {k: v.round(num_bits) for k, v in group_metric["numerical"].items()}
|
| 188 |
+
group_metric = ndarray_to_basetype(group_metric)
|
| 189 |
+
group_numerical_results[group_name] = self.sort_results(group_metric)
|
| 190 |
+
return numerical_results, group_numerical_results
|
| 191 |
+
|
| 192 |
+
def sort_results(self, results: dict) -> OrderedDict:
|
| 193 |
+
"""for a single group of metrics"""
|
| 194 |
+
sorted_results = OrderedDict()
|
| 195 |
+
all_keys = sorted(results.keys(), key=lambda item: item[::-1])
|
| 196 |
+
for name in self.metric_names:
|
| 197 |
+
for key in all_keys:
|
| 198 |
+
if key.endswith(name):
|
| 199 |
+
sorted_results[key] = results[key]
|
| 200 |
+
return sorted_results
|
utils/recorder/logger.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TBLogger:
|
| 5 |
+
def __init__(self, tb_root):
|
| 6 |
+
self.tb_root = tb_root
|
| 7 |
+
self.tb = None
|
| 8 |
+
|
| 9 |
+
def write_to_tb(self, name, data, curr_iter):
|
| 10 |
+
assert self.tb_root is not None
|
| 11 |
+
|
| 12 |
+
if self.tb is None:
|
| 13 |
+
self.tb = SummaryWriter(self.tb_root)
|
| 14 |
+
|
| 15 |
+
if not isinstance(data, (tuple, list)):
|
| 16 |
+
self.tb.add_scalar(f"data/{name}", data, curr_iter)
|
| 17 |
+
else:
|
| 18 |
+
for idx, data_item in enumerate(data):
|
| 19 |
+
self.tb.add_scalar(f"data/{name}_{idx}", data_item, curr_iter)
|
| 20 |
+
|
| 21 |
+
def close_tb(self):
|
| 22 |
+
if self.tb is not None:
|
| 23 |
+
self.tb.close()
|
utils/recorder/meter_recorder.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
from collections import deque
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AvgMeter(object):
|
| 6 |
+
__slots__ = ["value", "sum", "count"]
|
| 7 |
+
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.value = 0
|
| 10 |
+
self.sum = 0
|
| 11 |
+
self.count = 0
|
| 12 |
+
|
| 13 |
+
def reset(self):
|
| 14 |
+
self.value = 0
|
| 15 |
+
self.sum = 0
|
| 16 |
+
self.count = 0
|
| 17 |
+
|
| 18 |
+
def update(self, value, num=1):
|
| 19 |
+
self.value = value
|
| 20 |
+
self.sum += value * num
|
| 21 |
+
self.count += num
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def avg(self):
|
| 25 |
+
return self.sum / self.count
|
| 26 |
+
|
| 27 |
+
def __repr__(self) -> str:
|
| 28 |
+
return f"{self.avg:.5f}"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class HistoryBuffer:
|
| 32 |
+
"""The class tracks a series of values and provides access to the smoothed
|
| 33 |
+
value over a window or the global average / sum of the sequence.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
window_size (int): The maximal number of values that can
|
| 37 |
+
be stored in the buffer. Defaults to 20.
|
| 38 |
+
|
| 39 |
+
Example::
|
| 40 |
+
|
| 41 |
+
>>> his_buf = HistoryBuffer()
|
| 42 |
+
>>> his_buf.update(0.1)
|
| 43 |
+
>>> his_buf.update(0.2)
|
| 44 |
+
>>> his_buf.avg
|
| 45 |
+
0.15
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, window_size: int = 20) -> None:
|
| 49 |
+
self._history = deque(maxlen=window_size)
|
| 50 |
+
self._count: int = 0
|
| 51 |
+
self._sum: float = 0
|
| 52 |
+
self.reset()
|
| 53 |
+
|
| 54 |
+
def reset(self):
|
| 55 |
+
self._history.clear()
|
| 56 |
+
self._count = 0
|
| 57 |
+
self._sum = 0
|
| 58 |
+
|
| 59 |
+
def update(self, value: float, num: int = 1) -> None:
|
| 60 |
+
"""Add a new scalar value. If the length of queue exceeds ``window_size``,
|
| 61 |
+
the oldest element will be removed from the queue.
|
| 62 |
+
"""
|
| 63 |
+
self._history.append(value)
|
| 64 |
+
self._count += num
|
| 65 |
+
self._sum += value * num
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def latest(self) -> float:
|
| 69 |
+
"""The latest value of the queue."""
|
| 70 |
+
return self._history[-1]
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def avg(self) -> float:
|
| 74 |
+
"""The average over the window."""
|
| 75 |
+
if len(self._history) == 0:
|
| 76 |
+
return 0
|
| 77 |
+
else:
|
| 78 |
+
return sum(self._history) / len(self._history)
|
| 79 |
+
|
| 80 |
+
@property
|
| 81 |
+
def global_avg(self) -> float:
|
| 82 |
+
"""The global average of the queue."""
|
| 83 |
+
if self._count == 0:
|
| 84 |
+
return 0
|
| 85 |
+
else:
|
| 86 |
+
return self._sum / self._count
|
| 87 |
+
|
| 88 |
+
@property
|
| 89 |
+
def global_sum(self) -> float:
|
| 90 |
+
"""The global sum of the queue."""
|
| 91 |
+
return self._sum
|
utils/recorder/visualize_results.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import matplotlib
|
| 5 |
+
|
| 6 |
+
matplotlib.use("Agg")
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torchvision.transforms.functional as tv_tf
|
| 9 |
+
from torchvision.utils import make_grid
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def plot_results(data_container, save_path, base_size=256, is_rgb=True):
|
| 13 |
+
"""Plot the results conresponding to the batched images based on the `make_grid` method from `torchvision`.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
data_container (dict): Dict containing data you want to plot.
|
| 17 |
+
save_path (str): Path of the exported image.
|
| 18 |
+
"""
|
| 19 |
+
font_cfg = dict(fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, thickness=2)
|
| 20 |
+
|
| 21 |
+
grids = []
|
| 22 |
+
for subplot_id, (name, data) in enumerate(data_container.items()):
|
| 23 |
+
if data.ndim == 3:
|
| 24 |
+
data = data.unsqueeze(1)
|
| 25 |
+
|
| 26 |
+
grid = make_grid(data, nrow=data.shape[0], padding=2, normalize=False)
|
| 27 |
+
grid = np.array(tv_tf.to_pil_image(grid.float()))
|
| 28 |
+
h, w = grid.shape[:2]
|
| 29 |
+
ratio = base_size / h
|
| 30 |
+
grid = cv2.resize(grid, dsize=None, fx=ratio, fy=ratio, interpolation=cv2.INTER_LINEAR)
|
| 31 |
+
|
| 32 |
+
(text_w, text_h), baseline = cv2.getTextSize(text=name, **font_cfg)
|
| 33 |
+
text_xy = 20, 20 + text_h // 2 + baseline
|
| 34 |
+
cv2.putText(grid, text=name, org=text_xy, color=(255, 255, 255), **font_cfg)
|
| 35 |
+
|
| 36 |
+
grids.append(grid)
|
| 37 |
+
grids = np.concatenate(grids, axis=0) # H,W,C
|
| 38 |
+
|
| 39 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 40 |
+
|
| 41 |
+
if is_rgb:
|
| 42 |
+
grids = cv2.cvtColor(grids, cv2.COLOR_RGB2BGR)
|
| 43 |
+
cv2.imwrite(save_path, grids)
|