| import os |
| import gradio as gr |
| from qdrant_client import QdrantClient |
| from transformers import ClapModel, ClapProcessor |
| from dotenv import load_dotenv |
| import requests |
|
|
| |
| load_dotenv() |
|
|
| |
| QDRANT_URL = os.getenv('QDRANT_URL') |
| QDRANT_KEY = os.getenv('QDRANT_KEY') |
|
|
| |
| print(f"QDRANT_URL: {QDRANT_URL}") |
| print(f"QDRANT_KEY: {QDRANT_KEY}") |
|
|
| try: |
| |
| response = requests.get(QDRANT_URL) |
| print(f"Test de la connexion à Qdrant: {response.status_code}") |
|
|
| |
| if not QDRANT_URL or not QDRANT_KEY: |
| raise ValueError("Les variables d'environnement QDRANT_URL ou QDRANT_KEY ne sont pas définies") |
|
|
| |
| client = QdrantClient(QDRANT_URL, api_key=QDRANT_KEY) |
| print("[INFO] Client created...") |
|
|
| |
| print("[INFO] Loading the model...") |
| model_name = "laion/larger_clap_general" |
| model = ClapModel.from_pretrained(model_name) |
| processor = ClapProcessor.from_pretrained(model_name) |
|
|
| |
| max_results = 10 |
|
|
| def sound_search(query): |
| text_inputs = processor(text=query, return_tensors="pt") |
| text_embed = model.get_text_features(**text_inputs)[0] |
|
|
| hits = client.search( |
| collection_name="demo_spaces_db", |
| query_vector=text_embed, |
| limit=max_results, |
| ) |
| return [ |
| gr.Audio( |
| hit.payload['audio_path'], |
| label=f"style: {hit.payload['style']} -- score: {hit.score}") |
| for hit in hits |
| ] |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# Sound search database") |
| inp = gr.Textbox(placeholder="What sound are you looking for ?") |
| out = [gr.Audio(label=f"{x}") for x in range(max_results)] |
| inp.change(sound_search, inp, out) |
|
|
| demo.launch() |
|
|
| except Exception as e: |
| print(f"[ERROR] Failed to create Qdrant client: {e}") |
|
|