| | from fastapi import FastAPI, HTTPException, Request |
| | from fastapi.responses import JSONResponse |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from pydantic import BaseModel, Field |
| | from typing import List, Optional, Dict, Any |
| | import numpy as np |
| | import base64 |
| | import logging |
| | import sys |
| | import traceback |
| | import io |
| | from PIL import Image |
| | import json |
| |
|
| | |
| | try: |
| | import faceforge_core |
| | from faceforge_core.latent_explorer import LatentSpaceExplorer |
| | from faceforge_core.attribute_directions import LatentDirectionFinder |
| | from faceforge_core.custom_loss import attribute_preserving_loss |
| | HAS_CORE = True |
| | except ImportError as e: |
| | logging.warning(f"Failed to import faceforge_core modules: {e}") |
| | logging.warning("Using mock implementations instead") |
| | HAS_CORE = False |
| |
|
| | |
| | logging.basicConfig( |
| | level=logging.DEBUG, |
| | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
| | handlers=[logging.StreamHandler(sys.stdout)] |
| | ) |
| | logger = logging.getLogger("faceforge_api") |
| |
|
| | |
| |
|
| | class PointIn(BaseModel): |
| | text: str |
| | encoding: Optional[List[float]] = Field(None) |
| | xy_pos: Optional[List[float]] = Field(None) |
| |
|
| | class GenerateRequest(BaseModel): |
| | prompts: List[str] |
| | positions: Optional[List[List[float]]] = Field(None) |
| | mode: str = "distance" |
| | player_pos: Optional[List[float]] = Field(None) |
| |
|
| | class ManipulateRequest(BaseModel): |
| | encoding: List[float] |
| | direction: List[float] |
| | alpha: float |
| |
|
| | class AttributeDirectionRequest(BaseModel): |
| | latents: List[List[float]] |
| | labels: Optional[List[int]] = Field(None) |
| | n_components: Optional[int] = 10 |
| |
|
| | |
| |
|
| | class MockLatentSpaceExplorer: |
| | def __init__(self): |
| | self.points = [] |
| | logger.warning("Using mock LatentSpaceExplorer") |
| | |
| | def add_point(self, text, encoding=None, xy_pos=None): |
| | logger.debug(f"Mock add_point: {text}") |
| | self.points.append({"text": text, "xy_pos": xy_pos}) |
| | |
| | def sample_encoding(self, player_pos, mode="distance"): |
| | logger.debug(f"Mock sample_encoding: {player_pos}, {mode}") |
| | |
| | return np.random.randn(1, 4, 64, 64) |
| |
|
| | class MockLatentDirectionFinder: |
| | def __init__(self, latents): |
| | self.latents = latents |
| | logger.warning("Using mock LatentDirectionFinder") |
| | |
| | def classifier_direction(self, labels): |
| | return np.random.randn(512) |
| | |
| | def pca_direction(self, n_components=10): |
| | components = np.random.randn(n_components, 512) |
| | explained = np.random.rand(n_components) |
| | return components, explained |
| |
|
| | |
| |
|
| | app = FastAPI( |
| | title="FaceForge API", |
| | description="API for latent space exploration and manipulation", |
| | version="1.0.0", |
| | |
| | root_path="" |
| | ) |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | explorer = LatentSpaceExplorer() if HAS_CORE else MockLatentSpaceExplorer() |
| |
|
| | |
| | @app.middleware("http") |
| | async def error_handling_middleware(request: Request, call_next): |
| | try: |
| | return await call_next(request) |
| | except Exception as e: |
| | logger.error(f"Unhandled exception: {str(e)}") |
| | logger.debug(traceback.format_exc()) |
| | return JSONResponse( |
| | status_code=500, |
| | content={"detail": "Internal server error", "error": str(e)}, |
| | ) |
| |
|
| | @app.get("/") |
| | def read_root(): |
| | logger.debug("API root endpoint called") |
| | return {"message": "FaceForge API is running"} |
| |
|
| | @app.post("/generate") |
| | async def generate_image(req: GenerateRequest): |
| | try: |
| | logger.debug(f"Generate image request: {json.dumps(req.dict(), default=str)}") |
| | |
| | |
| | logger.debug(f"Request schema: {GenerateRequest.schema_json()}") |
| | |
| | |
| | explorer.points = [] |
| | |
| | |
| | for i, prompt in enumerate(req.prompts): |
| | logger.debug(f"Processing prompt {i}: {prompt}") |
| | |
| | |
| | encoding = np.random.randn(512) |
| | |
| | |
| | xy_pos = req.positions[i] if req.positions and i < len(req.positions) else None |
| | logger.debug(f"Position for prompt {i}: {xy_pos}") |
| | |
| | |
| | explorer.add_point(prompt, encoding, xy_pos) |
| | |
| | |
| | if req.player_pos is None: |
| | player_pos = [0.0, 0.0] |
| | else: |
| | player_pos = req.player_pos |
| | logger.debug(f"Player position: {player_pos}") |
| | |
| | |
| | logger.debug(f"Sampling with mode: {req.mode}") |
| | sampled = explorer.sample_encoding(tuple(player_pos), mode=req.mode) |
| | |
| | |
| | img = (np.random.rand(256, 256, 3) * 255).astype(np.uint8) |
| | |
| | |
| | logger.debug("Converting image to base64") |
| | pil_img = Image.fromarray(img) |
| | buffer = io.BytesIO() |
| | pil_img.save(buffer, format="PNG") |
| | img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8") |
| | |
| | |
| | response = {"status": "success", "image": img_b64} |
| | logger.debug(f"Response structure: {list(response.keys())}") |
| | logger.debug(f"Image base64 length: {len(img_b64)}") |
| | |
| | logger.debug("Image generated successfully") |
| | return response |
| | |
| | except Exception as e: |
| | logger.error(f"Error in generate_image: {str(e)}") |
| | logger.debug(traceback.format_exc()) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.post("/manipulate") |
| | def manipulate(req: ManipulateRequest): |
| | try: |
| | logger.debug(f"Manipulate request: {json.dumps(req.dict(), default=str)}") |
| | encoding = np.array(req.encoding) |
| | direction = np.array(req.direction) |
| | manipulated = encoding + req.alpha * direction |
| | logger.debug("Manipulation successful") |
| | return {"manipulated_encoding": manipulated.tolist()} |
| | except Exception as e: |
| | logger.error(f"Error in manipulate: {str(e)}") |
| | logger.debug(traceback.format_exc()) |
| | raise HTTPException(status_code=500, detail=str(e)) |
| |
|
| | @app.post("/attribute_direction") |
| | def attribute_direction(req: AttributeDirectionRequest): |
| | try: |
| | logger.debug(f"Attribute direction request: {json.dumps(req.dict(), default=str)}") |
| | latents = np.array(req.latents) |
| | |
| | finder = LatentDirectionFinder(latents) if HAS_CORE else MockLatentDirectionFinder(latents) |
| | |
| | if req.labels is not None: |
| | logger.debug("Using classifier-based direction finding") |
| | direction = finder.classifier_direction(req.labels) |
| | logger.debug("Direction found successfully") |
| | return {"direction": direction.tolist()} |
| | else: |
| | logger.debug(f"Using PCA with {req.n_components} components") |
| | components, explained = finder.pca_direction(n_components=req.n_components) |
| | logger.debug("PCA completed successfully") |
| | return {"components": components.tolist(), "explained_variance": explained.tolist()} |
| | except Exception as e: |
| | logger.error(f"Error in attribute_direction: {str(e)}") |
| | logger.debug(traceback.format_exc()) |
| | raise HTTPException(status_code=500, detail=str(e)) |