| import torch |
| from transformers import WhisperFeatureExtractor |
| from models.tinyoctopus import TINYOCTOPUS |
| from utils import prepare_one_sample |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = TINYOCTOPUS.from_config(cfg.config.model) |
| model.to(device) |
| model.eval() |
|
|
| |
| wav_processor = WhisperFeatureExtractor.from_pretrained("distil-whisper/distil-large-v3") |
|
|
| def transcribe(audio_path, task="dialect"): |
| """ |
| Perform inference on an audio file. |
| |
| Args: |
| audio_path (str): Path to the audio file. |
| task (str): Task to perform. Options: "dialect", "asr", "translation". |
| |
| Returns: |
| str: The generated text. |
| """ |
| task_prompts = { |
| "dialect": "What is the dialect of the speaker?", |
| "asr": "تعرف على الكلام وأعطني النص.", |
| "translation": "الرجاء ترجمة هذا المقطع الصوتي إلى اللغة الإنجليزية." |
| } |
|
|
| if task not in task_prompts: |
| raise ValueError("Invalid task. Choose from: 'dialect', 'asr', or 'translation'.") |
|
|
| try: |
| prompt = task_prompts[task] |
| samples = prepare_one_sample(audio_path, wav_processor) |
| prompt = [f"<Speech><SpeechHere></Speech> {prompt.strip()}"] |
| generated_text = model.generate(samples, {"temperature": 0.7}, prompts=prompt)[0] |
| return generated_text.replace('<s>', '').replace('</s>', '').strip() |
|
|
| except Exception as e: |
| return f"Error: {e}" |
|
|