File size: 5,151 Bytes
d403233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Pipeline builders."""

from typing import Dict

import json
import os
import tempfile

import torch

from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffnext.utils.omegaconf_utils import OmegaConfEncoder


def get_pipeline_path(
    pretrained_path,
    module_dict: dict = None,
    module_config: Dict[str, dict] = None,
    target_path: str = None,
) -> str:
    """Return the pipeling loading path.

    Args:
        pretrained_path (str)
            The pretrained path to load pipeline.
        module_dict (dict, *optional*)
            The path dict to load custom modules.
        module_config (Dict[str, dict], *optional*)
            The custom configurations to dump into ``config.json``.
        target_path (str, *optional*)
            The path to store custom modules and configs.

    Returns:
       str: The pipeline loading path.
    """
    if module_dict is None and module_config is None:
        return pretrained_path
    target_path = target_path or tempfile.mkdtemp()
    for k in os.listdir(pretrained_path):
        if not os.path.isdir(os.path.join(pretrained_path, k)):
            continue
        os.makedirs(os.path.join(target_path, k), exist_ok=True)
        for _ in os.listdir(os.path.join(pretrained_path, k)):
            try:
                os.symlink(os.path.join(pretrained_path, k, _), os.path.join(target_path, k, _))
            except FileExistsError:  # Some components may be provided.
                pass
    module_dict = module_dict.copy() if module_dict is not None else {}
    model_index = module_dict.pop("model_index", os.path.join(pretrained_path, "model_index.json"))
    model_index = json.load(open(model_index))
    for k, v in module_dict.items():
        model_index.pop(k) if not v else None
        try:
            os.symlink(v, os.path.join(target_path, k)) if v else None
        except FileExistsError:  # Some components may be provided.
            pass
    for k, v in (module_config or {}).items():
        config_file = os.path.join(target_path, k, "config.json")
        os.remove(config_file) if v and os.path.exists(config_file) else None
        json.dump(v, open(config_file, "w"), cls=OmegaConfEncoder) if v else None
    json.dump(model_index, open(os.path.join(target_path, "model_index.json"), "w"))
    return target_path


def build_diffusion_scheduler(scheduler_path, sample=False, **kwargs) -> SchedulerMixin:
    """Create a diffusion scheduler instance.

    Args:
        scheduler_path (str or scheduler instance)
            The path to load a diffusion scheduler.
        sample (bool, *optional*, default to False)
            Whether to create the sampling-specific scheduler.

    Returns:
        SchedulerMixin: The diffusion scheduler.
    """
    from diffnext.schedulers.scheduling_cfm import FlowMatchEulerDiscreteScheduler  # noqa
    from diffnext.schedulers.scheduling_ddpm import DDPMScheduler

    if isinstance(scheduler_path, str):
        class_key = "_{}_class_name".format("sample" if sample else "noise")
        class_type = locals()[DDPMScheduler.load_config(**locals())[class_key]]
        return class_type.from_pretrained(**locals())
    elif hasattr(scheduler_path, "config"):
        class_type = locals()[type(scheduler_path).__name__]
        return class_type.from_config(scheduler_path.config)
    return None


def build_pipeline(pretrained_path, pipe_cls, dtype=torch.float16, **kwargs) -> DiffusionPipeline:
    """Create a diffnext pipeline instance.

    Examples:
        ```py
        >>> from diffnext.pipelines import NOVAPipeline
        >>> from diffnext.pipelines.builder import build_pipeline
        >>> pipe = build_pipeline("BAAI/nova-d48w768-sdxl1024", NOVAPipeline)

    Args:
        pretrained_path (str):
            The model path that includes ``model_index.json`` to create pipeline.
        pipe_cls (object)
            The pipeline class object that defines the ``from_pretrained`` method.
        dtype (torch.dtype, *optional*, default to ``torch.float16``)
            The compute dtype used for all pipeline components.

    Returns:
        DiffusionPipeline: The diffusion pipeline.
    """
    kwargs.setdefault("trust_remote_code", True)
    kwargs.setdefault("torch_dtype", dtype)
    return pipe_cls.from_pretrained(pretrained_path, **kwargs)