Spaces:
Build error
Build error
| import os | |
| import tensorflow as tf | |
| import tensorflow_hub as hub | |
| import numpy as np | |
| import csv | |
| import requests | |
| import json | |
| import logging | |
| import scipy | |
| from scipy.io import wavfile | |
| from pydub import AudioSegment | |
| import io | |
| from io import BytesIO | |
| # Load the model | |
| model = hub.load('Audio_Multiple_v1') | |
| def class_names_from_csv(class_map_csv_text): | |
| """Returns list of class names corresponding to score vector.""" | |
| class_names = [] | |
| with tf.io.gfile.GFile(class_map_csv_text) as csvfile: | |
| reader = csv.DictReader(csvfile) | |
| for row in reader: | |
| class_names.append(row['display_name']) | |
| return class_names | |
| class_map_path = model.class_map_path().numpy() | |
| class_names = class_names_from_csv(class_map_path) | |
| def ensure_sample_rate(original_sample_rate, waveform, desired_sample_rate=16000): | |
| if original_sample_rate != desired_sample_rate: | |
| desired_length = int(round(float(len(waveform)) / original_sample_rate * desired_sample_rate)) | |
| waveform = np.array(scipy.signal.resample(waveform, desired_length), dtype=np.float32) | |
| return desired_sample_rate, waveform | |
| def convert_mp3_to_wav(mp3_data): | |
| audio = AudioSegment.from_file(io.BytesIO(mp3_data), format="mp3") | |
| wav_buffer = io.BytesIO() | |
| audio.export(wav_buffer, format='wav') | |
| wav_buffer.seek(0) | |
| return wav_buffer.getvalue() | |
| def process_audio_file(file_data, url, file_id): | |
| try: | |
| sample_rate, wav_data = wavfile.read(BytesIO(file_data)) | |
| if wav_data.ndim > 1: | |
| wav_data = np.mean(wav_data, axis=1) | |
| sample_rate, wav_data = ensure_sample_rate(sample_rate, wav_data) | |
| waveform = wav_data / tf.int16.max | |
| scores, embeddings, spectrogram = model(waveform) | |
| scores_np = scores.numpy() | |
| spectrogram_np = spectrogram.numpy() | |
| mean_scores = np.mean(scores, axis=0) | |
| top_two_indices = np.argsort(mean_scores)[-2:][::-1] | |
| inferred_class = class_names[top_two_indices[0]] | |
| if inferred_class == "Silence" and len(top_two_indices) > 1: | |
| inferred_class = class_names[top_two_indices[1]] | |
| answer_dict = {'url': url, 'answer': [inferred_class], qcUser: None, "normalfileID": file_id} | |
| return answer_dict | |
| except Exception as e: | |
| logging.error(f"Error processing {url}: {e}") | |
| return None | |
| def get_audio_data(url): | |
| response = requests.get(url) | |
| response.raise_for_status() | |
| return response.content | |
| # def send_results_to_api(data, result_url): | |
| # headers = {"Content-Type": "application/json"} | |
| # try: | |
| # response = requests.post(result_url, json=data, headers=headers) | |
| # response.raise_for_status() # Raise error for non-200 responses | |
| # return response.json() # Return any JSON response from the API | |
| # except requests.exceptions.HTTPError as http_err: | |
| # logging.error(f"HTTP error occurred: {http_err}") | |
| # return {"error": f"HTTP error occurred: {http_err}"} | |
| # except requests.exceptions.RequestException as req_err: | |
| # logging.error(f"Request error occurred: {req_err}") | |
| # return {"error": f"Request error occurred: {req_err}"} | |
| # except ValueError as val_err: | |
| # logging.error(f"Error decoding JSON response: {val_err}") | |
| # return {"error": f"Error decoding JSON response: {val_err}"} | |
| def process_audio(params): | |
| try: | |
| params = json.loads(params) | |
| except json.JSONDecodeError as e: | |
| return {"error": f"Invalid JSON input: {e.msg} at line {e.lineno} column {e.colno}"} | |
| audio_files = params.get("urls", []) | |
| if not params.get("normalfileID",[]): | |
| file_ids = [None]*len(audio_files) | |
| else: | |
| file_ids = params.get("normalfileID",[]) | |
| # api = params.get("api", "") | |
| # job_id = params.get("job_id", "") | |
| solutions = [] | |
| for audio_url,file_id in zip(audio_files, file_ids): | |
| audio_data = get_audio_data(audio_url) | |
| if audio_url.endswith(".mp3"): | |
| wav_data = convert_mp3_to_wav(audio_data) | |
| result = process_audio_file(wav_data, audio_url, file_id) | |
| elif audio_url.endswith(".wav"): | |
| result = process_audio_file(audio_data, audio_url, file_id) | |
| if result: | |
| solutions.append(result) | |
| # result_url = f"{api}/{job_id}" | |
| # send_results_to_api(solutions, result_url) | |
| return json.dumps({"solutions": solutions}) | |
| import gradio as gr | |
| inputt = gr.Textbox(label="Parameters (JSON format) Eg. {'urls':['file1.mp3','file2.wav']}") | |
| outputs = gr.JSON() | |
| application = gr.Interface(fn=process_audio, inputs=inputt, outputs=outputs, title="Audio Classification with API Integration") | |
| application.launch() |