| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| from torch.nn.utils.parametrize import register_parametrization |
| from torchcomp import ms2coef, coef2ms, db2amp, amp2db |
| from torchaudio.transforms import Spectrogram, InverseSpectrogram |
|
|
| from typing import List, Tuple, Union, Any, Optional, Callable |
| import math |
| from torch_fftconv import fft_conv1d |
| from functools import reduce |
|
|
| from .functional import ( |
| compressor_expander, |
| lowpass_biquad, |
| highpass_biquad, |
| equalizer_biquad, |
| lowshelf_biquad, |
| highshelf_biquad, |
| lowpass_biquad_coef, |
| highpass_biquad_coef, |
| highshelf_biquad_coef, |
| lowshelf_biquad_coef, |
| equalizer_biquad_coef, |
| ) |
| from .utils import chain_functions |
|
|
|
|
| class Clip(nn.Module): |
| def __init__(self, max: Optional[float] = None, min: Optional[float] = None): |
| super().__init__() |
| self.min = min |
| self.max = max |
|
|
| def forward(self, x): |
| if self.min is not None: |
| x = torch.clip(x, min=self.min) |
| if self.max is not None: |
| x = torch.clip(x, max=self.max) |
| return x |
|
|
|
|
| def clip_delay_eq_Q(m: nn.Module, Q: float): |
| if isinstance(m, Delay) and isinstance(m.eq, LowPass): |
| register_parametrization(m.eq.params, "Q", Clip(max=Q)) |
| return m |
|
|
|
|
| float2param = lambda x: nn.Parameter( |
| torch.tensor(x, dtype=torch.float32) if not isinstance(x, torch.Tensor) else x |
| ) |
|
|
| STEREO_NORM = math.sqrt(2) |
|
|
|
|
| def broadcast2stereo(m, args): |
| x, *_ = args |
| return x.expand(-1, 2, -1) if x.shape[1] == 1 else x |
|
|
|
|
| hadamard = lambda x: torch.stack([x.sum(1), x[:, 0] - x[:, 1]], 1) / STEREO_NORM |
|
|
|
|
| class Hadamard(nn.Module): |
| def forward(self, x): |
| return hadamard(x) |
|
|
|
|
| class FX(nn.Module): |
| def __init__(self, **kwargs) -> None: |
| super().__init__() |
|
|
| self.params = nn.ParameterDict({k: float2param(v) for k, v in kwargs.items()}) |
|
|
| def toJSON(self) -> dict[str, Any]: |
| return {k: v.item() for k, v in self.params.items() if v.numel() == 1} |
|
|
|
|
| class SmoothingCoef(nn.Module): |
| def forward(self, x): |
| return x.sigmoid() |
|
|
| def right_inverse(self, y): |
| return (y / (1 - y)).log() |
|
|
|
|
| class CompRatio(nn.Module): |
| def forward(self, x): |
| return x.exp() + 1 |
|
|
| def right_inverse(self, y): |
| return torch.log(y - 1) |
|
|
|
|
| class MinMax(nn.Module): |
| def __init__(self, min=0.0, max: Union[float, torch.Tensor] = 1.0): |
| super().__init__() |
| if isinstance(min, torch.Tensor): |
| self.register_buffer("min", min, persistent=False) |
| else: |
| self.min = min |
|
|
| if isinstance(max, torch.Tensor): |
| self.register_buffer("max", max, persistent=False) |
| else: |
| self.max = max |
|
|
| self._m = SmoothingCoef() |
|
|
| def forward(self, x): |
| return self._m(x) * (self.max - self.min) + self.min |
|
|
| def right_inverse(self, y): |
| return self._m.right_inverse((y - self.min) / (self.max - self.min)) |
|
|
|
|
| class WrappedPositive(nn.Module): |
| def __init__(self, period): |
| super().__init__() |
| self.period = period |
|
|
| def forward(self, x): |
| return x.abs() % self.period |
|
|
| def right_inverse(self, y): |
| return y |
|
|
|
|
| class CompressorExpander(FX): |
| cmp_ratio_min: float = 1 |
| cmp_ratio_max: float = 20 |
|
|
| def __init__( |
| self, |
| sr: int, |
| cmp_ratio: float = 2.0, |
| exp_ratio: float = 0.5, |
| at_ms: float = 50.0, |
| rt_ms: float = 50.0, |
| avg_coef: float = 0.3, |
| cmp_th: float = -18.0, |
| exp_th: float = -54.0, |
| make_up: float = 0.0, |
| delay: int = 0, |
| lookahead: bool = False, |
| max_lookahead: float = 15.0, |
| ): |
| super().__init__( |
| cmp_th=cmp_th, |
| exp_th=exp_th, |
| make_up=make_up, |
| avg_coef=avg_coef, |
| cmp_ratio=cmp_ratio, |
| exp_ratio=exp_ratio, |
| ) |
| |
| self.delay = delay |
| self.sr = sr |
|
|
| self.params["at"] = nn.Parameter(ms2coef(torch.tensor(at_ms), sr)) |
| self.params["rt"] = nn.Parameter(ms2coef(torch.tensor(rt_ms), sr)) |
|
|
| if lookahead: |
| self.params["lookahead"] = nn.Parameter(torch.ones(1) / sr * 1000) |
| register_parametrization( |
| self.params, "lookahead", WrappedPositive(max_lookahead) |
| ) |
| sinc_length = int(sr * (max_lookahead + 1) * 0.001) + 1 |
| left_pad_size = int(sr * 0.001) |
| self._pad_size = (left_pad_size, sinc_length - left_pad_size - 1) |
| self.register_buffer( |
| "_arange", |
| torch.arange(sinc_length) - left_pad_size, |
| persistent=False, |
| ) |
| self.lookahead = lookahead |
|
|
| register_parametrization(self.params, "at", SmoothingCoef()) |
| register_parametrization(self.params, "rt", SmoothingCoef()) |
| register_parametrization(self.params, "avg_coef", SmoothingCoef()) |
| register_parametrization( |
| self.params, "cmp_ratio", MinMax(self.cmp_ratio_min, self.cmp_ratio_max) |
| ) |
| register_parametrization(self.params, "exp_ratio", SmoothingCoef()) |
|
|
| def extra_repr(self) -> str: |
| with torch.no_grad(): |
| s = ( |
| f"attack: {coef2ms(self.params.at, self.sr).item()} (ms)\n" |
| f"release: {coef2ms(self.params.rt, self.sr).item()} (ms)\n" |
| f"avg_coef: {self.params.avg_coef.item()}\n" |
| f"compressor_ratio: {self.params.cmp_ratio.item()}\n" |
| f"expander_ratio: {self.params.exp_ratio.item()}\n" |
| f"compressor_threshold: {self.params.cmp_th.item()} (dB)\n" |
| f"expander_threshold: {self.params.exp_th.item()} (dB)\n" |
| f"make_up: {self.params.make_up.item()} (dB)" |
| ) |
| if self.lookahead: |
| s += f"\nlookahead: {self.params.lookahead.item()} (ms)" |
| return s |
|
|
| def toJSON(self) -> dict[str, Any]: |
| return { |
| "Attack (ms)": coef2ms(self.params.at, self.sr).item(), |
| "Release (ms)": coef2ms(self.params.rt, self.sr).item(), |
| "Average Coefficient": self.params.avg_coef.item(), |
| "Compressor Ratio": self.params.cmp_ratio.item(), |
| "Expander Ratio": self.params.exp_ratio.item(), |
| "Compressor Threshold (dB)": self.params.cmp_th.item(), |
| "Expander Threshold (dB)": self.params.exp_th.item(), |
| "Make Up (dB)": self.params.make_up.item(), |
| } | ({"Lookahead (ms)": self.params.lookahead.item()} if self.lookahead else {}) |
|
|
| def forward(self, x): |
| if self.lookahead: |
| lookahead_in_samples = self.params.lookahead * 0.001 * self.sr |
| sinc_filter = torch.sinc(self._arange - lookahead_in_samples) |
| lookahead_func = lambda gain: F.conv1d( |
| F.pad( |
| gain.view(-1, 1, gain.size(-1)), self._pad_size, mode="replicate" |
| ), |
| sinc_filter[None, None, :], |
| ).view(*gain.shape) |
| else: |
| lookahead_func = lambda x: x |
|
|
| return compressor_expander( |
| x.reshape(-1, x.shape[-1]), |
| lookahead_func=lookahead_func, |
| **{k: v for k, v in self.params.items() if k != "lookahead"}, |
| ).view(*x.shape) |
|
|
|
|
| class Panning(FX): |
| def __init__(self, pan: float = 0.0): |
| assert pan <= 100 and pan >= -100 |
| super().__init__(pan=(pan + 100) / 200) |
|
|
| register_parametrization(self.params, "pan", SmoothingCoef()) |
|
|
| self.register_forward_pre_hook(broadcast2stereo) |
|
|
| def extra_repr(self) -> str: |
| with torch.no_grad(): |
| s = f"pan: {self.params.pan.item() * 200 - 100}" |
| return s |
|
|
| def toJSON(self) -> dict[str, Any]: |
| return { |
| "Pan": self.params.pan.item() * 200 - 100, |
| } |
|
|
| def forward(self, x: torch.Tensor): |
| angle = self.params.pan.view(1) * torch.pi * 0.5 |
| amp = torch.concat([angle.cos(), angle.sin()]).view(2, 1) * STEREO_NORM |
| return x * amp |
|
|
|
|
| class StereoWidth(Panning): |
| def forward(self, x: torch.Tensor): |
| return chain_functions(hadamard, super().forward, hadamard)(x) |
|
|
|
|
| class ImpulseResponse(nn.Module): |
| def forward(self, h): |
| return torch.cat([torch.ones_like(h[..., :1]), h], dim=-1) |
|
|
|
|
| class FIR(FX): |
| def __init__( |
| self, |
| length: int, |
| channels: int = 2, |
| conv_method: str = "direct", |
| ): |
| super().__init__(kernel=torch.zeros(channels, length - 1)) |
| self._padding = length - 1 |
| self.channels = channels |
|
|
| match conv_method: |
| case "direct": |
| self.conv_func = F.conv1d |
| case "fft": |
| self.conv_func = fft_conv1d |
| case _: |
| raise ValueError(f"Unknown conv_method: {conv_method}") |
|
|
| if channels == 2: |
| self.register_forward_pre_hook(broadcast2stereo) |
|
|
| def forward(self, x: torch.Tensor): |
| zero_padded = F.pad(x[..., :-1], (self._padding, 0), "constant", 0) |
| return x + self.conv_func( |
| zero_padded, self.params.kernel.flip(1).unsqueeze(1), groups=self.channels |
| ) |
|
|
|
|
| class QFactor(nn.Module): |
| def forward(self, x): |
| return x.exp() |
|
|
| def right_inverse(self, y): |
| return y.log() |
|
|
|
|
| class LowPass(FX): |
| def __init__( |
| self, |
| sr: int, |
| freq: float = 17500.0, |
| Q: float = 0.707, |
| min_freq: float = 200.0, |
| max_freq: float = 18000, |
| min_Q: float = 0.5, |
| max_Q: float = 10.0, |
| ): |
| super().__init__(freq=freq, Q=Q) |
|
|
| self.sr = sr |
| register_parametrization(self.params, "freq", MinMax(min_freq, max_freq)) |
| register_parametrization(self.params, "Q", MinMax(min_Q, max_Q)) |
|
|
| def forward(self, x): |
| return lowpass_biquad( |
| x, sample_rate=self.sr, cutoff_freq=self.params.freq, Q=self.params.Q |
| ) |
|
|
| def extra_repr(self) -> str: |
| with torch.no_grad(): |
| s = f"freq: {self.params.freq.item():.4f}, Q: {self.params.Q.item():.4f}" |
| return s |
|
|
| def toJSON(self) -> dict[str, Any]: |
| return { |
| "Frequency (Hz)": self.params.freq.item(), |
| "Q": self.params.Q.item(), |
| } |
|
|
|
|
| class HighPass(LowPass): |
| def __init__( |
| self, |
| *args, |
| freq: float = 200.0, |
| min_freq: float = 16.0, |
| max_freq: float = 5300.0, |
| **kwargs, |
| ): |
| super().__init__( |
| *args, freq=freq, min_freq=min_freq, max_freq=max_freq, **kwargs |
| ) |
|
|
| def forward(self, x): |
| return highpass_biquad( |
| x, sample_rate=self.sr, cutoff_freq=self.params.freq, Q=self.params.Q |
| ) |
|
|
|
|
| class Peak(FX): |
| def __init__( |
| self, |
| sr: int, |
| gain: float = 0.0, |
| freq: float = 2000.0, |
| Q: float = 0.707, |
| min_freq: float = 33.0, |
| max_freq: float = 17500.0, |
| min_Q: float = 0.2, |
| max_Q: float = 20, |
| ): |
| super().__init__(freq=freq, Q=Q, gain=gain) |
|
|
| self.sr = sr |
|
|
| register_parametrization(self.params, "freq", MinMax(min_freq, max_freq)) |
| register_parametrization(self.params, "Q", MinMax(min_Q, max_Q)) |
|
|
| def forward(self, x): |
| return equalizer_biquad( |
| x, |
| sample_rate=self.sr, |
| center_freq=self.params.freq, |
| Q=self.params.Q, |
| gain=self.params.gain, |
| ) |
|
|
| def extra_repr(self) -> str: |
| with torch.no_grad(): |
| s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}, Q: {self.params.Q.item():.4f}" |
| return s |
|
|
| def toJSON(self) -> dict[str, Any]: |
| return { |
| "Frequency (Hz)": self.params.freq.item(), |
| "Gain (dB)": self.params.gain.item(), |
| "Q": self.params.Q.item(), |
| } |
|
|
|
|
| class LowShelf(FX): |
| def __init__( |
| self, |
| sr: int, |
| gain: float = 0.0, |
| freq: float = 115.0, |
| min_freq: float = 30, |
| max_freq: float = 200, |
| ): |
| super().__init__(freq=freq, gain=gain) |
|
|
| self.sr = sr |
| register_parametrization(self.params, "freq", MinMax(min_freq, max_freq)) |
|
|
| self.register_buffer("Q", torch.tensor(0.707), persistent=False) |
|
|
| def forward(self, x): |
| return lowshelf_biquad( |
| x, |
| sample_rate=self.sr, |
| cutoff_freq=self.params.freq, |
| gain=self.params.gain, |
| Q=self.Q, |
| ) |
|
|
| def extra_repr(self) -> str: |
| with torch.no_grad(): |
| s = f"freq: {self.params.freq.item():.4f}, gain: {self.params.gain.item():.4f}" |
| return s |
|
|
| def toJSON(self) -> dict[str, Any]: |
| return { |
| "Frequency (Hz)": self.params.freq.item(), |
| "Gain (dB)": self.params.gain.item(), |
| } |
|
|
|
|
| class HighShelf(LowShelf): |
| def __init__( |
| self, |
| *args, |
| freq: float = 4525, |
| min_freq: float = 750, |
| max_freq: float = 8300, |
| **kwargs, |
| ): |
| super().__init__( |
| *args, freq=freq, min_freq=min_freq, max_freq=max_freq, **kwargs |
| ) |
|
|
| def forward(self, x): |
| return highshelf_biquad( |
| x, |
| sample_rate=self.sr, |
| cutoff_freq=self.params.freq, |
| gain=self.params.gain, |
| Q=self.Q, |
| ) |
|
|
|
|
| def module2coeffs( |
| m: Union[LowPass, HighPass, Peak, LowShelf, HighShelf], |
| ) -> Tuple[ |
| torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor |
| ]: |
| match m: |
| case LowPass(): |
| return lowpass_biquad_coef(m.sr, m.params.freq, m.params.Q) |
| case HighPass(): |
| return highpass_biquad_coef(m.sr, m.params.freq, m.params.Q) |
| case Peak(): |
| return equalizer_biquad_coef(m.sr, m.params.freq, m.params.Q, m.params.gain) |
| case LowShelf(): |
| return lowshelf_biquad_coef(m.sr, m.params.freq, m.params.gain, m.Q) |
| case HighShelf(): |
| return highshelf_biquad_coef(m.sr, m.params.freq, m.params.gain, m.Q) |
| case _: |
| raise ValueError(f"Unknown module: {m}") |
|
|
|
|
| class AlwaysNegative(nn.Module): |
| def forward(self, x): |
| return -F.softplus(x) |
|
|
| def right_inverse(self, y): |
| return torch.log(y.neg().exp() - 1) |
|
|
|
|
| class Reverb(FX): |
| def __init__(self, ir_len=60000, n_fft=384, hop_length=192, downsample_factor=1): |
| super().__init__( |
| log_mag=torch.full((2, n_fft // downsample_factor // 2 + 1), -1.0), |
| log_mag_delta=torch.full((2, n_fft // downsample_factor // 2 + 1), -5.0), |
| ) |
|
|
| self.steps = (ir_len - n_fft + hop_length - 1) // hop_length |
| self.n_fft = n_fft |
| self.hop_length = hop_length |
| self.downsample_factor = downsample_factor |
|
|
| self._noise_angle = nn.Parameter( |
| torch.rand(2, n_fft // 2 + 1, self.steps) * 2 * torch.pi |
| ) |
|
|
| self.register_buffer( |
| "_arange", torch.arange(self.steps, dtype=torch.float32), persistent=False |
| ) |
| self.spec_forward = Spectrogram(n_fft, hop_length=hop_length, power=None) |
| self.spec_inverse = InverseSpectrogram( |
| n_fft, |
| hop_length=hop_length, |
| ) |
|
|
| register_parametrization(self.params, "log_mag", AlwaysNegative()) |
| register_parametrization(self.params, "log_mag_delta", AlwaysNegative()) |
|
|
| self.register_forward_pre_hook(broadcast2stereo) |
|
|
| def forward(self, x): |
| h = x |
| H = self.spec_forward(h) |
|
|
| log_mag = self.params.log_mag |
| log_mag_delta = self.params.log_mag_delta |
|
|
| if self.downsample_factor > 1: |
| log_mag = F.interpolate( |
| log_mag.unsqueeze(0), |
| size=self._noise_angle.size(1), |
| align_corners=True, |
| mode="linear", |
| ).squeeze(0) |
| log_mag_delta = F.interpolate( |
| log_mag_delta.unsqueeze(0), |
| size=self._noise_angle.size(1), |
| align_corners=True, |
| mode="linear", |
| ).squeeze(0) |
|
|
| ir_2d = torch.exp( |
| log_mag.unsqueeze(-1) |
| + log_mag_delta.unsqueeze(-1) * self._arange |
| + self._noise_angle * 1j |
| ) |
|
|
| padded_H = F.pad(H.flatten(1, 2), (ir_2d.shape[-1] - 1, 0)) |
|
|
| H = F.conv1d( |
| padded_H, |
| hadamard(ir_2d.unsqueeze(0)).flatten(1, 2).flip(-1).transpose(0, 1), |
| groups=H.shape[2] * 2, |
| ).view(*H.shape) |
|
|
| h = self.spec_inverse(H) |
| return h |
|
|
|
|
| class Delay(FX): |
| min_delay: float = 100 |
| max_delay: float = 1000 |
|
|
| def __init__( |
| self, |
| sr: int, |
| delay=200.0, |
| feedback=0.1, |
| gain=0.1, |
| ir_duration: float = 2, |
| eq: Optional[nn.Module] = None, |
| recursive_eq=False, |
| ): |
| super().__init__( |
| delay=delay, |
| feedback=feedback, |
| gain=gain, |
| ) |
| self.sr = sr |
| self.ir_length = int(sr * max(ir_duration, self.max_delay * 0.002)) |
|
|
| register_parametrization( |
| self.params, "delay", MinMax(self.min_delay, self.max_delay) |
| ) |
| register_parametrization(self.params, "feedback", SmoothingCoef()) |
| register_parametrization(self.params, "gain", SmoothingCoef()) |
|
|
| self.eq = eq |
| self.recursive_eq = recursive_eq |
|
|
| self.register_buffer( |
| "_arange", torch.arange(self.ir_length, dtype=torch.float32) |
| ) |
|
|
| self.odd_pan = Panning(0) |
| self.even_pan = Panning(0) |
|
|
| def forward(self, x): |
| assert x.size(1) == 1, x.size() |
| delay_in_samples = self.sr * self.params.delay * 0.001 |
| num_delays = self.ir_length // int(delay_in_samples.item() + 1) |
| series = torch.arange(1, num_delays + 1, device=x.device) |
| decays = self.params.feedback ** (series - 1) |
|
|
| if self.recursive_eq and self.eq is not None: |
| sinc_index = self._arange - delay_in_samples |
| single_sinc_filter = torch.sinc(sinc_index) |
| eq_sinc_filter = self.eq(single_sinc_filter) |
| H = torch.fft.rfft(eq_sinc_filter) |
| H_powered = torch.polar( |
| H.abs() ** series.unsqueeze(-1), H.angle() * series.unsqueeze(-1) |
| ) |
| sinc_filters = torch.fft.irfft(H_powered, n=self.ir_length) |
| else: |
| delays_in_samples = delay_in_samples * series |
| sinc_indexes = self._arange - delays_in_samples.unsqueeze(-1) |
| sinc_filters = torch.sinc(sinc_indexes) |
|
|
| decayed_sinc_filters = sinc_filters * decays.unsqueeze(-1) |
| return self._filter(x, decayed_sinc_filters) |
|
|
| def _filter(self, x: torch.Tensor, decayed_sinc_filters: torch.Tensor): |
| odd_delay_filters = torch.sum(decayed_sinc_filters[::2], 0) |
| even_delay_filters = torch.sum(decayed_sinc_filters[1::2], 0) |
| stacked_filters = torch.stack([odd_delay_filters, even_delay_filters]) |
|
|
| if self.eq is not None and not self.recursive_eq: |
| stacked_filters = self.eq(stacked_filters) |
|
|
| gained_odd_even_filters = stacked_filters * self.params.gain |
| padded_x = F.pad(x, (gained_odd_even_filters.size(-1) - 1, 0)) |
| conv1d = F.conv1d if x.size(-1) > 44100 * 20 else fft_conv1d |
| return sum( |
| [ |
| panner(s) |
| for panner, s in zip( |
| [self.odd_pan, self.even_pan], |
| |
| conv1d( |
| padded_x, |
| gained_odd_even_filters.flip(-1).unsqueeze(1), |
| ).chunk(2, 1), |
| ) |
| ] |
| ) |
|
|
| def extra_repr(self) -> str: |
| with torch.no_grad(): |
| s = ( |
| f"delay: {self.sr * self.params.delay.item() * 0.001} (samples)\n" |
| f"feedback: {self.params.feedback.item()}\n" |
| f"gain: {self.params.gain.item()}" |
| ) |
| return s |
|
|
| def toJSON(self) -> dict[str, Any]: |
| return { |
| "Delay (ms)": self.params.delay.item(), |
| "Feedback (dB)": self.params.feedback.log10().mul(20).item(), |
| "Gain (dB)": self.params.gain.log10().mul(20).item(), |
| "Odd delays": self.odd_pan.toJSON(), |
| "Even delays": self.even_pan.toJSON(), |
| } |
|
|
|
|
| class SurrogateDelay(Delay): |
| def __init__(self, *args, dropout=0.5, straight_through=False, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| self.dropout = dropout |
| self.straight_through = straight_through |
| self.log_damp = nn.Parameter(torch.ones(1) * -0.01) |
| register_parametrization(self, "log_damp", AlwaysNegative()) |
|
|
| def forward(self, x): |
| assert x.size(1) == 1, x.size() |
| if not self.training: |
| return super().forward(x) |
|
|
| log_damp = self.log_damp |
| delay_in_samples = self.sr * self.params.delay * 0.001 |
| num_delays = self.ir_length // int(delay_in_samples.item() + 1) |
| series = torch.arange(1, num_delays + 1, device=x.device) |
| decays = self.params.feedback ** (series - 1) |
|
|
| if self.recursive_eq and self.eq is not None: |
| exp_factor = self._arange[: self.ir_length // 2 + 1] |
| damped_exp = torch.exp( |
| log_damp * exp_factor |
| - 1j * delay_in_samples / self.ir_length * 2 * torch.pi * exp_factor |
| ) |
| sinc_filter = torch.fft.irfft(damped_exp, n=self.ir_length) |
| if self.straight_through: |
| sinc_index = self._arange - delay_in_samples |
| hard_sinc_filter = torch.sinc(sinc_index) |
| sinc_filter = sinc_filter + (hard_sinc_filter - sinc_filter).detach() |
|
|
| eq_sinc_filter = self.eq(sinc_filter) |
| H = torch.fft.rfft(eq_sinc_filter) |
|
|
| |
| H_powered = torch.polar( |
| H.abs() ** series.unsqueeze(-1), H.angle() * series.unsqueeze(-1) |
| ) |
| sinc_filters = torch.fft.irfft(H_powered, n=self.ir_length) |
| else: |
| exp_factors = series.unsqueeze(-1) * self._arange[: self.ir_length // 2 + 1] |
| damped_exps = torch.exp( |
| log_damp * exp_factors |
| - 1j * delay_in_samples / self.ir_length * 2 * torch.pi * exp_factors |
| ) |
| sinc_filters = torch.fft.irfft(damped_exps, n=self.ir_length) |
| if self.straight_through: |
| delays_in_samples = delay_in_samples * series |
| sinc_indexes = self._arange - delays_in_samples.unsqueeze(-1) |
| hard_sinc_filters = torch.sinc(sinc_indexes) |
| sinc_filters = ( |
| sinc_filters + (hard_sinc_filters - sinc_filters).detach() |
| ) |
|
|
| decayed_sinc_filters = sinc_filters * decays.unsqueeze(-1) |
|
|
| dropout_mask = torch.rand(x.size(0), device=x.device) < self.dropout |
| if not torch.any(dropout_mask): |
| return self._filter(x, decayed_sinc_filters) |
| elif torch.all(dropout_mask): |
| return super().forward(x) |
|
|
| out = torch.zeros((x.size(0), 2, x.size(2)), device=x.device) |
| out[~dropout_mask] = self._filter(x[~dropout_mask], decayed_sinc_filters) |
| out[dropout_mask] = super().forward(x[dropout_mask]) |
| return out |
|
|
| def extra_repr(self) -> str: |
| with torch.no_grad(): |
| return super().extra_repr() + f"\ndamp: {self.log_damp.exp().item()}" |
|
|
|
|
| class FSDelay(FX): |
| def __init__( |
| self, |
| sr: int, |
| delay=200.0, |
| feedback=0.1, |
| gain=0.1, |
| ir_duration: float = 6, |
| eq: Optional[LowPass] = None, |
| recursive_eq=False, |
| ): |
| super().__init__( |
| delay=delay, |
| feedback=feedback, |
| gain=gain, |
| ) |
| self.sr = sr |
| self.ir_length = int(sr * max(ir_duration, Delay.max_delay * 0.002)) |
|
|
| register_parametrization( |
| self.params, "delay", MinMax(Delay.min_delay, Delay.max_delay) |
| ) |
| register_parametrization(self.params, "gain", SmoothingCoef()) |
|
|
| T_60 = ir_duration * 0.75 |
| max_delay_in_samples = sr * Delay.max_delay * 0.001 |
| maximum_decay = db2amp(torch.tensor(-60 / sr / T_60 * max_delay_in_samples)) |
| register_parametrization(self.params, "feedback", MinMax(0, maximum_decay)) |
|
|
| self.eq = eq |
| self.recursive_eq = recursive_eq |
|
|
| self.odd_pan = Panning(0) |
| self.even_pan = Panning(0) |
|
|
| self.register_buffer( |
| "_arange", torch.arange(self.ir_length, dtype=torch.float32) |
| ) |
|
|
| def _get_h(self): |
| freqs = self._arange[: self.ir_length // 2 + 1] / self.ir_length * 2 * torch.pi |
| delay_in_samples = self.sr * self.params.delay * 0.001 |
|
|
| |
| Dinv = torch.exp(1j * freqs * delay_in_samples) |
| Dinv2 = torch.exp(2j * freqs * delay_in_samples) |
| if self.recursive_eq and self.eq is not None: |
| b0, b1, b2, a0, a1, a2 = module2coeffs(self.eq) |
| z_inv = torch.exp(-1j * freqs) |
| z_inv2 = torch.exp(-2j * freqs) |
| eq_H = (b0 + b1 * z_inv + b2 * z_inv2) / (a0 + a1 * z_inv + a2 * z_inv2) |
| damp = eq_H * self.params.feedback |
| det = Dinv2 - damp * damp |
| else: |
| damp = torch.full_like(Dinv, self.params.feedback) + 0j |
| det = Dinv2 - self.params.feedback.square() |
| inv_Dinv_m_A = torch.stack([Dinv, damp], 0) / det |
| h = torch.fft.irfft(inv_Dinv_m_A, n=self.ir_length) * self.params.gain |
|
|
| if self.eq is not None and not self.recursive_eq: |
| h = self.eq(h) |
| return h |
|
|
| def forward(self, x): |
| assert x.size(1) == 1, x.size() |
| h = self._get_h() |
|
|
| padded_x = F.pad(x, (h.size(-1) - 1, 0)) |
| conv1d = F.conv1d if x.size(-1) > 44100 * 20 else fft_conv1d |
| return sum( |
| [ |
| panner(s) |
| for panner, s in zip( |
| [self.odd_pan, self.even_pan], |
| conv1d( |
| padded_x, |
| h.flip(-1).unsqueeze(1), |
| ).chunk(2, 1), |
| ) |
| ] |
| ) |
|
|
| def extra_repr(self) -> str: |
| with torch.no_grad(): |
| s = ( |
| f"delay: {self.sr * self.params.delay.item() * 0.001} (samples)\n" |
| f"feedback: {self.params.feedback.item()}\n" |
| f"gain: {self.params.gain.item()}" |
| ) |
| return s |
|
|
|
|
| class FSSurrogateDelay(FSDelay): |
| def __init__(self, *args, straight_through=False, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| self.straight_through = straight_through |
| self.log_damp = nn.Parameter(torch.ones(1) * -0.0001) |
| register_parametrization(self, "log_damp", AlwaysNegative()) |
|
|
| def _get_h(self): |
| if not self.training: |
| return super()._get_h() |
|
|
| log_damp = self.log_damp |
| delay_in_samples = self.sr * self.params.delay * 0.001 |
|
|
| exp_factor = self._arange[: self.ir_length // 2 + 1] |
| freqs = exp_factor / self.ir_length * 2 * torch.pi |
| D = torch.exp(log_damp * exp_factor - 1j * delay_in_samples * freqs) |
| D2 = torch.exp(log_damp * exp_factor * 2 - 2j * delay_in_samples * freqs) |
|
|
| if self.straight_through: |
| D_orig = torch.exp(-1j * delay_in_samples * freqs) |
| D2_orig = torch.exp(-2j * delay_in_samples * freqs) |
| D = torch.stack([D, D_orig], 0) |
| D2 = torch.stack([D2, D2_orig], 0) |
|
|
| if self.recursive_eq and self.eq is not None: |
| b0, b1, b2, a0, a1, a2 = module2coeffs(self.eq) |
| z_inv = torch.exp(-1j * freqs) |
| z_inv2 = torch.exp(-2j * freqs) |
| eq_H = (b0 + b1 * z_inv + b2 * z_inv2) / (a0 + a1 * z_inv + a2 * z_inv2) |
| damp = eq_H * self.params.feedback |
| odd_H = D / (1 - damp * damp * D2) |
| even_H = odd_H * D * damp |
| else: |
| damp = torch.full_like(D, self.params.feedback) + 0j |
| odd_H = D / (1 - self.params.feedback.square() * D2) |
| even_H = odd_H * D * self.params.feedback |
|
|
| inv_Dinv_m_A = torch.stack([odd_H, even_H], 0) |
| h = torch.fft.irfft(inv_Dinv_m_A, n=self.ir_length) |
|
|
| if self.straight_through: |
| damped_h, orig_h = h.unbind(1) |
| h = damped_h + (orig_h - damped_h).detach() |
|
|
| if self.eq is not None and not self.recursive_eq: |
| h = self.eq(h) |
| return h * self.params.gain |
|
|
| def extra_repr(self) -> str: |
| with torch.no_grad(): |
| return super().extra_repr() + f"\ndamp: {self.log_damp.exp().item()}" |
|
|
|
|
| class SendFXsAndSum(FX): |
| def __init__(self, *args, cross_send=True, pan_direct=False): |
| super().__init__( |
| **( |
| { |
| f"sends_{i}": torch.full([len(args) - i - 1], 0.01) |
| for i in range(len(args) - 1) |
| } |
| if cross_send |
| else {} |
| ) |
| ) |
| self.effects = nn.ModuleList(args) |
| if pan_direct: |
| self.pan = Panning() |
|
|
| if cross_send: |
| for i in range(len(args) - 1): |
| register_parametrization(self.params, f"sends_{i}", SmoothingCoef()) |
|
|
| def forward(self, x): |
| if hasattr(self, "pan"): |
| di = self.pan(x) |
| else: |
| di = x |
|
|
| if len(self.params) == 0: |
| return di, reduce( |
| lambda x, y: x[..., : y.shape[-1]] + y[..., : x.shape[-1]], |
| map(lambda f: f(x), self.effects), |
| ) |
|
|
| def f(states, ps): |
| x, cum_sends = states |
| m, send_gains = ps |
| h = m(cum_sends[0]) |
| return ( |
| x[..., : h.shape[-1]] + h[..., : x.shape[-1]], |
| ( |
| None |
| if cum_sends.size(0) == 1 |
| else cum_sends[1:, ..., : h.shape[-1]] |
| + send_gains[:, None, None, None] * h[..., : cum_sends.shape[-1]] |
| ), |
| ) |
|
|
| return ( |
| di, |
| reduce( |
| f, |
| zip( |
| self.effects, |
| [self.params[f"sends_{i}"] for i in range(len(self.effects) - 1)] |
| + [None], |
| ), |
| ( |
| torch.zeros_like(x), |
| x.unsqueeze(0).expand(len(self.effects), -1, -1, -1), |
| ), |
| )[0], |
| ) |
|
|
|
|
| class UniLossLess(nn.Module): |
| def forward(self, x): |
| tri = x.triu(1) |
| return torch.linalg.matrix_exp(tri - tri.T) |
|
|
|
|
| class FDN(FX): |
| max_delay = 100 |
|
|
| def __init__( |
| self, |
| sr: int, |
| ir_duration: float = 1.0, |
| delays=(997, 1153, 1327, 1559, 1801, 2099), |
| trainable_delay=False, |
| num_decay_freq=1, |
| delay_independent_decay=False, |
| eq: Optional[nn.Module] = None, |
| ): |
| |
| num_delays = len(delays) |
| super().__init__( |
| b=torch.ones(num_delays, 2) / num_delays, |
| c=torch.zeros(2, num_delays), |
| U=torch.randn(num_delays, num_delays) / num_delays**0.5, |
| gamma=torch.rand( |
| num_decay_freq, num_delays if not delay_independent_decay else 1 |
| ) |
| * 0.2 |
| + 0.4, |
| |
| ) |
|
|
| self.sr = sr |
| self.ir_length = int(sr * ir_duration) |
|
|
| |
| T_60 = ir_duration * 0.75 |
| delays = torch.tensor(delays) |
| if delay_independent_decay: |
| gamma_max = db2amp(-60 / sr / T_60 * delays.min()) |
| else: |
| gamma_max = db2amp(-60 / sr / T_60 * delays) |
|
|
| register_parametrization(self.params, "gamma", MinMax(0, gamma_max)) |
| register_parametrization(self.params, "U", UniLossLess()) |
|
|
| if not trainable_delay: |
| self.register_buffer( |
| "delays", |
| delays, |
| ) |
| else: |
| self.params["delays"] = nn.Parameter(delays / sr * 1000) |
| register_parametrization(self.params, "delays", MinMax(0, self.max_delay)) |
|
|
| self.register_forward_pre_hook(broadcast2stereo) |
|
|
| self.eq = eq |
|
|
| def forward(self, x): |
| conv1d = F.conv1d if x.size(-1) > 44100 * 20 else fft_conv1d |
|
|
| c = self.params.c + 0j |
| b = self.params.b + 0j |
|
|
| gamma = self.params.gamma |
| delays = self.delays if hasattr(self, "delays") else self.params.delays |
|
|
| if gamma.size(0) > 1: |
| gamma = F.interpolate( |
| gamma.T.unsqueeze(1), |
| size=self.ir_length // 2 + 1, |
| align_corners=True, |
| mode="linear", |
| ).transpose(0, 2) |
|
|
| if gamma.size(2) == 1: |
| gamma = gamma ** (delays / delays.min()) |
|
|
| A = self.params.U * gamma |
|
|
| freqs = ( |
| torch.arange(self.ir_length // 2 + 1, device=x.device) |
| / self.ir_length |
| * 2 |
| * torch.pi |
| ) |
| invD = torch.exp(1j * freqs[:, None] * delays) |
| |
| H = c @ torch.linalg.solve(torch.diag_embed(invD) - A, b) |
|
|
| h = torch.fft.irfft(H.permute(1, 2, 0), n=self.ir_length) |
|
|
| if self.eq is not None: |
| h = self.eq(h) |
|
|
| |
| return conv1d( |
| F.pad(x, (self.ir_length - 1, 0)), |
| h.flip(-1), |
| ) |
|
|
| def toJSON(self) -> dict[str, Any]: |
| return { |
| "T60 (s)": { |
| f"{f:.2f} Hz": g.item() |
| for f, g in zip( |
| torch.linspace(0, 22050, self.params.gamma.numel()), |
| -60 * self.delays.min() / amp2db(self.params.gamma) / 44100, |
| ) |
| }, |
| "Gain (dB, approx)": amp2db( |
| torch.linalg.norm(self.params.b) * torch.linalg.norm(self.params.c) |
| ).item(), |
| } |
|
|