| from fastapi import FastAPI, HTTPException, UploadFile, File, Form |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from model import SimpleTransformerModel, FullChatDataset, VoiceInterface, generate_response |
| import torch |
| import uvicorn |
| import os |
| from typing import Optional |
|
|
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| dataset = FullChatDataset() |
| model = SimpleTransformerModel(len(dataset.tokenizer)) |
| voice_interface = VoiceInterface() |
|
|
| class ChatRequest(BaseModel): |
| prompt: str |
| max_length: int = 100 |
| use_voice: bool = False |
|
|
| @app.post("/chat/") |
| async def chat_endpoint( |
| prompt: Optional[str] = Form(None), |
| max_length: int = Form(100), |
| use_voice: bool = Form(False), |
| audio_file: Optional[UploadFile] = File(None) |
| ): |
| try: |
| |
| if audio_file: |
| contents = await audio_file.read() |
| with open("temp_audio.wav", "wb") as f: |
| f.write(contents) |
| |
| with sr.AudioFile("temp_audio.wav") as source: |
| audio = voice_interface.recognizer.record(source) |
| prompt = voice_interface.recognizer.recognize_google(audio) |
| os.remove("temp_audio.wav") |
| |
| |
| if not prompt: |
| raise HTTPException(status_code=400, detail="No input provided") |
| |
| response = generate_response( |
| model, |
| dataset.tokenizer, |
| prompt, |
| max_length, |
| voice_interface if use_voice else None |
| ) |
| |
| return {"response": response} |
| |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.get("/") |
| async def read_root(): |
| return {"message": "CyberFuture Running"} |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=8000) |