| import re |
| import os |
| import yaml |
| import cv2 |
| import argparse |
| import warnings |
| import numpy as np |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from easydict import EasyDict as ed |
|
|
| class Simplify(nn.Module): |
| def __init__(self, model): |
| super(Simplify, self).__init__() |
| self.model = model |
| |
| def cuda(self): |
| self.model = self.model.cuda() |
| return self |
| |
| def forward(self, x): |
| out = self.model({'image': x}) |
| return out['pred'] |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config', '-c', type=str, default='configs/InSPyReNet_SwinB.yaml') |
| parser.add_argument('--resume', '-r', action='store_true', default=False) |
| parser.add_argument('--verbose', '-v', action='store_true', default=False) |
| parser.add_argument('--debug', '-d', action='store_true', default=False) |
| args = parser.parse_args() |
| |
| cuda_visible_devices = None |
| local_rank = -1 |
|
|
| if "CUDA_VISIBLE_DEVICES" in os.environ.keys(): |
| cuda_visible_devices = [int(i) for i in os.environ["CUDA_VISIBLE_DEVICES"].split(',')] |
| if "LOCAL_RANK" in os.environ.keys(): |
| local_rank = int(os.environ["LOCAL_RANK"]) |
|
|
| if local_rank == -1: |
| device_num = 1 |
| elif cuda_visible_devices is None: |
| device_num = torch.cuda.device_count() |
| else: |
| device_num = len(cuda_visible_devices) |
|
|
| args.device_num = device_num |
| args.local_rank = local_rank |
| |
| warnings.simplefilter("ignore") |
|
|
| return args |
|
|
| def sort(x): |
| convert = lambda text: int(text) if text.isdigit() else text.lower() |
| alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] |
| return sorted(x, key=alphanum_key) |
|
|
| def load_config(config_dir, easy=True): |
| cfg = yaml.load(open(config_dir), yaml.FullLoader) |
| if easy is True: |
| cfg = ed(cfg) |
| return cfg |
|
|
| def to_cuda(sample): |
| for key in sample.keys(): |
| if type(sample[key]) == torch.Tensor: |
| sample[key] = sample[key].cuda() |
| return sample |
|
|
| def to_numpy(pred, shape): |
| pred = F.interpolate(pred, shape, mode='bilinear', align_corners=True) |
| pred = pred.data.cpu() |
| pred = pred.numpy().squeeze() |
| return pred |
|
|
| def debug_tile(deblist, size=(100, 100), activation=None): |
| debugs = [] |
| for debs in deblist: |
| debug = [] |
| for deb in debs: |
| if activation is not None: |
| deb = activation(deb) |
| log = deb.cpu().detach().numpy().squeeze() |
| log = ((log - log.min()) / (log.max() - log.min()) * 255).astype(np.uint8) |
| log = cv2.cvtColor(log, cv2.COLOR_GRAY2RGB) |
| log = cv2.resize(log, size) |
| debug.append(log) |
| debugs.append(np.vstack(debug)) |
| return np.hstack(debugs) |
|
|
|
|
| if __name__ == "__main__": |
| x = torch.rand(4, 3, 576, 576) |