File size: 2,893 Bytes
5dd1bb4 d9759a5 5dd1bb4 d9759a5 5dd1bb4 d9759a5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | """
FastAPI application for the SQLEnv environment.
Exposes the SQLEnvironment over HTTP and WebSocket endpoints,
compatible with the OpenEnv EnvClient.
Usage:
# Development (with auto-reload):
uv run uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
# Via uv:
uv run server
"""
import os
from pathlib import Path
# Load environment variables from .env file
try:
from dotenv import load_dotenv
env_file = Path(__file__).parent.parent / ".env"
if env_file.exists():
load_dotenv(env_file)
except ImportError:
pass # python-dotenv not installed, use system env vars
from openenv.core.env_server import create_app
try:
from sql_env.models import SQLAction, SQLObservation
from sql_env.server.sql_environment import SQLEnvironment
except ImportError:
# Fallback for Docker where PYTHONPATH=/app/env
from models import SQLAction, SQLObservation # type: ignore[no-redef]
from server.sql_environment import SQLEnvironment # type: ignore[no-redef]
def get_tokenizer():
"""Get tokenizer from environment or use a mock for testing."""
tokenizer_name = os.environ.get(
"TOKENIZER_NAME", "mistralai/Mistral-7B-Instruct-v0.1"
)
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
print(f"Loaded tokenizer: {tokenizer_name}")
return tokenizer
except ImportError:
print(
"Warning: transformers not installed, using mock tokenizer for testing only"
)
from server.test_sql_env import MockTokenizer
return MockTokenizer()
def create_sql_environment():
"""Factory function that creates SQLEnvironment with tokenizer and paths."""
tokenizer = get_tokenizer()
questions_path = os.environ.get(
"QUESTIONS_PATH",
str(
Path(__file__).parent.parent
/ "data"
/ "questions"
/ "student_assessment.json"
),
)
db_dir = os.environ.get(
"DB_DIR",
str(Path(__file__).parent.parent / "data" / "databases"),
)
return SQLEnvironment(
questions_path=questions_path,
db_dir=db_dir,
tokenizer=tokenizer,
)
# Create the FastAPI app
app = create_app(
create_sql_environment,
SQLAction,
SQLObservation,
env_name="sql_env",
)
def main(host: str = "0.0.0.0", port: int | None = None):
"""Entry point for running the server directly.
Enables:
uv run server
python -m sql_env.server.app
"""
import uvicorn
if port is None:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
port = args.port
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()
|