xjsc0's picture
1
61e6f25
import torch
# from vocos import Vocos
from singer.model import Singer
def load_model(model_cls, model_cfg, ckpt_path, vocab_char_map, device="cuda"):
model_arc = model_cfg.model.arch
mel_spec_kwargs = model_cfg.model.mel_spec
vocab_size = len(vocab_char_map)
backbone = model_cls(
**model_arc, text_num_embeds=vocab_size, mel_dim=mel_spec_kwargs.n_mel_channels
)
model = Singer(
transformer=backbone,
mel_spec_kwargs=mel_spec_kwargs,
vocab_char_map=vocab_char_map,
)
checkpoint = torch.load(ckpt_path, map_location="cpu")
if "ema_model_state_dict" in checkpoint:
state_dict = checkpoint["ema_model_state_dict"]
elif "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
else:
state_dict = checkpoint
# Handle module prefix
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith("module."):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
model.load_state_dict(new_state_dict)
model.to(device)
model.eval()
return model
def load_vocoder(vocoder_name, is_local, local_path, device="cuda"):
if vocoder_name == "vocos":
if is_local:
vocoder = Vocos.from_hparams(local_path).to(device)
else:
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
elif vocoder_name == "bigvgan":
# Placeholder for bigvgan
# You might need to import bigvgan here
raise NotImplementedError("BigVGAN loading not implemented yet")
else:
# Fallback or error
print(
f"Warning: Unknown vocoder {vocoder_name}, trying to load from local path if provided"
)
if is_local:
# Try loading as vocos or similar if generic
vocoder = Vocos.from_hparams(local_path).to(device)
else:
raise ValueError(f"Unknown vocoder: {vocoder_name}")
return vocoder