| import torch |
| import logging |
| from contextlib import asynccontextmanager |
| from fastapi import FastAPI, Request, Form |
| from fastapi.responses import HTMLResponse |
| from fastapi.templating import Jinja2Templates |
| from fastapi.staticfiles import StaticFiles |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| model = None |
| tokenizer = None |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Load model on startup and cleanup on shutdown""" |
| global model, tokenizer |
|
|
| try: |
| model_id = "codeby-hp/FinetuneTinybert-SentimentClassification" |
| |
| logger.info(f"Loading tokenizer from {model_id}...") |
| tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
| logger.info(f"Loading model from {model_id}...") |
| model = AutoModelForSequenceClassification.from_pretrained(model_id) |
| model.to(device) |
| model.eval() |
|
|
| logger.info(f"Model loaded successfully on {device}") |
| except Exception as e: |
| logger.error(f"Error loading model: {e}") |
| raise |
|
|
| yield |
|
|
| logger.info("Shutting down...") |
|
|
|
|
| app = FastAPI(title="Sentiment Analysis API", lifespan=lifespan) |
|
|
| templates = Jinja2Templates(directory="templates") |
|
|
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def home(request: Request): |
| """Render the home page""" |
| return templates.TemplateResponse("index.html", {"request": request}) |
|
|
|
|
| @app.post("/predict") |
| async def predict(request: Request, text: str = Form(...)): |
| """Predict sentiment for the given text""" |
| if not text.strip(): |
| return templates.TemplateResponse( |
| "index.html", |
| {"request": request, "error": "Please enter some text to analyze"}, |
| ) |
|
|
| try: |
| inputs = tokenizer( |
| text, return_tensors="pt", truncation=True, max_length=512, padding=True |
| ) |
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits = outputs.logits |
| probabilities = torch.nn.functional.softmax(logits, dim=-1) |
| predicted_class = torch.argmax(probabilities, dim=-1).item() |
| confidence = probabilities[0][predicted_class].item() |
|
|
| sentiment_map = {0: "Negative", 1: "Positive"} |
| sentiment = sentiment_map.get(predicted_class, "Unknown") |
|
|
| return templates.TemplateResponse( |
| "index.html", |
| { |
| "request": request, |
| "text": text, |
| "sentiment": sentiment, |
| "confidence": round(confidence * 100, 2), |
| }, |
| ) |
|
|
| except Exception as e: |
| logger.error(f"Prediction error: {e}") |
| return templates.TemplateResponse( |
| "index.html", {"request": request, "error": f"An error occurred: {str(e)}"} |
| ) |
|
|
|
|
| @app.get("/health") |
| async def health_check(): |
| """Health check endpoint""" |
| return { |
| "status": "healthy", |
| "model_loaded": model is not None, |
| "device": str(device), |
| } |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
|
|
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|