| | from fastapi import FastAPI |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel |
| | from transformers.pipelines import pipeline |
| | import os |
| |
|
| | os.environ["HF_HOME"] = "/tmp" |
| |
|
| | SPAM_MODEL = "cjell/spam-detector-roberta" |
| | TOXIC_MODEL = "s-nlp/roberta_toxicity_classifier" |
| | SENTIMENT_MODEL = "nlptown/bert-base-multilingual-uncased-sentiment" |
| | NSFW_MODEL = "michellejieli/NSFW_text_classifier" |
| | HATE_MODEL = "facebook/roberta-hate-speech-dynabench-r4-target" |
| | IMAGE_MODEL = "Falconsai/nsfw_image_detection" |
| |
|
| | spam = pipeline("text-classification", model=SPAM_MODEL) |
| |
|
| | toxic = pipeline("text-classification", model=TOXIC_MODEL) |
| |
|
| | sentiment = pipeline("text-classification", model = SENTIMENT_MODEL) |
| |
|
| | nsfw = pipeline("text-classification", model = NSFW_MODEL) |
| |
|
| | hate = pipeline("text-classification", model = HATE_MODEL) |
| |
|
| | image = pipeline("image-classification", model = IMAGE_MODEL) |
| |
|
| |
|
| | app = FastAPI() |
| |
|
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | @app.get("/") |
| | def root(): |
| | return {"status": "ok"} |
| |
|
| | class Query(BaseModel): |
| | text: str |
| |
|
| | @app.post("/spam") |
| | def predict_spam(query: Query): |
| | result = spam(query.text)[0] |
| | return {"label": result["label"], "score": result["score"]} |
| |
|
| | @app.post("/toxic") |
| | def predict_toxic(query: Query): |
| | result = toxic(query.text)[0] |
| | return {"label": result["label"], "score": result["score"]} |
| |
|
| | @app.post("/sentiment") |
| | def predict_sentiment(query: Query): |
| | result = sentiment(query.text)[0] |
| | return {"label": result["label"], "score": result["score"]} |
| |
|
| | @app.post("/nsfw") |
| | def predict_nsfw(query: Query): |
| | result = nsfw(query.text)[0] |
| | return {"label": result["label"], "score": result["score"]} |
| |
|
| | @app.post("/hate") |
| | def predict_hate(query: Query): |
| | result = hate(query.text)[0] |
| | return {"label": result["label"], "score": result["score"]} |
| |
|
| |
|
| | @app.get("/health") |
| | def health_check(): |
| |
|
| | status = { |
| | "server": "running", |
| | "models": {} |
| | } |
| |
|
| | models = { |
| | "spam": (SPAM_MODEL, spam), |
| | "toxic": (TOXIC_MODEL, toxic), |
| | "sentiment": (SENTIMENT_MODEL, sentiment), |
| | "nsfw": (NSFW_MODEL, nsfw), |
| | } |
| |
|
| | for key, (model_name, model_pipeline) in models.items(): |
| | try: |
| | model_pipeline("test") |
| | status["models"][key] = { |
| | "model_name": model_name, |
| | "status": "running" |
| | } |
| | except Exception as e: |
| | status["models"][key] = { |
| | "model_name": model_name, |
| | "status": f"error: {str(e)}" |
| | } |
| |
|
| | return status |
| |
|