|
|
| import argparse
|
| import binascii
|
| import os
|
| import os.path as osp
|
|
|
| import imageio
|
| import torch
|
| import torchvision
|
|
|
| __all__ = ['cache_video', 'cache_image', 'str2bool']
|
|
|
|
|
| def rand_name(length=8, suffix=''):
|
| name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
| if suffix:
|
| if not suffix.startswith('.'):
|
| suffix = '.' + suffix
|
| name += suffix
|
| return name
|
|
|
|
|
| def cache_video(tensor,
|
| save_file=None,
|
| fps=30,
|
| suffix='.mp4',
|
| nrow=8,
|
| normalize=True,
|
| value_range=(-1, 1),
|
| retry=5):
|
|
|
| cache_file = osp.join('/tmp', rand_name(
|
| suffix=suffix)) if save_file is None else save_file
|
|
|
|
|
| error = None
|
| for _ in range(retry):
|
| try:
|
|
|
| tensor = tensor.clamp(min(value_range), max(value_range))
|
| tensor = torch.stack([
|
| torchvision.utils.make_grid(
|
| u, nrow=nrow, normalize=normalize, value_range=value_range)
|
| for u in tensor.unbind(2)
|
| ],
|
| dim=1).permute(1, 2, 3, 0)
|
| tensor = (tensor * 255).type(torch.uint8).cpu()
|
|
|
|
|
| writer = imageio.get_writer(
|
| cache_file, fps=fps, codec='libx264', quality=8)
|
| for frame in tensor.numpy():
|
| writer.append_data(frame)
|
| writer.close()
|
| return cache_file
|
| except Exception as e:
|
| error = e
|
| continue
|
| else:
|
| print(f'cache_video failed, error: {error}', flush=True)
|
| return None
|
|
|
|
|
| def cache_image(tensor,
|
| save_file,
|
| nrow=8,
|
| normalize=True,
|
| value_range=(-1, 1),
|
| retry=5):
|
|
|
| suffix = osp.splitext(save_file)[1]
|
| if suffix.lower() not in [
|
| '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp'
|
| ]:
|
| suffix = '.png'
|
|
|
|
|
| error = None
|
| for _ in range(retry):
|
| try:
|
| tensor = tensor.clamp(min(value_range), max(value_range))
|
| torchvision.utils.save_image(
|
| tensor,
|
| save_file,
|
| nrow=nrow,
|
| normalize=normalize,
|
| value_range=value_range)
|
| return save_file
|
| except Exception as e:
|
| error = e
|
| continue
|
|
|
|
|
| def str2bool(v):
|
| """
|
| Convert a string to a boolean.
|
|
|
| Supported true values: 'yes', 'true', 't', 'y', '1'
|
| Supported false values: 'no', 'false', 'f', 'n', '0'
|
|
|
| Args:
|
| v (str): String to convert.
|
|
|
| Returns:
|
| bool: Converted boolean value.
|
|
|
| Raises:
|
| argparse.ArgumentTypeError: If the value cannot be converted to boolean.
|
| """
|
| if isinstance(v, bool):
|
| return v
|
| v_lower = v.lower()
|
| if v_lower in ('yes', 'true', 't', 'y', '1'):
|
| return True
|
| elif v_lower in ('no', 'false', 'f', 'n', '0'):
|
| return False
|
| else:
|
| raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
|
|
|