| import sys |
| import logging |
| import shutil |
| import tempfile |
| import zipfile |
| import io as python_io |
| import base64 |
| from pathlib import Path |
|
|
| from fastapi import FastAPI, UploadFile, File |
| from fastapi.responses import StreamingResponse, JSONResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| import torch |
|
|
| |
| |
| |
| sys.path.append(str(Path(__file__).parents[2])) |
|
|
| from sharp.models import PredictorParams, RGBGaussianPredictor, create_predictor |
| from sharp.utils import io as sharp_io |
| from sharp.utils.gaussians import save_ply |
| from sharp.cli.predict import predict_image, DEFAULT_MODEL_URL |
|
|
| logging.basicConfig(level=logging.INFO) |
| LOGGER = logging.getLogger("sharp.api") |
|
|
| app = FastAPI() |
|
|
| |
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| predictor: RGBGaussianPredictor | None = None |
| device: torch.device | None = None |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| global predictor, device |
| try: |
| device_str = ( |
| "cuda" |
| if torch.cuda.is_available() |
| else ("mps" if torch.backends.mps.is_available() else "cpu") |
| ) |
| device = torch.device(device_str) |
| LOGGER.info(f"Using device: {device}") |
|
|
| LOGGER.info("Loading SHARP model state dict...") |
| state_dict = torch.hub.load_state_dict_from_url( |
| DEFAULT_MODEL_URL, progress=True, map_location=device |
| ) |
|
|
| predictor = create_predictor(PredictorParams()) |
| predictor.load_state_dict(state_dict) |
| predictor.eval() |
| predictor.to(device) |
| LOGGER.info("Model loaded and ready.") |
| except Exception as e: |
| LOGGER.exception("Failed during startup/model init: %s", e) |
| |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| return { |
| "status": "ok", |
| "device": str(device) if device else None, |
| "model_loaded": predictor is not None, |
| } |
|
|
|
|
| @app.post("/predict") |
| async def predict(files: list[UploadFile] = File(...)): |
| """Accept images and return JSON with per-image metadata and PLY as base64.""" |
| if not predictor: |
| return JSONResponse({"error": "Model not loaded"}, status_code=500) |
|
|
| results = [] |
| with tempfile.TemporaryDirectory() as temp_dir: |
| temp_path = Path(temp_dir) |
|
|
| for file in files: |
| try: |
| |
| file_path = temp_path / file.filename |
| with open(file_path, "wb") as buffer: |
| shutil.copyfileobj(file.file, buffer) |
|
|
| |
| image, _, f_px = sharp_io.load_rgb(file_path) |
| gaussians = predict_image(predictor, image, f_px, device) |
|
|
| |
| ply_filename = f"{file_path.stem}.ply" |
| ply_path = temp_path / ply_filename |
| height, width = image.shape[:2] |
| save_ply(gaussians, f_px, (height, width), ply_path) |
|
|
| |
| with open(ply_path, "rb") as f: |
| ply_data = base64.b64encode(f.read()).decode("utf-8") |
|
|
| results.append( |
| { |
| "filename": file.filename, |
| "ply_filename": ply_filename, |
| "ply_data": ply_data, |
| "width": width, |
| "height": height, |
| "focal_length": f_px, |
| } |
| ) |
| except Exception as e: |
| LOGGER.exception("Error processing %s: %s", file.filename, e) |
| results.append({"filename": file.filename, "error": str(e)}) |
|
|
| return {"results": results} |
|
|
|
|
| @app.post("/predict/download") |
| async def predict_download(files: list[UploadFile] = File(...)): |
| """Accept images and return a ZIP of generated PLY files.""" |
| if not predictor: |
| return JSONResponse({"error": "Model not loaded"}, status_code=500) |
|
|
| output_zip = python_io.BytesIO() |
| with tempfile.TemporaryDirectory() as temp_dir: |
| temp_path = Path(temp_dir) |
| with zipfile.ZipFile(output_zip, "w") as zf: |
| for file in files: |
| try: |
| file_path = temp_path / file.filename |
| with open(file_path, "wb") as buffer: |
| shutil.copyfileobj(file.file, buffer) |
|
|
| image, _, f_px = sharp_io.load_rgb(file_path) |
| gaussians = predict_image(predictor, image, f_px, device) |
|
|
| ply_filename = f"{file_path.stem}.ply" |
| ply_path = temp_path / ply_filename |
| height, width = image.shape[:2] |
| save_ply(gaussians, f_px, (height, width), ply_path) |
|
|
| zf.write(ply_path, ply_filename) |
| except Exception as e: |
| LOGGER.exception("Error processing %s: %s", file.filename, e) |
| continue |
|
|
| output_zip.seek(0) |
| return StreamingResponse( |
| output_zip, |
| media_type="application/zip", |
| headers={"Content-Disposition": "attachment; filename=gaussians.zip"}, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
|
|
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|