File size: 5,093 Bytes
d403233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Wavelets utilities.

References:

[Cosmos](https://github.com/nvidia-cosmos/cosmos-predict1/blob/main/cosmos_predict1/tokenizer/modules/patching.py)
"""

import math

import torch
from torch import nn


class Patcher2D(nn.Module):
    """2D discrete wavelet transform."""

    def __init__(self, patch_size=4):
        super(Patcher2D, self).__init__()
        self.rescale_factor = 2
        self.patch_size, self.num_strides = patch_size, int(math.log2(patch_size))
        wavelets1 = torch.tensor([0.7071067811865476] * 2)
        wavelets2 = wavelets1 * ((-1) ** torch.arange(2))
        self.register_buffer("wavelets1", wavelets1, persistent=False)
        self.register_buffer("wavelets2", wavelets2, persistent=False)

    def dwt(self, x) -> torch.Tensor:
        g = x.size(1)
        hl = self.wavelets1.flip(0).view(1, 1, -1).repeat(g, 1, 1)
        hh = self.wavelets2.view(1, 1, -1).repeat(g, 1, 1)
        x1, out = nn.functional.pad(x, (0, 1, 0, 1), "reflect"), []
        for w1 in (hl, hh):
            x2 = nn.functional.conv2d(x1, w1[:, :, None, :], stride=(1, 2), groups=g)
            for w2 in (hl, hh):
                out.append(nn.functional.conv2d(x2, w2[:, :, :, None], stride=(2, 1), groups=g))
        return torch.cat(out, dim=1).mul_(1 / self.rescale_factor)

    def idwt(self, x) -> torch.Tensor:
        g = x.size(1) // 4
        hl = self.wavelets1.flip([0]).view(1, 1, -1).repeat([g, 1, 1])
        hh = self.wavelets2.view(1, 1, -1).repeat(g, 1, 1)
        out = list(torch.chunk(x, 4, dim=1))
        for i in range(2):
            for j, w in enumerate((hl, hh)):
                x, w = out[i * 2 + j], w[:, :, :, None]
                out.append(nn.functional.conv_transpose2d(x, w, stride=(2, 1), groups=g))
        out = [out[i] + out[i + 1] for i in range(4, 8, 2)]
        for j, w in enumerate((hl, hh)):
            x, w = out[j], w[:, :, None, :]
            out.append(nn.functional.conv_transpose2d(x, w, stride=(1, 2), groups=g))
        return out[2].add(out[3]).mul_(self.rescale_factor)

    def forward(self, x) -> torch.Tensor:
        for _ in range(self.num_strides):
            x = self.dwt(x)
        return x


class Patcher3D(Patcher2D):
    """3D discrete wavelet transform."""

    def __init__(self, patch_size=4):
        super(Patcher3D, self).__init__(patch_size)
        self.rescale_factor = 2 * 2**0.5

    def dwt(self, x) -> torch.Tensor:
        g = x.size(1)
        hl = self.wavelets1.flip(0).view(1, 1, -1).repeat(g, 1, 1)
        hh = self.wavelets2.view(1, 1, -1).repeat(g, 1, 1)
        x1, out = nn.functional.pad(x, (0, 1, 0, 1, 0, 1), "reflect"), []
        for w1 in (hl, hh):
            x2 = nn.functional.conv3d(x1, w1[:, :, :, None, None], stride=(2, 1, 1), groups=g)
            for w2 in (hl, hh):
                x3 = nn.functional.conv3d(x2, w2[:, :, None, :, None], stride=(1, 2, 1), groups=g)
                for w3 in (hl, hh):
                    w3 = w3[:, :, None, None, :]
                    out.append(nn.functional.conv3d(x3, w3, stride=(1, 1, 2), groups=g))
        return torch.cat(out, dim=1).mul_(1.0 / self.rescale_factor)

    def idwt(self, x) -> torch.Tensor:
        g = x.size(1) // 8
        hl = self.wavelets1.flip([0]).view(1, 1, -1).repeat([g, 1, 1])
        hh = self.wavelets2.view(1, 1, -1).repeat(g, 1, 1)
        out = list(torch.chunk(x, 8, dim=1))
        for i in range(4):
            for j, w in enumerate((hl, hh)):
                x, w = out[i * 2 + j], w[:, :, None, None, :]
                out.append(nn.functional.conv_transpose3d(x, w, stride=(1, 1, 2), groups=g))
        out = [out[i] + out[i + 1] for i in range(8, 16, 2)]
        for i in range(2):
            for j, w in enumerate((hl, hh)):
                x, w = out[i * 2 + j], w[:, :, None, :, None]
                out.append(nn.functional.conv_transpose3d(x, w, stride=(1, 2, 1), groups=g))
        out = [out[i] + out[i + 1] for i in range(4, 8, 2)]
        for j, w in enumerate((hl, hh)):
            x, w = out[j], w[:, :, :, None, None]
            out.append(nn.functional.conv_transpose3d(x, w, stride=(2, 1, 1), groups=g))
        return out[2].add(out[3]).mul_(self.rescale_factor)

    def forward(self, x) -> torch.Tensor:
        x = torch.cat([x[:, :, :1].repeat_interleave(self.patch_size, 2), x[:, :, 1:]], 2)
        for _ in range(self.num_strides):
            x = self.dwt(x)
        return x