| | """
|
| | Tests for REST API endpoints.
|
| |
|
| | Tests that configuration is properly received and applied.
|
| | """
|
| | import pytest
|
| | from fastapi.testclient import TestClient
|
| | from portfolio_optimization.rest_api import app
|
| | from portfolio_optimization.domain import (
|
| | PortfolioOptimizationPlanModel,
|
| | StockSelectionModel,
|
| | SolverConfigModel,
|
| | )
|
| |
|
| |
|
| | @pytest.fixture
|
| | def client():
|
| | """Create a test client for the FastAPI app."""
|
| | return TestClient(app)
|
| |
|
| |
|
| | class TestDemoDataEndpoints:
|
| | """Tests for demo data endpoints."""
|
| |
|
| | def test_list_demo_data(self, client):
|
| | """GET /demo-data should return available datasets."""
|
| | response = client.get("/demo-data")
|
| | assert response.status_code == 200
|
| | data = response.json()
|
| | assert "SMALL" in data
|
| | assert "LARGE" in data
|
| |
|
| | def test_get_small_demo_data(self, client):
|
| | """GET /demo-data/SMALL should return 25 stocks."""
|
| | response = client.get("/demo-data/SMALL")
|
| | assert response.status_code == 200
|
| | data = response.json()
|
| | assert "stocks" in data
|
| | assert len(data["stocks"]) == 25
|
| |
|
| | def test_get_large_demo_data(self, client):
|
| | """GET /demo-data/LARGE should return 51 stocks."""
|
| | response = client.get("/demo-data/LARGE")
|
| | assert response.status_code == 200
|
| | data = response.json()
|
| | assert "stocks" in data
|
| | assert len(data["stocks"]) == 51
|
| |
|
| |
|
| | class TestSolverConfigEndpoints:
|
| | """Tests for solver configuration handling."""
|
| |
|
| | def test_plan_model_accepts_solver_config(self):
|
| | """PortfolioOptimizationPlanModel should accept solverConfig."""
|
| | model = PortfolioOptimizationPlanModel(
|
| | stocks=[
|
| | StockSelectionModel(
|
| | stockId="AAPL",
|
| | stockName="Apple",
|
| | sector="Technology",
|
| | predictedReturn=0.12,
|
| | selected=None
|
| | )
|
| | ],
|
| | targetPositionCount=20,
|
| | maxSectorPercentage=0.25,
|
| | solverConfig=SolverConfigModel(terminationSeconds=60)
|
| | )
|
| | assert model.solver_config is not None
|
| | assert model.solver_config.termination_seconds == 60
|
| |
|
| | def test_plan_model_serializes_solver_config(self):
|
| | """solverConfig should serialize with camelCase aliases."""
|
| | model = PortfolioOptimizationPlanModel(
|
| | stocks=[],
|
| | solverConfig=SolverConfigModel(terminationSeconds=90)
|
| | )
|
| | data = model.model_dump(by_alias=True)
|
| | assert "solverConfig" in data
|
| | assert data["solverConfig"]["terminationSeconds"] == 90
|
| |
|
| | def test_plan_model_deserializes_solver_config(self):
|
| | """solverConfig should deserialize from JSON."""
|
| | json_data = {
|
| | "stocks": [
|
| | {
|
| | "stockId": "AAPL",
|
| | "stockName": "Apple",
|
| | "sector": "Technology",
|
| | "predictedReturn": 0.12,
|
| | "selected": None
|
| | }
|
| | ],
|
| | "targetPositionCount": 15,
|
| | "maxSectorPercentage": 0.30,
|
| | "solverConfig": {
|
| | "terminationSeconds": 120
|
| | }
|
| | }
|
| | model = PortfolioOptimizationPlanModel.model_validate(json_data)
|
| | assert model.target_position_count == 15
|
| | assert model.max_sector_percentage == 0.30
|
| | assert model.solver_config is not None
|
| | assert model.solver_config.termination_seconds == 120
|
| |
|
| | def test_plan_without_solver_config(self):
|
| | """Plan should work without solverConfig (uses defaults)."""
|
| | json_data = {
|
| | "stocks": [],
|
| | "targetPositionCount": 20,
|
| | "maxSectorPercentage": 0.25
|
| | }
|
| | model = PortfolioOptimizationPlanModel.model_validate(json_data)
|
| | assert model.solver_config is None
|
| |
|
| | def test_post_portfolio_with_solver_config(self, client):
|
| | """POST /portfolios should accept solverConfig in request body."""
|
| |
|
| | demo_response = client.get("/demo-data/SMALL")
|
| | plan_data = demo_response.json()
|
| |
|
| |
|
| | plan_data["solverConfig"] = {
|
| | "terminationSeconds": 10
|
| | }
|
| |
|
| |
|
| | response = client.post("/portfolios", json=plan_data)
|
| | assert response.status_code == 200
|
| | job_id = response.json()
|
| | assert job_id is not None
|
| | assert len(job_id) > 0
|
| |
|
| |
|
| | stop_response = client.delete(f"/portfolios/{job_id}")
|
| | assert stop_response.status_code == 200
|
| |
|