| import math |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Union |
|
|
| import numpy as np |
| import torch |
| import tqdm |
| from audiotools import AudioSignal |
| from torch import nn |
|
|
| SUPPORTED_VERSIONS = ["1.0.0"] |
|
|
|
|
| @dataclass |
| class DACFile: |
| codes: torch.Tensor |
|
|
| |
| chunk_length: int |
| original_length: int |
| input_db: float |
| channels: int |
| sample_rate: int |
| padding: bool |
| dac_version: str |
|
|
| def save(self, path): |
| artifacts = { |
| "codes": self.codes.numpy().astype(np.uint16), |
| "metadata": { |
| "input_db": self.input_db.numpy().astype(np.float32), |
| "original_length": self.original_length, |
| "sample_rate": self.sample_rate, |
| "chunk_length": self.chunk_length, |
| "channels": self.channels, |
| "padding": self.padding, |
| "dac_version": SUPPORTED_VERSIONS[-1], |
| }, |
| } |
| path = Path(path).with_suffix(".dac") |
| with open(path, "wb") as f: |
| np.save(f, artifacts) |
| return path |
|
|
| @classmethod |
| def load(cls, path): |
| artifacts = np.load(path, allow_pickle=True)[()] |
| codes = torch.from_numpy(artifacts["codes"].astype(int)) |
| if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: |
| raise RuntimeError( |
| f"Given file {path} can't be loaded with this version of descript-audio-codec." |
| ) |
| return cls(codes=codes, **artifacts["metadata"]) |
|
|
|
|
| class CodecMixin: |
| @property |
| def padding(self): |
| if not hasattr(self, "_padding"): |
| self._padding = True |
| return self._padding |
|
|
| @padding.setter |
| def padding(self, value): |
| assert isinstance(value, bool) |
|
|
| layers = [ |
| l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) |
| ] |
|
|
| for layer in layers: |
| if value: |
| if hasattr(layer, "original_padding"): |
| layer.padding = layer.original_padding |
| else: |
| layer.original_padding = layer.padding |
| layer.padding = tuple(0 for _ in range(len(layer.padding))) |
|
|
| self._padding = value |
|
|
| def get_delay(self): |
| |
| l_out = self.get_output_length(0) |
| L = l_out |
|
|
| layers = [] |
| for layer in self.modules(): |
| if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): |
| layers.append(layer) |
|
|
| for layer in reversed(layers): |
| d = layer.dilation[0] |
| k = layer.kernel_size[0] |
| s = layer.stride[0] |
|
|
| if isinstance(layer, nn.ConvTranspose1d): |
| L = ((L - d * (k - 1) - 1) / s) + 1 |
| elif isinstance(layer, nn.Conv1d): |
| L = (L - 1) * s + d * (k - 1) + 1 |
|
|
| L = math.ceil(L) |
|
|
| l_in = L |
|
|
| return (l_in - l_out) // 2 |
|
|
| def get_output_length(self, input_length): |
| L = input_length |
| |
| for layer in self.modules(): |
| if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): |
| d = layer.dilation[0] |
| k = layer.kernel_size[0] |
| s = layer.stride[0] |
|
|
| if isinstance(layer, nn.Conv1d): |
| L = ((L - d * (k - 1) - 1) / s) + 1 |
| elif isinstance(layer, nn.ConvTranspose1d): |
| L = (L - 1) * s + d * (k - 1) + 1 |
|
|
| L = math.floor(L) |
| return L |
|
|
| @torch.no_grad() |
| def compress( |
| self, |
| audio_path_or_signal: Union[str, Path, AudioSignal], |
| win_duration: float = 1.0, |
| verbose: bool = False, |
| normalize_db: float = -16, |
| n_quantizers: int = None, |
| ) -> DACFile: |
| """Processes an audio signal from a file or AudioSignal object into |
| discrete codes. This function processes the signal in short windows, |
| using constant GPU memory. |
| |
| Parameters |
| ---------- |
| audio_path_or_signal : Union[str, Path, AudioSignal] |
| audio signal to reconstruct |
| win_duration : float, optional |
| window duration in seconds, by default 5.0 |
| verbose : bool, optional |
| by default False |
| normalize_db : float, optional |
| normalize db, by default -16 |
| |
| Returns |
| ------- |
| DACFile |
| Object containing compressed codes and metadata |
| required for decompression |
| """ |
| audio_signal = audio_path_or_signal |
| if isinstance(audio_signal, (str, Path)): |
| audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) |
|
|
| self.eval() |
| original_padding = self.padding |
| original_device = audio_signal.device |
|
|
| audio_signal = audio_signal.clone() |
| original_sr = audio_signal.sample_rate |
|
|
| resample_fn = audio_signal.resample |
| loudness_fn = audio_signal.loudness |
|
|
| |
| if audio_signal.signal_duration >= 10 * 60 * 60: |
| resample_fn = audio_signal.ffmpeg_resample |
| loudness_fn = audio_signal.ffmpeg_loudness |
|
|
| original_length = audio_signal.signal_length |
| resample_fn(self.sample_rate) |
| input_db = loudness_fn() |
|
|
| if normalize_db is not None: |
| audio_signal.normalize(normalize_db) |
| audio_signal.ensure_max_of_audio() |
|
|
| nb, nac, nt = audio_signal.audio_data.shape |
| audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) |
| win_duration = ( |
| audio_signal.signal_duration if win_duration is None else win_duration |
| ) |
|
|
| if audio_signal.signal_duration <= win_duration: |
| |
| self.padding = True |
| n_samples = nt |
| hop = nt |
| else: |
| |
| self.padding = False |
| |
| audio_signal.zero_pad(self.delay, self.delay) |
| n_samples = int(win_duration * self.sample_rate) |
| |
| n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) |
| hop = self.get_output_length(n_samples) |
|
|
| codes = [] |
| range_fn = range if not verbose else tqdm.trange |
|
|
| for i in range_fn(0, nt, hop): |
| x = audio_signal[..., i : i + n_samples] |
| x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) |
|
|
| audio_data = x.audio_data.to(self.device) |
| audio_data = self.preprocess(audio_data, self.sample_rate) |
| _, c, _, _, _ = self.encode(audio_data, n_quantizers) |
| codes.append(c.to(original_device)) |
| chunk_length = c.shape[-1] |
|
|
| codes = torch.cat(codes, dim=-1) |
|
|
| dac_file = DACFile( |
| codes=codes, |
| chunk_length=chunk_length, |
| original_length=original_length, |
| input_db=input_db, |
| channels=nac, |
| sample_rate=original_sr, |
| padding=self.padding, |
| dac_version=SUPPORTED_VERSIONS[-1], |
| ) |
|
|
| if n_quantizers is not None: |
| codes = codes[:, :n_quantizers, :] |
|
|
| self.padding = original_padding |
| return dac_file |
|
|
| @torch.no_grad() |
| def decompress( |
| self, |
| obj: Union[str, Path, DACFile], |
| verbose: bool = False, |
| ) -> AudioSignal: |
| """Reconstruct audio from a given .dac file |
| |
| Parameters |
| ---------- |
| obj : Union[str, Path, DACFile] |
| .dac file location or corresponding DACFile object. |
| verbose : bool, optional |
| Prints progress if True, by default False |
| |
| Returns |
| ------- |
| AudioSignal |
| Object with the reconstructed audio |
| """ |
| self.eval() |
| if isinstance(obj, (str, Path)): |
| obj = DACFile.load(obj) |
|
|
| original_padding = self.padding |
| self.padding = obj.padding |
|
|
| range_fn = range if not verbose else tqdm.trange |
| codes = obj.codes |
| original_device = codes.device |
| chunk_length = obj.chunk_length |
| recons = [] |
|
|
| for i in range_fn(0, codes.shape[-1], chunk_length): |
| c = codes[..., i : i + chunk_length].to(self.device) |
| z = self.quantizer.from_codes(c)[0] |
| r = self.decode(z) |
| recons.append(r.to(original_device)) |
|
|
| recons = torch.cat(recons, dim=-1) |
| recons = AudioSignal(recons, self.sample_rate) |
|
|
| resample_fn = recons.resample |
| loudness_fn = recons.loudness |
|
|
| |
| if recons.signal_duration >= 10 * 60 * 60: |
| resample_fn = recons.ffmpeg_resample |
| loudness_fn = recons.ffmpeg_loudness |
|
|
| recons.normalize(obj.input_db) |
| resample_fn(obj.sample_rate) |
| recons = recons[..., : obj.original_length] |
| loudness_fn() |
| recons.audio_data = recons.audio_data.reshape( |
| -1, obj.channels, obj.original_length |
| ) |
|
|
| self.padding = original_padding |
| return recons |
|
|