File size: 15,341 Bytes
c8c0ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7ab0ec
c8c0ef5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c298f3c
c8c0ef5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
"""
Main model for using CodecLM. This will combine all the required components
and provide easy access to the generation API.
"""

import typing as tp
import warnings

import torch

from codeclm.tokenizer.audio_tokenizer import AudioTokenizer
# from .lm_llama import LMModel
from ..utils.autocast import TorchAutocast
import torch
from torch.nn import functional as F
import torchaudio
# from optim.ema import EMA
from codeclm.utils.utils import dict_from_config
from codeclm.modules.pattern import (
    CodebooksPatternProvider,
    DelayedPatternProvider,
)
from codeclm.modules.conditioners import (
    ConditioningAttributes,
    AudioCondition,
    BaseConditioner,
    QuantizedEmbeddingConditioner,
    ConditionerProvider,
    ConditionFuser,
    QwTextConditioner,
    QwTokenizerConditioner,
    ClassifierFreeGuidanceDropoutInference,
)
import omegaconf

def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig, version: str = 'v1.0') -> ConditionerProvider:
    """Instantiate a conditioning model."""    
    cfg = getattr(cfg, 'conditioners')
    dict_cfg = {} if cfg is None else dict_from_config(cfg)
    conditioners: tp.Dict[str, BaseConditioner] = {}
    condition_provider_args = dict_cfg.pop('args', {})

    for cond, cond_cfg in dict_cfg.items():
        model_type = cond_cfg['model']
        model_args = cond_cfg[model_type]
        if model_type == 'QwTokenizer':
            conditioners[str(cond)] = QwTokenizerConditioner(
                output_dim=output_dim,
                **model_args
            )
        elif model_type == "QwTextTokenizer":
            conditioners[str(cond)] = QwTextConditioner(
                output_dim=output_dim,
                version=version,
                **model_args
            )
        elif model_type == "qt_embedding":
            conditioners[str(cond)] = QuantizedEmbeddingConditioner(
                dim=output_dim,
                **model_args
            )
        else:
            raise ValueError(f"Unrecognized conditioning model: {model_type}")
    conditioner = ConditionerProvider(conditioners, **condition_provider_args)
    return conditioner

