| import gradio as gr |
| import os |
| import shutil |
| import logging |
| from huggingface_hub import hf_hub_download |
| from audio_separator.separator import Separator |
| |
| from audio_separator.separator.architectures import bs_roformer_separator |
|
|
| |
| REPO_ID = "anvuew/dereverb_room" |
| MODEL_FILENAME = "dereverb_room_anvuew_sdr_13.7432.ckpt" |
| CONFIG_FILENAME = "dereverb_room_anvuew.yaml" |
| |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| def inference(audio_path): |
| if not audio_path: return None |
|
|
| local_models_dir = os.path.abspath("models") |
| os.makedirs(local_models_dir, exist_ok=True) |
|
|
| logger.info(f"Downloading model files...") |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, local_dir=local_models_dir) |
| config_path = hf_hub_download(repo_id=REPO_ID, filename=CONFIG_FILENAME, local_dir=local_models_dir) |
|
|
| expected_config_path = os.path.splitext(model_path)[0] + ".yaml" |
| if config_path != expected_config_path: |
| shutil.copyfile(config_path, expected_config_path) |
|
|
| logger.info("Registering custom model...") |
| |
| bs_roformer_separator.BS_ROFORMER_MODELS[MODEL_FILENAME] = { |
| "model_type": "bs_roformer", |
| "config_filename": os.path.basename(expected_config_path), |
| "model_filename": MODEL_FILENAME, |
| "model_friendly_name": "Custom Dereverb", |
| "domain": "dereverb", |
| "source": "local" |
| } |
| |
|
|
| logger.info("Initializing separator...") |
| separator = Separator( |
| model_file_dir=local_models_dir, |
| output_dir=".", |
| output_format="FLAC", |
| log_level=logging.INFO |
| ) |
|
|
| logger.info(f"Loading model: {MODEL_FILENAME}...") |
| separator.load_model(model_filename=MODEL_FILENAME) |
|
|
| logger.info("Starting separation...") |
| output_files = separator.separate(audio_path) |
| return output_files[0] |
|
|
| with gr.Blocks(title="Dereverb Room Web UI") as demo: |
| gr.Markdown("# Dereverb Room Inference") |
| with gr.Row(): |
| input_audio = gr.Audio(label="Input", type="filepath") |
| output_audio = gr.Audio(label="Output (Dereverbed)", type="filepath", interactive=False) |
| gr.Button("Remove Reverb").click(fn=inference, inputs=input_audio, outputs=output_audio) |
|
|
| if __name__ == "__main__": |
| demo.launch() |