from fastapi import FastAPI from pydantic import BaseModel import joblib import logging import json import shutil from pathlib import Path from huggingface_hub import hf_hub_download from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer from groq import Groq import os logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") app = FastAPI() models = None en_emotion_map = None # ====================== GROQ SETUP ====================== groq_client = Groq(api_key=os.getenv("GROQ_API_KEY")) def get_emotion_from_groq(text: str): """Use Groq to get accurate emotion for Sinhala text""" try: prompt = f""" Analyze the emotion of this Sinhala text. Reply with ONLY one word from: joy, sadness, anger, fear, surprise, disgust, love, neutral. Text: {text} Emotion: """ response = groq_client.chat.completions.create( model="llama-3.3-70b-versatile", messages=[{"role": "user", "content": prompt}], temperature=0.1, max_tokens=15 ) emotion = response.choices[0].message.content.strip().lower() logging.info(f"Groq returned: {emotion}") return emotion except Exception as e: logging.error(f"Groq API failed: {e}") return "neutral" def load_models(): global models, en_emotion_map if models is not None: return models logging.info("📥 Loading models...") try: # English Model en_repo = "E-motionAssistant/English_LR_Model_New" en_vectorizer = joblib.load(hf_hub_download(en_repo, "tfidf_vectorizer.joblib")) en_classifier = joblib.load(hf_hub_download(en_repo, "logreg_model.joblib")) en_label_encoder = joblib.load(hf_hub_download(en_repo, "label_encoder.joblib")) try: map_path = hf_hub_download(en_repo, "emotion_map.json") with open(map_path, "r", encoding="utf-8") as f: en_emotion_map = json.load(f) except: en_emotion_map = None # Sinhala Model (loaded but not used for prediction) si_vectorizer = joblib.load(hf_hub_download("E-motionAssistant/Sinhala_Text_Emotion_Model_LR", "tfidf_vectorizer.joblib")) si_classifier = joblib.load(hf_hub_download("E-motionAssistant/Sinhala_Text_Emotion_Model_LR", "logreg_model.joblib")) si_label_encoder = joblib.load(hf_hub_download("E-motionAssistant/Sinhala_Text_Emotion_Model_LR", "label_encoder.joblib")) # Tamil Model logging.info("📥 Loading Tamil model...") try: cache_dir = Path.home() / ".cache" / "huggingface" / "hub" model_cache = cache_dir / "models--E-motionAssistant--Tamil_Emotion_Recognition_Model" if model_cache.exists(): shutil.rmtree(model_cache) except: pass model_name = "E-motionAssistant/Tamil_Emotion_Recognition_Model" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) tamil_pipe = pipeline( "text-classification", model=model, tokenizer=tokenizer, device=-1, truncation=True, max_length=512 ) # Correct tuple assignment models = (en_vectorizer, en_classifier, en_label_encoder, si_vectorizer, si_classifier, si_label_encoder, tamil_pipe) logging.info("🎉 All models loaded successfully!") return models except Exception as e: logging.error(f"❌ Model loading failed: {e}") raise @app.on_event("startup") def startup_event(): load_models() class PredictRequest(BaseModel): text: str language: str @app.get("/") def root(): return {"status": "ok"} @app.post("/predict") def predict(req: PredictRequest): if not req.text or not req.text.strip(): return {"error": "Text cannot be empty"} if models is None: load_models() en_vec, en_clf, en_le, si_vec, si_clf, si_le, tamil_pipe = models try: lang = req.language.lower() if lang == "english": X = en_vec.transform([req.text]) pred = int(en_clf.predict(X)[0]) emotion = en_emotion_map.get(str(pred), "unknown") if en_emotion_map else en_le.inverse_transform([pred])[0] return {"emotion": emotion, "language": "English"} elif lang == "sinhala": # Use Groq API for accurate Sinhala emotion emotion = get_emotion_from_groq(req.text) return {"emotion": emotion, "language": "Sinhala"} elif lang == "tamil": result = tamil_pipe(req.text) emotion = result[0]['label'] score = round(float(result[0]['score']), 4) return {"emotion": emotion, "confidence": score, "language": "Tamil"} except Exception as e: logging.error(f"Prediction error: {e}") return {"error": str(e)}