| | import os |
| | import queue |
| | import asyncio |
| | import concurrent.futures |
| | import functools |
| | import io |
| | import sys |
| | import random |
| | from threading import Thread |
| | import time |
| |
|
| | from dotenv import load_dotenv |
| |
|
| | import pyaudio |
| | import speech_recognition as sr |
| | import websockets |
| | from aioconsole import ainput |
| | from pydub import AudioSegment |
| | from simpleaudio import WaveObject |
| |
|
| | load_dotenv() |
| |
|
| | executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) |
| | web2_initial_message = True |
| |
|
| | CHUNK = 1024 |
| | FORMAT = pyaudio.paInt16 |
| | CHANNELS = 1 |
| | RATE = 44100 |
| |
|
| |
|
| | class AudioPlayer: |
| | def __init__(self): |
| | self.play_thread = None |
| | self.stop_flag = False |
| | self.queue = queue.Queue() |
| |
|
| | def play_audio(self): |
| | while not self.stop_flag or not self.queue.empty(): |
| | try: |
| | wav_data = self.queue.get_nowait() |
| | except queue.Empty: |
| | continue |
| |
|
| | wave_obj = WaveObject.from_wave_file(wav_data) |
| | play_obj = wave_obj.play() |
| |
|
| | while play_obj.is_playing() and not self.stop_flag: |
| | time.sleep(0.1) |
| |
|
| | if self.stop_flag: |
| | play_obj.stop() |
| |
|
| | def start_playing(self, wav_data): |
| | self.stop_flag = False |
| | self.queue.put(wav_data) |
| |
|
| | if self.play_thread is None or not self.play_thread.is_alive(): |
| | self.play_thread = Thread(target=self.play_audio) |
| | self.play_thread.start() |
| |
|
| | def stop_playing(self): |
| | if self.play_thread and self.play_thread.is_alive(): |
| | self.stop_flag = True |
| | self.play_thread.join() |
| | self.play_thread = None |
| |
|
| | def add_to_queue(self, wav_data): |
| | self.queue.put(wav_data) |
| |
|
| |
|
| | audio_player = AudioPlayer() |
| |
|
| |
|
| | def get_input_device_id(): |
| | p = pyaudio.PyAudio() |
| | devices = [(i, p.get_device_info_by_index(i)['name']) |
| | for i in range(p.get_device_count()) |
| | if p.get_device_info_by_index(i).get('maxInputChannels')] |
| |
|
| | print('Available devices:') |
| | for id, name in devices: |
| | print(f"Device id {id} - {name}") |
| |
|
| | return int(input('Please select device id: ')) |
| |
|
| |
|
| | async def handle_audio(websocket, device_id): |
| | with sr.Microphone(device_index=device_id, sample_rate=RATE) as source: |
| | recognizer = sr.Recognizer() |
| | print('Source sample rate: ', source.SAMPLE_RATE) |
| | print('Source width: ', source.SAMPLE_WIDTH) |
| | print('Adjusting for ambient noise...Wait for 2 seconds') |
| | recognizer.energy_threshold = 5000 |
| | recognizer.dynamic_energy_ratio = 6 |
| | recognizer.dynamic_energy_adjustment_damping = 0.85 |
| | recognizer.non_speaking_duration = 0.5 |
| | recognizer.pause_threshold = 0.8 |
| | recognizer.phrase_threshold = 0.5 |
| | recognizer.adjust_for_ambient_noise(source, duration=2) |
| | listen_func = functools.partial( |
| | recognizer.listen, source, phrase_time_limit=30) |
| |
|
| | print('Okay, start talking!') |
| | while True: |
| | print('[*]', end="") |
| | audio = await asyncio.get_event_loop().run_in_executor(executor, listen_func) |
| | await websocket.send(audio.frame_data) |
| | print('[-]', end="") |
| | await asyncio.sleep(2) |
| |
|
| |
|
| | async def handle_text(websocket): |
| | print('You: ', end="", flush=False) |
| | while True: |
| | message = await ainput() |
| | await websocket.send(message) |
| |
|
| | initial_message = True |
| | async def receive_message(websocket, websocket2): |
| | web1_init_message = await websocket.recv() |
| | print('web1_init_message: ', web1_init_message) |
| |
|
| | web2_init_message = await websocket2.recv() |
| | print('web1_init_message: ', web2_init_message) |
| | message_to_websocket1 = "Suppose I'm Steve Jobs now. What question do you have for me?" |
| | await websocket.send(message_to_websocket1) |
| |
|
| | web1_message = '' |
| | while True: |
| | try: |
| | message = await websocket.recv() |
| | print('here') |
| | except websockets.exceptions.ConnectionClosedError as e: |
| | print("Connection closed unexpectedly: ", e) |
| | break |
| | except Exception as e: |
| | print("An error occurred: ", e) |
| | break |
| |
|
| | if isinstance(message, str): |
| | if message == '[end]\n': |
| | if not web1_message: |
| | continue |
| | |
| | message_to_websocket2 = web1_message[web1_message.find('> ') + 2:] |
| | |
| | await websocket2.send(message_to_websocket2) |
| | web2_message = '' |
| | j = 0 |
| | while True: |
| | j += 1 |
| | try: |
| | message = await websocket2.recv() |
| | except websockets.exceptions.ConnectionClosedError as e: |
| | print("Connection closed unexpectedly: ", e) |
| | break |
| | except Exception as e: |
| | print("An error occurred: ", e) |
| | break |
| |
|
| | if isinstance(message, str): |
| | if message == '[end]\n': |
| | |
| | if not web2_message: |
| | |
| | continue |
| | |
| | print(web2_message) |
| | message_from_websocket2 = web2_message[web2_message.find('> ') + 2:] |
| | await websocket.send(message_from_websocket2) |
| | break |
| | elif message.startswith('[+]'): |
| | |
| | audio_player.stop_playing() |
| | |
| | |
| | elif message.startswith('[=]'): |
| | |
| | |
| | pass |
| | else: |
| | |
| | web2_message += message |
| | elif isinstance(message, bytes): |
| | global web2_initial_message |
| | if web2_initial_message: |
| | web2_initial_message = False |
| | continue |
| | audio_data = io.BytesIO(message) |
| | audio = AudioSegment.from_mp3(audio_data) |
| | wav_data = io.BytesIO() |
| | audio.export(wav_data, format="wav") |
| | |
| | audio_player.start_playing(wav_data) |
| |
|
| | elif message.startswith('[+]'): |
| | |
| | audio_player.stop_playing() |
| | |
| | print(f"\n{message}", end="\n", flush=False) |
| | elif message.startswith('[=]'): |
| | |
| | print(f"{message}", end="\n", flush=False) |
| | else: |
| | web1_message += message |
| | print(f"{message}", end="", flush=False) |
| | elif isinstance(message, bytes): |
| | audio_data = io.BytesIO(message) |
| | audio = AudioSegment.from_mp3(audio_data) |
| | wav_data = io.BytesIO() |
| | audio.export(wav_data, format="wav") |
| | |
| | audio_player.start_playing(wav_data) |
| | else: |
| | print("Unexpected message") |
| | break |
| |
|
| |
|
| | def select_model(): |
| | llm_model_selection = input( |
| | '1: gpt-3.5-turbo-16k \n' |
| | '2: gpt-4 \n' |
| | '3: claude-2 \n' |
| | 'Select llm model:') |
| | if llm_model_selection == '1': |
| | llm_model = 'gpt-3.5-turbo-16k' |
| | elif llm_model_selection == '2': |
| | llm_model = 'gpt-4' |
| | elif llm_model_selection == '3': |
| | llm_model = 'claude-2' |
| | return llm_model |
| |
|
| |
|
| | async def start_client(client_id, url): |
| | api_key = os.getenv('AUTH_API_KEY') |
| | llm_model = select_model() |
| | uri = f"ws://{url}/ws/{client_id}?api_key={api_key}&llm_model={llm_model}" |
| | async with websockets.connect(uri) as websocket: |
| | uri2 = f"ws://{url}/ws/9999999?api_key={api_key}&llm_model={llm_model}" |
| | |
| | async with websockets.connect(uri2) as websocket2: |
| | await websocket.send('terminal') |
| | await websocket2.send('terminal') |
| | print(f"Client #{client_id} connected to websocket1") |
| | print(f"Client 9999999 connected to websocket2") |
| | welcome_message = await websocket.recv() |
| | welcome_message2 = await websocket2.recv() |
| | print(f"{welcome_message}") |
| | character = input('Select character: ') |
| | await websocket.send(character) |
| | await websocket2.send('6') |
| |
|
| | mode = input('Select mode (1: audio, 2: text): ') |
| | if mode.lower() == '1': |
| | device_id = get_input_device_id() |
| | send_task = asyncio.create_task(handle_audio(websocket, device_id)) |
| | else: |
| | send_task = asyncio.create_task(handle_text(websocket)) |
| |
|
| | receive_task = asyncio.create_task(receive_message(websocket, websocket2)) |
| | await asyncio.gather(receive_task, send_task) |
| |
|
| |
|
| | async def main(url): |
| | client_id = random.randint(0, 1000000) |
| | task = asyncio.create_task(start_client(client_id, url)) |
| | try: |
| | await task |
| | except KeyboardInterrupt: |
| | task.cancel() |
| | await asyncio.wait_for(task, timeout=None) |
| | print("Client stopped by user") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | url = sys.argv[1] if len(sys.argv) > 1 else 'localhost:8000' |
| | asyncio.run(main(url)) |
| |
|