| | |
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import logging |
| | import argparse |
| | import json |
| | import safetensors.torch |
| | import os |
| | import sys |
| | from pathlib import Path |
| | from typing import Any, ContextManager, cast |
| | from torch import Tensor |
| |
|
| | import numpy as np |
| | import torch |
| | import gguf |
| |
|
| | |
| | SUPPORTED_ARCHS = ["flux", "sd3", "ltxv", "hyvid", "wan", "hidream", "qwen"] |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class QuantConfig: |
| | ftype: gguf.LlamaFileType |
| | qtype: gguf.GGMLQuantizationType |
| |
|
| | def __init__(self, ftype: gguf.LlamaFileType, qtype: gguf.GGMLQuantizationType): |
| | self.ftype = ftype |
| | self.qtype = qtype |
| |
|
| |
|
| | qconfig_map: dict[str, QuantConfig] = { |
| | "F16": QuantConfig(gguf.LlamaFileType.MOSTLY_F16, gguf.GGMLQuantizationType.F16), |
| | "BF16": QuantConfig(gguf.LlamaFileType.MOSTLY_BF16, gguf.GGMLQuantizationType.BF16), |
| | "Q8_0": QuantConfig(gguf.LlamaFileType.MOSTLY_Q8_0, gguf.GGMLQuantizationType.Q8_0), |
| | "Q6_K": QuantConfig(gguf.LlamaFileType.MOSTLY_Q6_K, gguf.GGMLQuantizationType.Q6_K), |
| | "Q5_K_S": QuantConfig(gguf.LlamaFileType.MOSTLY_Q5_K_S, gguf.GGMLQuantizationType.Q5_K), |
| | "Q5_1": QuantConfig(gguf.LlamaFileType.MOSTLY_Q5_1, gguf.GGMLQuantizationType.Q5_1), |
| | "Q5_0": QuantConfig(gguf.LlamaFileType.MOSTLY_Q5_0, gguf.GGMLQuantizationType.Q5_0), |
| | "Q4_K_S": QuantConfig(gguf.LlamaFileType.MOSTLY_Q4_K_S, gguf.GGMLQuantizationType.Q4_K), |
| | "Q4_1": QuantConfig(gguf.LlamaFileType.MOSTLY_Q4_1, gguf.GGMLQuantizationType.Q4_1), |
| | "Q4_0": QuantConfig(gguf.LlamaFileType.MOSTLY_Q4_0, gguf.GGMLQuantizationType.Q4_0), |
| | "Q3_K_S": QuantConfig(gguf.LlamaFileType.MOSTLY_Q3_K_S, gguf.GGMLQuantizationType.Q3_K), |
| | |
| | } |
| |
|
| |
|
| | |
| | class LazyTorchTensor(gguf.LazyBase): |
| | _tensor_type = torch.Tensor |
| | |
| | dtype: torch.dtype |
| | shape: torch.Size |
| |
|
| | |
| | _dtype_map: dict[torch.dtype, type] = { |
| | torch.float16: np.float16, |
| | torch.float32: np.float32, |
| | } |
| |
|
| | |
| | |
| | |
| | _dtype_str_map: dict[str, torch.dtype] = { |
| | "F64": torch.float64, |
| | "F32": torch.float32, |
| | "BF16": torch.bfloat16, |
| | "F16": torch.float16, |
| | |
| | "I64": torch.int64, |
| | |
| | "I32": torch.int32, |
| | |
| | "I16": torch.int16, |
| | "U8": torch.uint8, |
| | "I8": torch.int8, |
| | "BOOL": torch.bool, |
| | "F8_E4M3": torch.float8_e4m3fn, |
| | "F8_E5M2": torch.float8_e5m2, |
| | } |
| |
|
| | def numpy(self) -> gguf.LazyNumpyTensor: |
| | dtype = self._dtype_map[self.dtype] |
| | return gguf.LazyNumpyTensor( |
| | meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), |
| | args=(self,), |
| | func=(lambda s: s.numpy()), |
| | ) |
| |
|
| | @classmethod |
| | def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor: |
| | return torch.empty(size=shape, dtype=dtype, device="meta") |
| |
|
| | @classmethod |
| | def from_safetensors_slice(cls, st_slice: Any) -> Tensor: |
| | dtype = cls._dtype_str_map[st_slice.get_dtype()] |
| | shape: tuple[int, ...] = tuple(st_slice.get_shape()) |
| | lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) |
| | return cast(torch.Tensor, lazy) |
| |
|
| | @classmethod |
| | def __torch_function__(cls, func, types, args=(), kwargs=None): |
| | del types |
| |
|
| | if kwargs is None: |
| | kwargs = {} |
| |
|
| | if func is torch.Tensor.numpy: |
| | return args[0].numpy() |
| |
|
| | return cls._wrap_fn(func)(*args, **kwargs) |
| |
|
| |
|
| | class Converter: |
| | path_safetensors: Path |
| | endianess: gguf.GGUFEndian |
| | outtype: QuantConfig |
| | outfile: Path |
| | gguf_writer: gguf.GGUFWriter |
| |
|
| | def __init__( |
| | self, |
| | arch: str, |
| | path_safetensors: Path, |
| | endianess: gguf.GGUFEndian, |
| | outtype: QuantConfig, |
| | outfile: Path, |
| | subfolder: str = None, |
| | repo_id: str = None, |
| | is_diffusers: bool = False, |
| | ): |
| | self.path_safetensors = path_safetensors |
| | self.endianess = endianess |
| | self.outtype = outtype |
| | self.outfile = outfile |
| |
|
| | self.gguf_writer = gguf.GGUFWriter(path=None, arch=arch, endianess=self.endianess) |
| | self.gguf_writer.add_file_type(self.outtype.ftype) |
| | self.gguf_writer.add_type("diffusion") |
| | if repo_id: |
| | self.gguf_writer.add_string("repo_id", repo_id) |
| | if subfolder: |
| | self.gguf_writer.add_string("subfolder", subfolder) |
| | if is_diffusers: |
| | self.gguf_writer.add_bool("is_diffusers", True) |
| |
|
| | |
| | from safetensors import safe_open |
| |
|
| | ctx = cast(ContextManager[Any], safe_open(path_safetensors, framework="pt", device="cpu")) |
| | with ctx as model_part: |
| | for name in model_part.keys(): |
| | data = model_part.get_slice(name) |
| | data = LazyTorchTensor.from_safetensors_slice(data) |
| | self.process_tensor(name, data) |
| |
|
| | def process_tensor(self, name: str, data_torch: LazyTorchTensor) -> None: |
| | is_1d = len(data_torch.shape) == 1 |
| | current_dtype = data_torch.dtype |
| | target_dtype = gguf.GGMLQuantizationType.F32 if is_1d else self.outtype.qtype |
| |
|
| | if data_torch.dtype not in (torch.float16, torch.float32): |
| | data_torch = data_torch.to(torch.float32) |
| |
|
| | data = data_torch.numpy() |
| |
|
| | if current_dtype != target_dtype: |
| | from custom_quants import quantize as custom_quantize, QuantError |
| |
|
| | try: |
| | data = custom_quantize(data, target_dtype) |
| | except QuantError as e: |
| | logger.warning("%s, %s", e, "falling back to F16") |
| | target_dtype = gguf.GGMLQuantizationType.F16 |
| | data = custom_quantize(data, target_dtype) |
| |
|
| | |
| | shape = gguf.quant_shape_from_byte_shape(data.shape, target_dtype) if data.dtype == np.uint8 else data.shape |
| | shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" |
| | logger.info(f"{f'%-32s' % f'{name},'} {current_dtype} --> {target_dtype.name}, shape = {shape_str}") |
| |
|
| | |
| | self.gguf_writer.add_tensor(name, data, raw_dtype=target_dtype) |
| |
|
| | def write(self) -> None: |
| | self.gguf_writer.write_header_to_file(path=self.outfile) |
| | self.gguf_writer.write_kv_data_to_file() |
| | self.gguf_writer.write_tensors_to_file(progress=True) |
| | self.gguf_writer.close() |
| |
|
| |
|
| | |
| | def _merge_sharded_checkpoints(folder: Path): |
| | with open(folder / "diffusion_pytorch_model.safetensors.index.json", "r") as f: |
| | ckpt_metadata = json.load(f) |
| | weight_map = ckpt_metadata.get("weight_map", None) |
| | if weight_map is None: |
| | raise KeyError("'weight_map' key not found in the shard index file.") |
| |
|
| | |
| | files_to_load = set(weight_map.values()) |
| | merged_state_dict = {} |
| |
|
| | |
| | for file_name in files_to_load: |
| | part_file_path = folder / file_name |
| | if not os.path.exists(part_file_path): |
| | raise FileNotFoundError(f"Part file {file_name} not found.") |
| |
|
| | with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: |
| | for tensor_key in f.keys(): |
| | if tensor_key in weight_map: |
| | merged_state_dict[tensor_key] = f.get_tensor(tensor_key) |
| |
|
| | return merged_state_dict |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser(description="Convert a flux model to GGUF") |
| | parser.add_argument( |
| | "--outfile", |
| | type=Path, |
| | default=Path("model-{ftype}.gguf"), |
| | help="path to write to; default: 'model-{ftype}.gguf' ; note: {ftype} will be replaced by the outtype", |
| | ) |
| | parser.add_argument( |
| | "--outtype", |
| | type=str, |
| | choices=qconfig_map.keys(), |
| | default="F16", |
| | help="output quantization scheme", |
| | ) |
| | parser.add_argument( |
| | "--arch", |
| | type=str, |
| | choices=SUPPORTED_ARCHS, |
| | help="output model architecture", |
| | ) |
| | parser.add_argument( |
| | "--bigendian", |
| | action="store_true", |
| | help="model is executed on big endian machine", |
| | ) |
| | parser.add_argument( |
| | "model", |
| | type=Path, |
| | help="directory containing safetensors model file", |
| | nargs="?", |
| | ) |
| | parser.add_argument("--cache_dir", type=Path, help="Directory to store the intermediate files when needed.") |
| | parser.add_argument( |
| | "--subfolder", type=Path, default=None, help="Subfolder on the HF Hub to load checkpoints from." |
| | ) |
| | parser.add_argument( |
| | "--verbose", |
| | action="store_true", |
| | help="increase output verbosity", |
| | ) |
| |
|
| | args = parser.parse_args() |
| | if args.model is None: |
| | parser.error("the following arguments are required: model") |
| | if args.arch is None: |
| | parser.error("the following arguments are required: --arch") |
| | if args.arch not in SUPPORTED_ARCHS: |
| | parser.error(f"Unsupported architecture: {args.arch}. Supported architectures: {', '.join(SUPPORTED_ARCHS)}") |
| | return args |
| |
|
| |
|
| | def convert(args): |
| | if args.verbose: |
| | logging.basicConfig(level=logging.DEBUG) |
| | else: |
| | logging.basicConfig(level=logging.INFO) |
| |
|
| | if not args.model.is_dir() and not args.model.is_file(): |
| | if not len(str(args.model).split("/")) == 2: |
| | logging.error(f"Model path {args.model} does not exist.") |
| | sys.exit(1) |
| |
|
| | is_diffusers = False |
| | repo_id = None |
| | merged_state_dict = None |
| | if args.model.is_dir(): |
| | logging.info("Supplied a directory.") |
| | files = list(args.model.glob("*.safetensors")) |
| | n = len(files) |
| | if n == 0: |
| | logging.error("No safetensors files found.") |
| | sys.exit(1) |
| | if n == 1: |
| | logging.info(f"Assinging {files[0]} to `args.model`") |
| | args.model = files[0] |
| | if n > 1: |
| | assert args.model / "diffusion_pytorch_model.safetensors.index.json" in list(args.model.glob("*.*")) |
| | assert args.cache_dir |
| | merged_state_dict = _merge_sharded_checkpoints(args.model) |
| | filepath = args.cache_dir / "merged_state_dict.safetensors" |
| | safetensors.torch.save_file(merged_state_dict, filepath) |
| | logging.info(f"Serialized merged state dict to {filepath}") |
| | args.model = Path(filepath) |
| |
|
| | elif len(str(args.model).split("/")) == 2: |
| | from huggingface_hub import snapshot_download |
| |
|
| | logging.info("Hub repo ID detected.") |
| | allow_patterns = f"{args.subfolder}/*.*" if args.subfolder else None |
| | local_dir = snapshot_download( |
| | repo_id=str(args.model), local_dir=args.cache_dir, allow_patterns=allow_patterns, token=args.hf_token |
| | ) |
| | repo_id = str(args.model) |
| | local_dir = Path(local_dir) |
| | local_dir = local_dir / args.subfolder if args.subfolder else local_dir |
| | merged_state_dict = _merge_sharded_checkpoints(local_dir) |
| | filepath = ( |
| | args.cache_dir / "merged_state_dict.safetensors" if args.cache_dir else "merged_state_dict.safetensors" |
| | ) |
| | safetensors.torch.save_file(merged_state_dict, filepath) |
| | logging.info(f"Serialized merged state dict to {filepath}") |
| | args.model = Path(filepath) |
| | is_diffusers = True |
| |
|
| | if args.model.suffix != ".safetensors": |
| | logging.error(f"Model path {args.model} is not a safetensors file.") |
| | sys.exit(1) |
| |
|
| | if args.outfile.suffix != ".gguf": |
| | logging.error("Output file must have .gguf extension.") |
| | sys.exit(1) |
| |
|
| | qconfig = qconfig_map[args.outtype] |
| | outfile = Path(str(args.outfile).format(ftype=args.outtype.upper())) |
| |
|
| | logger.info(f"Converting model in {args.model} to {outfile} with quantization {args.outtype}") |
| | converter = Converter( |
| | arch=args.arch, |
| | path_safetensors=args.model, |
| | endianess=gguf.GGUFEndian.BIG if args.bigendian else gguf.GGUFEndian.LITTLE, |
| | outtype=qconfig, |
| | outfile=outfile, |
| | repo_id=repo_id, |
| | subfolder=str(args.subfolder) if args.subfolder else None, |
| | is_diffusers=is_diffusers, |
| | ) |
| | converter.write() |
| | logger.info( |
| | f"Conversion complete. Output written to {outfile}, architecture: {args.arch}, quantization: {qconfig.qtype.name}" |
| | ) |
| | if merged_state_dict is not None: |
| | os.remove(filepath) |
| | logging.info(f"Removed the intermediate {filepath}.") |
| |
|