File size: 620 Bytes
8097081
 
 
 
 
 
 
 
 
 
 
 
72a7241
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# tests/test_environment.py
import pytest
from src.pytorch_debug_env.environment import PyTorchDebugEnv
from src.pytorch_debug_env.scenario_generator import ScenarioGenerator
from src.pytorch_debug_env.bug_library import BUG_TEMPLATES

@pytest.mark.asyncio
async def test_env_reset():
    generator = ScenarioGenerator(BUG_TEMPLATES)
    env = PyTorchDebugEnv(generator=generator)
    obs = await env.reset("easy")
    assert obs.task_id == "easy"
    assert "train.py" in obs.revealed_files
    assert "config/training_config.yaml" in obs.revealed_files
    assert obs.step_num == 0
    assert obs.steps_remaining >= 0