| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from configuration_bigcodec import BigCodecConfig |
|
|
| |
| from vq.codec_encoder import CodecEncoder_Transformer |
| from vq.codec_decoder_vocos import CodecDecoderVocos |
| from vq.module import SemanticEncoder |
| from transformers import AutoFeatureExtractor, Wav2Vec2BertModel |
|
|
| class XCodec2Model(PreTrainedModel): |
| config_class = BigCodecConfig |
|
|
| def __init__(self, config: BigCodecConfig): |
| super().__init__(config) |
|
|
| |
| self.semantic_model = Wav2Vec2BertModel.from_pretrained( |
| "facebook/w2v-bert-2.0", |
| output_hidden_states=True |
| ) |
| self.semantic_model.eval() |
|
|
| self.SemanticEncoder_module = SemanticEncoder( |
| config.semantic_hidden_size, |
| config.semantic_hidden_size, |
| config.semantic_hidden_size |
| ) |
|
|
| |
| self.CodecEnc = CodecEncoder_Transformer() |
|
|
| |
| self.generator = CodecDecoderVocos() |
|
|
| |
| self.fc_prior = nn.Linear(2048, 2048) |
| self.fc_post_a = nn.Linear(2048, 1024) |
| feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0") |
| self.feature_extractor = feature_extractor |
|
|
| def forward(self, input_waveform, sample_rate=16000): |
| """ |
| 这里的 forward 不一定要叫 forward,也可以拆成别的方法; |
| 但是如果想兼容 pipeline,需要在 forward 里给出核心逻辑。 |
| |
| 参数: |
| input_waveform: [batch_size, waveform_length] |
| sample_rate: 默认 16000 |
| 返回: |
| 重构后的语音音频 (Tensor) |
| """ |
| |
| |
| input_features = self.feature_extractor( |
| input_waveform, |
| sampling_rate=sample_rate, |
| return_tensors="pt" |
| ).input_features.to(self.device) |
|
|
| |
| semantic_output = self.semantic_model(input_features) |
| semantic_hidden_16 = semantic_output.hidden_states[16] |
| semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) |
| semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16) |
|
|
| |
| wav = input_waveform.unsqueeze(1).to(self.device) |
| vq_emb = self.CodecEnc(wav) |
| vq_emb = vq_emb.transpose(1, 2) |
|
|
| |
| |
| if vq_emb.shape[-1] != semantic_encoded.shape[-1]: |
| |
| min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1]) |
| vq_emb = vq_emb[:, :, :min_len] |
| semantic_encoded = semantic_encoded[:, :, :min_len] |
|
|
| |
| concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) |
|
|
| |
| concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2) |
|
|
| |
| _, vq_code, _ = self.generator(concat_emb, vq=True) |
| vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) |
| vq_post_emb = vq_post_emb.transpose(1, 2) |
|
|
| |
| vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) |
|
|
| |
| recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] |
| |
| return recon_audio |
|
|
| def encode_code(self, input_waveform, sample_rate=16000): |
| """ |
| 将输入的音频编码为代码表示。 |
| |
| 参数: |
| input_waveform: [batch_size, waveform_length] |
| sample_rate: 默认 16000 |
| 返回: |
| 编码后的代码 (Tensor) |
| """ |
| with torch.no_grad(): |
| |
| input_features = self.feature_extractor( |
| input_waveform, |
| sampling_rate=sample_rate, |
| return_tensors="pt" |
| ).input_features.to(self.device) |
|
|
| |
| semantic_output = self.semantic_model(input_features) |
| semantic_hidden_16 = semantic_output.hidden_states[16] |
| semantic_hidden_16 = semantic_hidden_16.transpose(1, 2) |
| semantic_encoded = self.SemanticEncoder_module(semantic_hidden_16) |
|
|
| |
| wav = input_waveform.unsqueeze(1).to(self.device) |
| vq_emb = self.CodecEnc(wav) |
| vq_emb = vq_emb.transpose(1, 2) |
|
|
| |
| if vq_emb.shape[-1] != semantic_encoded.shape[-1]: |
| min_len = min(vq_emb.shape[-1], semantic_encoded.shape[-1]) |
| vq_emb = vq_emb[:, :, :min_len] |
| semantic_encoded = semantic_encoded[:, :, :min_len] |
|
|
| |
| concat_emb = torch.cat([semantic_encoded, vq_emb], dim=1) |
|
|
| |
| concat_emb = self.fc_prior(concat_emb.transpose(1, 2)).transpose(1, 2) |
|
|
| |
| _, vq_code, _ = self.generator(concat_emb, vq=True) |
| |
| return vq_code |
|
|
| def decode_code(self, vq_code): |
| """ |
| 将编码后的代码解码回音频。 |
| |
| 参数: |
| vq_code: 编码后的代码 (Tensor) [batch, frames] |
| 返回: |
| 解码后的音频 (Tensor) [batch, waveform_length] |
| """ |
| with torch.no_grad(): |
| |
| vq_post_emb = self.generator.quantizer.get_output_from_indices(vq_code.transpose(1, 2)) |
| vq_post_emb = vq_post_emb.transpose(1, 2) |
|
|
| |
| vq_post_emb = self.fc_post_a(vq_post_emb.transpose(1, 2)).transpose(1, 2) |
|
|
| |
| recon_audio = self.generator(vq_post_emb.transpose(1, 2), vq=False)[0] |
| return recon_audio |
|
|