| from safetensors import safe_open |
| import torch, os |
|
|
|
|
| class SafetensorsCompatibleTensor: |
| def __init__(self, tensor): |
| self.tensor = tensor |
| |
| def get_shape(self): |
| return list(self.tensor.shape) |
|
|
|
|
| class SafetensorsCompatibleBinaryLoader: |
| def __init__(self, path, device): |
| print("Detected non-safetensors files, which may cause slower loading. It's recommended to convert it to a safetensors file.") |
| self.state_dict = torch.load(path, weights_only=True, map_location=device) |
| |
| def keys(self): |
| return self.state_dict.keys() |
| |
| def get_tensor(self, name): |
| return self.state_dict[name] |
| |
| def get_slice(self, name): |
| return SafetensorsCompatibleTensor(self.state_dict[name]) |
|
|
|
|
| class DiskMap: |
|
|
| def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9): |
| self.path = path if isinstance(path, list) else [path] |
| self.device = device |
| self.torch_dtype = torch_dtype |
| if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None: |
| self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE')) |
| else: |
| self.buffer_size = buffer_size |
| self.files = [] |
| self.flush_files() |
| self.name_map = {} |
| for file_id, file in enumerate(self.files): |
| for name in file.keys(): |
| self.name_map[name] = file_id |
| self.rename_dict = self.fetch_rename_dict(state_dict_converter) |
| |
| def flush_files(self): |
| if len(self.files) == 0: |
| for path in self.path: |
| if path.endswith(".safetensors"): |
| self.files.append(safe_open(path, framework="pt", device=str(self.device))) |
| else: |
| self.files.append(SafetensorsCompatibleBinaryLoader(path, device=self.device)) |
| else: |
| for i, path in enumerate(self.path): |
| if path.endswith(".safetensors"): |
| self.files[i] = safe_open(path, framework="pt", device=str(self.device)) |
| self.num_params = 0 |
|
|
| def __getitem__(self, name): |
| if self.rename_dict is not None: name = self.rename_dict[name] |
| file_id = self.name_map[name] |
| param = self.files[file_id].get_tensor(name) |
| if self.torch_dtype is not None and isinstance(param, torch.Tensor): |
| param = param.to(self.torch_dtype) |
| if isinstance(param, torch.Tensor) and param.device == "cpu": |
| param = param.clone() |
| if isinstance(param, torch.Tensor): |
| self.num_params += param.numel() |
| if self.num_params > self.buffer_size: |
| self.flush_files() |
| return param |
|
|
| def fetch_rename_dict(self, state_dict_converter): |
| if state_dict_converter is None: |
| return None |
| state_dict = {} |
| for file in self.files: |
| for name in file.keys(): |
| state_dict[name] = name |
| state_dict = state_dict_converter(state_dict) |
| return state_dict |
| |
| def __iter__(self): |
| if self.rename_dict is not None: |
| return self.rename_dict.__iter__() |
| else: |
| return self.name_map.__iter__() |
| |
| def __contains__(self, x): |
| if self.rename_dict is not None: |
| return x in self.rename_dict |
| else: |
| return x in self.name_map |
|
|