File size: 4,341 Bytes
b6ff324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Diffusion MLP."""

import torch
from torch import nn
from torch.utils.checkpoint import checkpoint as apply_ckpt

from diffnext.models.embeddings import PatchEmbed
from diffnext.models.normalization import AdaLayerNormZero


class Projector(nn.Module):
    """MLP Projector layer."""

    def __init__(self, dim, mlp_dim=None, out_dim=None):
        super(Projector, self).__init__()
        self.fc1 = nn.Linear(dim, mlp_dim or dim)
        self.fc2 = nn.Linear(mlp_dim or dim, out_dim or dim)
        self.activation = nn.SiLU()

    def forward(self, x) -> torch.Tensor:
        return self.fc2(self.activation(self.fc1(x)))


class DiffusionBlock(nn.Module):
    """Diffusion block."""

    def __init__(self, dim):
        super(DiffusionBlock, self).__init__()
        self.dim, self.mlp_checkpointing = dim, False
        self.norm1 = AdaLayerNormZero(dim, num_stats=3, eps=1e-6)
        self.proj, self.norm2 = Projector(dim, dim, dim), nn.LayerNorm(dim)

    def forward(self, x, z) -> torch.Tensor:
        if self.mlp_checkpointing and x.requires_grad:
            h, (gate,) = apply_ckpt(self.norm1, x, z, use_reentrant=False)
            return self.norm2(apply_ckpt(self.proj, h, use_reentrant=False)).mul(gate).add_(x)
        h, (gate,) = self.norm1(x, z)
        return self.norm2(self.proj(h)).mul(gate).add_(x)


class TimeCondEmbed(nn.Module):
    """Time-Condition embedding layer."""

    def __init__(self, cond_dim, embed_dim, freq_dim=256):
        super(TimeCondEmbed, self).__init__()
        self.timestep_proj = Projector(freq_dim, embed_dim, embed_dim)
        self.condition_proj = Projector(cond_dim, embed_dim, embed_dim)
        self.freq_dim, self.time_freq = freq_dim, None

    def get_freq_embed(self, timestep, dtype) -> torch.Tensor:
        if self.time_freq is None:
            dim, log_theta = self.freq_dim // 2, 9.210340371976184  # math.log(10000)
            freq = torch.arange(dim, dtype=torch.float32, device=timestep.device)
            self.time_freq = freq.mul(-log_theta / dim).exp().unsqueeze(0)
        emb = timestep.unsqueeze(-1).float() * self.time_freq
        return torch.cat([emb.cos(), emb.sin()], dim=-1).to(dtype=dtype)

    def forward(self, timestep, z) -> torch.Tensor:
        t = self.timestep_proj(self.get_freq_embed(timestep, z.dtype))
        return self.condition_proj(z).add_(t.unsqueeze_(1) if t.dim() == 2 else t)


class DiffusionMLP(nn.Module):
    """Diffusion MLP model."""

    def __init__(self, depth, embed_dim, cond_dim, patch_size=2, image_dim=4):
        super(DiffusionMLP, self).__init__()
        self.patch_embed = PatchEmbed(image_dim, embed_dim, patch_size)
        self.time_cond_embed = TimeCondEmbed(cond_dim, embed_dim)
        self.blocks = nn.ModuleList(DiffusionBlock(embed_dim) for _ in range(depth))
        self.norm = AdaLayerNormZero(embed_dim, num_stats=2, eps=1e-6)
        self.head = nn.Linear(embed_dim, patch_size**2 * image_dim)

    def forward(self, x, timestep, z, pred_ids=None) -> torch.Tensor:
        x, o = self.patch_embed(x), None if pred_ids is None else x
        o = None if pred_ids is None else self.patch_embed.patchify(o)
        x = x if pred_ids is None else x.gather(1, pred_ids.expand(-1, -1, x.size(-1)))
        z = z if pred_ids is None else z.gather(1, pred_ids.expand(-1, -1, z.size(-1)))
        z = self.time_cond_embed(timestep, z)
        for blk in self.blocks:
            x = blk(x, z)
        x = self.norm(x, z)[0]
        x = self.head(x)
        return x if pred_ids is None else o.scatter(1, pred_ids.expand(-1, -1, x.size(-1)), x)