File size: 2,893 Bytes
5dd1bb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9759a5
5dd1bb4
 
 
 
 
 
 
 
d9759a5
 
 
 
 
 
 
 
5dd1bb4
 
 
 
d9759a5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
FastAPI application for the SQLEnv environment.

Exposes the SQLEnvironment over HTTP and WebSocket endpoints,
compatible with the OpenEnv EnvClient.

Usage:
    # Development (with auto-reload):
    uv run uvicorn server.app:app --reload --host 0.0.0.0 --port 8000

    # Via uv:
    uv run server
"""

import os
from pathlib import Path

# Load environment variables from .env file
try:
    from dotenv import load_dotenv

    env_file = Path(__file__).parent.parent / ".env"
    if env_file.exists():
        load_dotenv(env_file)
except ImportError:
    pass  # python-dotenv not installed, use system env vars

from openenv.core.env_server import create_app

try:
    from sql_env.models import SQLAction, SQLObservation
    from sql_env.server.sql_environment import SQLEnvironment
except ImportError:
    # Fallback for Docker where PYTHONPATH=/app/env
    from models import SQLAction, SQLObservation  # type: ignore[no-redef]
    from server.sql_environment import SQLEnvironment  # type: ignore[no-redef]


def get_tokenizer():
    """Get tokenizer from environment or use a mock for testing."""
    tokenizer_name = os.environ.get(
        "TOKENIZER_NAME", "mistralai/Mistral-7B-Instruct-v0.1"
    )

    try:
        from transformers import AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        print(f"Loaded tokenizer: {tokenizer_name}")
        return tokenizer
    except ImportError:
        print(
            "Warning: transformers not installed, using mock tokenizer for testing only"
        )
        from server.test_sql_env import MockTokenizer

        return MockTokenizer()


def create_sql_environment():
    """Factory function that creates SQLEnvironment with tokenizer and paths."""
    tokenizer = get_tokenizer()
    questions_path = os.environ.get(
        "QUESTIONS_PATH",
        str(
            Path(__file__).parent.parent
            / "data"
            / "questions"
            / "student_assessment.json"
        ),
    )
    db_dir = os.environ.get(
        "DB_DIR",
        str(Path(__file__).parent.parent / "data" / "databases"),
    )
    return SQLEnvironment(
        questions_path=questions_path,
        db_dir=db_dir,
        tokenizer=tokenizer,
    )


# Create the FastAPI app
app = create_app(
    create_sql_environment,
    SQLAction,
    SQLObservation,
    env_name="sql_env",
)


def main(host: str = "0.0.0.0", port: int | None = None):
    """Entry point for running the server directly.

    Enables:
        uv run server
        python -m sql_env.server.app
    """
    import uvicorn

    if port is None:
        import argparse

        parser = argparse.ArgumentParser()
        parser.add_argument("--port", type=int, default=8000)
        args = parser.parse_args()
        port = args.port

    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    main()