| | from fastapi import FastAPI, WebSocket, Request, WebSocketDisconnect |
| | from fastapi.staticfiles import StaticFiles |
| | from fastapi.responses import HTMLResponse |
| | from fastapi.templating import Jinja2Templates |
| |
|
| | import numpy as np |
| | from transformers import pipeline |
| | import torch |
| | from transformers.pipelines.audio_utils import ffmpeg_microphone_live |
| |
|
| | device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| |
|
| | classifier = pipeline( |
| | "audio-classification", model="MIT/ast-finetuned-speech-commands-v2", device=device |
| | ) |
| | intent_class_pipe = pipeline( |
| | "audio-classification", model="anton-l/xtreme_s_xlsr_minds14", device=device |
| | ) |
| |
|
| |
|
| | async def launch_fn( |
| | wake_word="marvin", |
| | prob_threshold=0.5, |
| | chunk_length_s=2.0, |
| | stream_chunk_s=0.25, |
| | debug=False, |
| | ): |
| | if wake_word not in classifier.model.config.label2id.keys(): |
| | raise ValueError( |
| | f"Wake word {wake_word} not in set of valid class labels, pick a wake word in the set {classifier.model.config.label2id.keys()}." |
| | ) |
| |
|
| | sampling_rate = classifier.feature_extractor.sampling_rate |
| |
|
| | mic = ffmpeg_microphone_live( |
| | sampling_rate=sampling_rate, |
| | chunk_length_s=chunk_length_s, |
| | stream_chunk_s=stream_chunk_s, |
| | ) |
| |
|
| | print("Listening for wake word...") |
| | for prediction in classifier(mic): |
| | prediction = prediction[0] |
| | if debug: |
| | print(prediction) |
| | if prediction["label"] == wake_word: |
| | if prediction["score"] > prob_threshold: |
| | return True |
| |
|
| |
|
| | async def listen(websocket, chunk_length_s=2.0, stream_chunk_s=2.0): |
| | sampling_rate = intent_class_pipe.feature_extractor.sampling_rate |
| |
|
| | mic = ffmpeg_microphone_live( |
| | sampling_rate=sampling_rate, |
| | chunk_length_s=chunk_length_s, |
| | stream_chunk_s=stream_chunk_s, |
| | ) |
| | audio_buffer = [] |
| | |
| | print("Listening") |
| | for i in range(4): |
| | audio_chunk = next(mic) |
| | audio_buffer.append(audio_chunk["raw"]) |
| | |
| | prediction = intent_class_pipe(audio_chunk["raw"]) |
| | print(prediction) |
| | await websocket.send_text(f"chunk: {prediction[0]['label']} | {i+1} / 4") |
| | |
| | if await is_silence(audio_chunk["raw"], threshold=0.7): |
| | print("Silence detected, processing audio.") |
| | break |
| |
|
| | combined_audio = np.concatenate(audio_buffer) |
| | prediction = intent_class_pipe(combined_audio) |
| | top_3_predictions = prediction[:3] |
| | formatted_predictions = "\n".join([f"{pred['label']}: {pred['score'] * 100:.2f}%" for pred in top_3_predictions]) |
| | await websocket.send_text(f"classes: \n{formatted_predictions}") |
| | return |
| |
|
| |
|
| | async def is_silence(audio_chunk, threshold): |
| | silence = intent_class_pipe(audio_chunk) |
| | if silence[0]["label"] == "silence" and silence[0]["score"] > threshold: |
| | return True |
| | else: |
| | return False |
| |
|
| |
|
| | |
| | app = FastAPI() |
| |
|
| | |
| | app.mount("/static", StaticFiles(directory="static"), name="static") |
| |
|
| | |
| | templates = Jinja2Templates(directory="templates") |
| |
|
| |
|
| | @app.get("/", response_class=HTMLResponse) |
| | async def get_home(request: Request): |
| | return templates.TemplateResponse("index.html", {"request": request}) |
| |
|
| |
|
| | @app.websocket("/ws") |
| | async def websocket_endpoint(websocket: WebSocket): |
| | await websocket.accept() |
| | try: |
| | process_active = False |
| |
|
| | while True: |
| | message = await websocket.receive_text() |
| |
|
| | if message == "start" and not process_active: |
| | process_active = True |
| | await websocket.send_text("Listening for wake word...") |
| | wake_word_detected = await launch_fn(debug=True) |
| | if wake_word_detected: |
| | await websocket.send_text("Wake word detected. Listening for your query...") |
| | await listen(websocket) |
| | process_active = False |
| |
|
| | elif message == "stop": |
| | if process_active: |
| | |
| | |
| | process_active = False |
| | await websocket.send_text("Process stopped. Ready to restart.") |
| | break |
| |
|
| | except WebSocketDisconnect: |
| | print("Client disconnected.") |
| |
|