| import logging |
| import os |
| import sys |
|
|
| from modules.ffmpeg_env import setup_ffmpeg_path |
|
|
| try: |
| setup_ffmpeg_path() |
| |
| logging.basicConfig( |
| level=os.getenv("LOG_LEVEL", "INFO"), |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
| ) |
| except BaseException: |
| pass |
|
|
| import argparse |
|
|
| from modules import config |
| from modules.api.api_setup import process_api_args, setup_api_args |
| from modules.api.app_config import app_description, app_title, app_version |
| from modules.gradio_dcls_fix import dcls_patch |
| from modules.models_setup import process_model_args, setup_model_args |
| from modules.utils.env import get_and_update_env |
| from modules.utils.ignore_warn import ignore_useless_warnings |
| from modules.utils.torch_opt import configure_torch_optimizations |
| from modules.webui import webui_config |
| from modules.webui.app import create_interface, webui_init |
|
|
| import subprocess |
|
|
| subprocess.run( |
| "pip install flash-attn --no-build-isolation", |
| env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, |
| shell=True, |
| ) |
|
|
| dcls_patch() |
| ignore_useless_warnings() |
|
|
|
|
| def setup_webui_args(parser: argparse.ArgumentParser): |
| parser.add_argument("--server_name", type=str, help="server name") |
| parser.add_argument("--server_port", type=int, help="server port") |
| parser.add_argument( |
| "--share", action="store_true", help="share the gradio interface" |
| ) |
| parser.add_argument("--debug", action="store_true", help="enable debug mode") |
| parser.add_argument("--auth", type=str, help="username:password for authentication") |
| parser.add_argument( |
| "--tts_max_len", |
| type=int, |
| help="Max length of text for TTS", |
| ) |
| parser.add_argument( |
| "--ssml_max_len", |
| type=int, |
| help="Max length of text for SSML", |
| ) |
| parser.add_argument( |
| "--max_batch_size", |
| type=int, |
| help="Max batch size for TTS", |
| ) |
| |
| parser.add_argument( |
| "--webui_experimental", |
| action="store_true", |
| help="Enable webui_experimental features", |
| ) |
| parser.add_argument( |
| "--language", |
| type=str, |
| help="Set the default language for the webui", |
| ) |
| parser.add_argument( |
| "--api", |
| action="store_true", |
| help="use api=True to launch the API together with the webui (run launch.py for only API server)", |
| ) |
|
|
|
|
| def process_webui_args(args): |
| server_name = get_and_update_env(args, "server_name", "0.0.0.0", str) |
| server_port = get_and_update_env(args, "server_port", 7860, int) |
| share = get_and_update_env(args, "share", False, bool) |
| debug = get_and_update_env(args, "debug", False, bool) |
| auth = get_and_update_env(args, "auth", None, str) |
| language = get_and_update_env(args, "language", "en", str) |
| api = get_and_update_env(args, "api", False, bool) |
|
|
| webui_config.experimental = get_and_update_env( |
| args, "webui_experimental", False, bool |
| ) |
| webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int) |
| webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int) |
| webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int) |
|
|
| webui_config.experimental = get_and_update_env( |
| args, "webui_experimental", False, bool |
| ) |
| webui_config.tts_max = get_and_update_env(args, "tts_max_len", 1000, int) |
| webui_config.ssml_max = get_and_update_env(args, "ssml_max_len", 5000, int) |
| webui_config.max_batch_size = get_and_update_env(args, "max_batch_size", 8, int) |
|
|
| configure_torch_optimizations() |
| webui_init() |
| demo = create_interface() |
|
|
| if auth: |
| auth = tuple(auth.split(":")) |
|
|
| app, local_url, share_url = demo.queue().launch( |
| server_name=server_name, |
| server_port=server_port, |
| share=share, |
| debug=debug, |
| auth=auth, |
| show_api=False, |
| prevent_thread_lock=True, |
| inbrowser=sys.platform == "win32", |
| app_kwargs={ |
| "title": app_title, |
| "description": app_description, |
| "version": app_version, |
| "redoc_url": ( |
| None |
| if api is False |
| else None if config.runtime_env_vars.no_docs else "/redoc" |
| ), |
| "docs_url": ( |
| None |
| if api is False |
| else None if config.runtime_env_vars.no_docs else "/docs" |
| ), |
| }, |
| ) |
| |
| |
| |
| |
| app.user_middleware = [ |
| x for x in app.user_middleware if x.cls.__name__ != "CustomCORSMiddleware" |
| ] |
|
|
| if api: |
| process_api_args(args, app) |
|
|
| demo.block_thread() |
|
|
|
|
| if __name__ == "__main__": |
| import dotenv |
|
|
| dotenv.load_dotenv( |
| dotenv_path=os.getenv("ENV_FILE", ".env.webui"), |
| ) |
|
|
| parser = argparse.ArgumentParser(description="Gradio App") |
|
|
| setup_webui_args(parser) |
| setup_model_args(parser) |
| setup_api_args(parser) |
|
|
| args = parser.parse_args() |
|
|
| process_model_args(args) |
| process_webui_args(args) |
|
|