def get_codebooks_pattern_provider(code_depth: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
    """Instantiate a codebooks pattern provider object."""
    pattern_providers = {
        'delay': DelayedPatternProvider,
    }
    name = cfg.modeling
    kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
    klass = pattern_providers[name]
    return klass(code_depth, **kwargs)

MelodyList = tp.List[tp.Optional[torch.Tensor]]
MelodyType = tp.Union[torch.Tensor, MelodyList]

def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
    """Instantiate a condition fuser object."""
    fuser_cfg = getattr(cfg, 'fuser')
    fuser_methods = ['sum', 'prepend']
    fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
    kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
    fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
    return fuser

class CodecLM_gen:
    """CodecLM main model with convenient generation API.

    Args:
        name (str): name of the model.
        compression_model (CompressionModel): Compression model
            used to map audio to invertible discrete representations.
        lm (LMModel): Language model over discrete representations.
        max_duration (float, optional): maximum duration the model can produce,
            otherwise, inferred from the training params.
    """
    def __init__(self, cfg, name: str, audiotokenizer: AudioTokenizer, 
                 max_duration: tp.Optional[float] = None):
        self.cfg = cfg
        self.name = name
        self.audiotokenizer = audiotokenizer
        self.seperate_tokenizer = None
        if max_duration is None:
            max_duration = self.cfg.max_dur
        assert max_duration is not None

        self.max_duration: float = max_duration
        # self.device = next(iter(lm.parameters())).device
        # self.device = next(iter(audiotokenizer.parameters())).device
        self.generation_params: dict = {}
        # self.set_generation_params(duration=15)  # 15 seconds by default
        self.set_generation_params(duration=15, extend_stride=self.max_duration // 2)
        self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
        self.autocast = TorchAutocast(enabled=False)        
        self.condition_provider = get_conditioner_provider(cfg.lm.dim, self.cfg)
        codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
        self.pattern_provider = get_codebooks_pattern_provider(cfg.lm.code_depth, codebooks_pattern_cfg)
        self.fuser = get_condition_fuser(cfg)
        self.eos_token_id = cfg.lm.code_size



    @property
    def frame_rate(self) -> float:
        """Roughly the number of AR steps per seconds."""
        return self.audiotokenizer.frame_rate

    @property
    def sample_rate(self) -> int:
        """Sample rate of the generated audio."""
        return self.audiotokenizer.sample_rate

    @property
    def audio_channels(self) -> int:
        """Audio channels of the generated audio."""
        return self.audiotokenizer.channels

    def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
                              top_p: float = 0.0, temperature: float = 1.0,
                              duration: float = 30.0, cfg_coef: float = 3.0,
                             extend_stride: float = 18, record_tokens: bool = False,
                             record_window: int = 50):
        """Set the generation parameters for CodecLM.

        Args:
            use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
            top_k (int, optional): top_k used for sampling. Defaults to 250.
            top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
            temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
            duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
            cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
            two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
                instead of batching together the two. This has some impact on how things
                are padded but seems to have little impact in practice.
            extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
                should we extend the audio each time. Larger values will mean less context is
                preserved, and shorter value will require extra computations.
        """
        assert extend_stride <= self.max_duration, "Cannot stride by more than max generation duration."
        self.extend_stride = extend_stride
        self.duration = duration
        self.generation_params = {
            'use_sampling': use_sampling,
            'temp': temperature,
            'top_k': top_k,
            'top_p': top_p,
            'cfg_coef': cfg_coef,
            'record_tokens': record_tokens,
            'record_window': record_window,
        }

    def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
        """Override the default progress callback."""
        self._progress_callback = progress_callback

    # Inference
    def generate_condition(self, descriptions: tp.List[str],
                            melody_wavs: torch.Tensor = None, 
                            return_tokens: bool = False,
                            melody_is_wav: bool = True,
                            type_info: tp.List[str] = None,
                            embeded_eosp1: torch.Tensor = None,
                            ) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
        if melody_wavs is not None:
            if melody_wavs.dim() == 2:
                melody_wavs = melody_wavs[None]
            if melody_wavs.dim() != 3:
                raise ValueError("Melody wavs should have a shape [B, C, T].")
            melody_wavs = list(melody_wavs)
        
            # if melody_is_wav:
            #     melody_wavs = [wav.mean(dim=-2) for wav in melody_wavs]
            
        texts, audio_qt_embs = self._prepare_tokens_and_attributes(descriptions=descriptions,
                                                                        melody_wavs=melody_wavs,
                                                                        melody_is_wav=melody_is_wav)
        fused_input = self.get_condition_tensors(texts, audio_qt_embs, type_info, embeded_eosp1)

        return fused_input, audio_qt_embs


    @torch.no_grad()
    def _prepare_tokens_and_attributes(
            self,
            descriptions: tp.Sequence[tp.Optional[str]],
            melody_wavs: tp.Optional[MelodyList] = None,
            melody_is_wav = True
    ) -> tp.Tuple[tp.List[str], tp.List[torch.Tensor]]:
        """Prepare model inputs.

        Args:
            descriptions (list of str): A list of strings used as text conditioning.
            prompt (torch.Tensor): A batch of waveforms used for continuation.
            melody_wavs (torch.Tensor, optional): A batch of waveforms
                used as melody conditioning. Defaults to None.
        """
        texts = [description for description in descriptions]
        audio_qt_embs = []

        if melody_wavs is None:
            audio_qt_embs = None
        elif melody_wavs is not None:
            if 'prompt_audio' not in self.condition_provider.conditioners:
                raise RuntimeError("This model doesn't support melody conditioning. "
                                   "Use the `melody` model.")
            assert len(melody_wavs) == len(texts), \
                f"number of melody wavs must match number of descriptions! " \
                f"got melody len={len(melody_wavs)}, and descriptions len={len(texts)}"
            if type(melody_wavs) == list:
                melody_wavs = torch.stack(melody_wavs, dim=0)
            # melody_wavs = melody_wavs.to(self.device)
            if melody_is_wav:
                melody_tokens, scale = self.audiotokenizer.encode(melody_wavs)
            else:
                melody_tokens = melody_wavs
            target_melody_token_len = self.cfg.prompt_len * self.audiotokenizer.frame_rate
            if melody_tokens.shape[-1] > target_melody_token_len:
                melody_tokens = melody_tokens[...,:target_melody_token_len]
            for melody in melody_tokens:
                audio_qt_embs.append(melody.long())
        return texts, audio_qt_embs

    @torch.no_grad()
    def prepare_condition_tensors(self,
                                   batch_size = 1,
                                   text: tp.Optional[tp.List[str]] = None, 
                                   audio_qt_emb: tp.Optional[tp.List[torch.Tensor]] = None,
                                   type_info: tp.Optional[tp.List[str]] = None,
                                   prepare_null_condition = False,
                                   ):
        conditions = []
        for i in range(batch_size):
            attr = ConditioningAttributes()
            if 'description' in self.condition_provider.conditioners:
                attr["text"]["description"] = ""
                if text is not None:
                    attr["text"]["description"] = text[i]
            if 'prompt_audio' in self.condition_provider.conditioners:
                if audio_qt_emb is None:    # tokenize stage will padding to max length
                    attr["audio"]['prompt_audio'] = AudioCondition(
                        wav=torch.zeros((1, self.cfg.audio_tokenizer_code_depth, 0)).long().cuda() + 16385, 
                        length=torch.Tensor([0]).long(),
                        sample_rate=[self.cfg.sample_rate],)
                else:
                    aT = audio_qt_emb[i].shape[-1]
                    pattern = self.pattern_provider.get_pattern(aT)
                    audio_qt_seq, _, _ = pattern.build_pattern_sequence(audio_qt_emb[i][None], 
                                                                        self.eos_token_id, keep_only_valid_steps=False)   
                    attr["audio"]['prompt_audio'] = AudioCondition(
                        wav=audio_qt_seq.long().cuda(), 
                        length=torch.Tensor([audio_qt_seq.shape[-1]]).long(),
                        sample_rate=[self.cfg.sample_rate],)
            if 'type_info' in self.condition_provider.conditioners:
                attr["text"]["type_info"] = ""
                if type_info is not None:
                    attr["text"]["type_info"] = type_info[i]
            conditions.append(attr)
        if prepare_null_condition:
            cfg_inference = ClassifierFreeGuidanceDropoutInference() 
            null_conditions = cfg_inference(conditions, condition_types=["audio", "text"], 
                                            customized=None)
            conditions = conditions + null_conditions
        print("conditions", conditions)
        tokenized_conditions = self.condition_provider.tokenize(conditions)
        # import pdb; pdb.set_trace()
        condition_tensors = self.condition_provider(tokenized_conditions)
        return condition_tensors

    def get_condition_tensors(self, texts, audio_qt_embs, type_info, embeded_eosp1):
        condition_tensors = self.prepare_condition_tensors(batch_size=1, text=texts, audio_qt_emb=audio_qt_embs, type_info=type_info, prepare_null_condition=self.cfg.vllm.cfg)
        if self.cfg.vllm.cfg:
            input_ = torch.cat((embeded_eosp1, embeded_eosp1), dim=0)
        else:
            input_ = embeded_eosp1
        fused_input = self.fuser(input_, condition_tensors)
        return fused_input

    @torch.no_grad()
    def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None, chunked=False, chunk_size=128, gen_type='mixed'):
        """Generate Audio from tokens"""
        assert gen_tokens.dim() == 3
        if self.seperate_tokenizer is not None:
            gen_tokens_song = gen_tokens[:, [0], :]
            gen_tokens_vocal = gen_tokens[:, [1], :]
            gen_tokens_bgm = gen_tokens[:, [2], :]
            if gen_type == 'bgm':
                gen_tokens_vocal = torch.full_like(gen_tokens_vocal, 3142)
                if vocal_prompt is not None:
                    vocal_prompt = torch.zeros_like(vocal_prompt)
            elif gen_type == 'vocal':
                gen_tokens_bgm = torch.full_like(gen_tokens_bgm, 9670)
                if bgm_prompt is not None:
                    bgm_prompt = torch.zeros_like(bgm_prompt)
            else:
                assert gen_type == 'mixed', f"gen_type {gen_type} not supported"
            gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt, chunked=chunked, chunk_size=chunk_size)
            return gen_audio_seperate
        else:
            gen_audio = self.audiotokenizer.decode(gen_tokens, prompt, chunked=chunked, chunk_size=chunk_size)
            return gen_audio