| | import gc |
| | import logging |
| | import threading |
| |
|
| | import torch |
| | from transformers import LlamaTokenizer |
| |
|
| | from modules import config |
| | from modules.ChatTTS import ChatTTS |
| | from modules.devices import devices |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | chat_tts = None |
| | lock = threading.Lock() |
| |
|
| |
|
| | def load_chat_tts_in_thread(): |
| | global chat_tts |
| | if chat_tts: |
| | return |
| |
|
| | logger.info("Loading ChatTTS models") |
| | chat_tts = ChatTTS.Chat() |
| | device = devices.get_device_for("chattts") |
| | dtype = devices.dtype |
| | chat_tts.load_models( |
| | compile=config.runtime_env_vars.compile, |
| | source="local", |
| | local_path="./models/ChatTTS", |
| | device=device, |
| | dtype=dtype, |
| | dtype_vocos=devices.dtype_vocos, |
| | dtype_dvae=devices.dtype_dvae, |
| | dtype_gpt=devices.dtype_gpt, |
| | dtype_decoder=devices.dtype_decoder, |
| | ) |
| |
|
| | |
| | |
| | if device == devices.cpu and dtype == torch.float16: |
| | logger.warning( |
| | "The device is CPU and dtype is float16, which may not work properly. It is recommended to use float32 by enabling the `--no_half` parameter." |
| | ) |
| |
|
| | devices.torch_gc() |
| | logger.info("ChatTTS models loaded") |
| |
|
| |
|
| | def load_chat_tts(): |
| | with lock: |
| | if chat_tts is None: |
| | load_chat_tts_in_thread() |
| | if chat_tts is None: |
| | raise Exception("Failed to load ChatTTS models") |
| | return chat_tts |
| |
|
| |
|
| | def unload_chat_tts(): |
| | logging.info("Unloading ChatTTS models") |
| | global chat_tts |
| |
|
| | if chat_tts: |
| | for model_name, model in chat_tts.pretrain_models.items(): |
| | if isinstance(model, torch.nn.Module): |
| | model.cpu() |
| | del model |
| | chat_tts = None |
| | devices.torch_gc() |
| | gc.collect() |
| | logger.info("ChatTTS models unloaded") |
| |
|
| |
|
| | def reload_chat_tts(): |
| | logging.info("Reloading ChatTTS models") |
| | unload_chat_tts() |
| | instance = load_chat_tts() |
| | logger.info("ChatTTS models reloaded") |
| | return instance |
| |
|
| |
|
| | def get_tokenizer() -> LlamaTokenizer: |
| | chat_tts = load_chat_tts() |
| | tokenizer = chat_tts.pretrain_models["tokenizer"] |
| | return tokenizer |
| |
|