| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| from typing import List |
| import uvicorn |
| from medimageinsightmodel import MedImageInsight |
| import base64 |
|
|
| |
| app = FastAPI(title="Medical Image Analysis API") |
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
| |
| classifier = MedImageInsight( |
| model_dir="2024.09.27", |
| vision_model_name="medimageinsigt-v1.0.0.pt", |
| language_model_name="language_model.pth" |
| ) |
| classifier.load_model() |
|
|
|
|
| class ClassificationRequest(BaseModel): |
| images: List[str] |
| labels: List[str] |
| multilabel : bool = False |
|
|
| class EmbeddingRequest(BaseModel): |
| images: List[str] = None |
| texts: List[str] = None |
|
|
| @app.post("/predict") |
| async def predict(request: ClassificationRequest): |
| try: |
| results = classifier.predict( |
| images=request.images, |
| labels=request.labels, |
| multilabel = request.multilabel |
| ) |
| return {"predictions": results} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/encode") |
| async def encode(request: EmbeddingRequest): |
| try: |
| results = classifier.encode(images=request.images, texts= request.texts) |
| results["image_embeddings"] = results["image_embeddings"].tolist() if results["image_embeddings"] is not None else None |
| results["text_embeddings"] = results["text_embeddings"].tolist() if results["text_embeddings"] is not None else None |
|
|
| return results |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.get("/health") |
| async def health(): |
| return {"status": "healthy"} |
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=8000) |