Spaces:
Runtime error
Runtime error
| """ | |
| Tests for CustomerSupportEnv. | |
| Run: python -m pytest tests/ -v | |
| """ | |
| import sys, os | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| import pytest | |
| from env.environment import CustomerSupportEnv, TASKS | |
| from env.models import Action, ActionType, TicketStatus | |
| from graders.graders import grade | |
| # ββ Fixtures ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def env1(): | |
| e = CustomerSupportEnv(task_id="task_1", seed=0) | |
| e.reset() | |
| return e | |
| def env2(): | |
| e = CustomerSupportEnv(task_id="task_2", seed=0) | |
| e.reset() | |
| return e | |
| def env3(): | |
| e = CustomerSupportEnv(task_id="task_3", seed=0) | |
| e.reset() | |
| return e | |
| # ββ reset() βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_reset_returns_observation(): | |
| env = CustomerSupportEnv(task_id="task_1", seed=0) | |
| obs = env.reset() | |
| assert obs.ticket_id == "TKT-001" | |
| assert obs.done is False | |
| assert obs.turn == 0 | |
| assert obs.status == TicketStatus.OPEN.value or obs.status == TicketStatus.OPEN | |
| def test_reset_clears_state(env1): | |
| env1.step(Action(action_type=ActionType.SEARCH_KB)) | |
| obs = env1.reset() | |
| assert obs.kb_searched is False | |
| assert obs.turn == 0 | |
| assert obs.cumulative_reward == 0.0 | |
| def test_reset_loads_history(env1): | |
| obs = env1.state() | |
| assert len(obs.history) >= 1 | |
| assert obs.history[0].role == "customer" | |
| # ββ state() βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_state_does_not_advance(env1): | |
| obs_before = env1.state() | |
| env1.state() | |
| obs_after = env1.state() | |
| assert obs_before.turn == obs_after.turn | |
| # ββ step() ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_step_search_kb(env1): | |
| result = env1.step(Action(action_type=ActionType.SEARCH_KB)) | |
| assert result.reward.total == 2.0 | |
| assert result.observation.kb_searched is True | |
| assert len(result.observation.kb_results) > 0 | |
| def test_step_search_kb_duplicate_penalised(env1): | |
| env1.step(Action(action_type=ActionType.SEARCH_KB)) | |
| result = env1.step(Action(action_type=ActionType.SEARCH_KB)) | |
| assert result.reward.total < 0 | |
| def test_step_empathize(env1): | |
| result = env1.step(Action(action_type=ActionType.EMPATHIZE)) | |
| assert result.reward.total == 1.0 | |
| assert result.observation.empathized is True | |
| def test_step_empathize_no_double_reward(env1): | |
| env1.step(Action(action_type=ActionType.EMPATHIZE)) | |
| result = env1.step(Action(action_type=ActionType.EMPATHIZE)) | |
| assert result.reward.total == 0.0 | |
| def test_step_offer_solution_without_kb_penalised(env1): | |
| result = env1.step(Action( | |
| action_type=ActionType.OFFER_SOLUTION, | |
| payload="I have unlocked your account and sent a reset link." | |
| )) | |
| assert result.reward.penalties == -1.0 | |
| def test_step_offer_solution_with_kb_rewarded(env1): | |
| env1.step(Action(action_type=ActionType.SEARCH_KB)) | |
| result = env1.step(Action( | |
| action_type=ActionType.OFFER_SOLUTION, | |
| payload="I have unlocked your account and sent a password reset link." | |
| )) | |
| assert result.reward.total > 0 | |
| def test_step_resolve_without_solution_penalised(env1): | |
| result = env1.step(Action(action_type=ActionType.RESOLVE)) | |
| assert result.reward.total == -3.0 | |
| assert result.done is True | |
| def test_step_resolve_good(env1): | |
| env1.step(Action(action_type=ActionType.SEARCH_KB)) | |
| env1.step(Action( | |
| action_type=ActionType.OFFER_SOLUTION, | |
| payload="Account unlocked and reset email sent." | |
| )) | |
| result = env1.step(Action(action_type=ActionType.RESOLVE)) | |
| assert result.reward.total >= 5.0 | |
| assert result.done is True | |
| def test_step_raises_before_reset(): | |
| env = CustomerSupportEnv(task_id="task_1") | |
| with pytest.raises(RuntimeError): | |
| env.step(Action(action_type=ActionType.SEARCH_KB)) | |
| def test_step_raises_after_done(env1): | |
| env1.step(Action(action_type=ActionType.RESOLVE)) | |
| with pytest.raises(RuntimeError): | |
| env1.step(Action(action_type=ActionType.SEARCH_KB)) | |
| def test_timeout_penalty(env1): | |
| """Exceeding max_turns gives timeout penalty.""" | |
| for _ in range(env1._obs.max_turns - 1): | |
| env1.step(Action(action_type=ActionType.EMPATHIZE)) | |
| obs = env1.state() | |
| assert obs.turn >= obs.max_turns - 1 | |
| # ββ Graders βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_grader_task1_optimal(env1): | |
| env1.step(Action(action_type=ActionType.SEARCH_KB)) | |
| env1.step(Action(action_type=ActionType.EMPATHIZE)) | |
| env1.step(Action( | |
| action_type=ActionType.OFFER_SOLUTION, | |
| payload="I have unlocked your account and sent a password reset link to your email." | |
| )) | |
| env1.step(Action(action_type=ActionType.RESOLVE)) | |
| result = grade("task_1", env1.state()) | |
| assert result.score >= 0.90 | |
| assert result.passed is True | |
| def test_grader_task1_minimal(env1): | |
| """Just resolve with no steps β should fail.""" | |
| env1.step(Action(action_type=ActionType.RESOLVE)) | |
| result = grade("task_1", env1.state()) | |
| assert result.score < 0.40 | |
| assert result.passed is False | |
| def test_grader_task1_score_in_range(env1): | |
| result = grade("task_1", env1.state()) | |
| assert 0.0 <= result.score <= 1.0 | |
| def test_grader_task2_requires_clarify(env2): | |
| """Medium task: no clarify β lower score.""" | |
| env2.step(Action(action_type=ActionType.SEARCH_KB)) | |
| env2.step(Action( | |
| action_type=ActionType.OFFER_SOLUTION, | |
| payload="I have applied a $20 credit to your account." | |
| )) | |
| env2.step(Action(action_type=ActionType.RESOLVE)) | |
| result = grade("task_2", env2.state()) | |
| assert result.breakdown.get("ask_clarify", 0) == 0.0 | |
| def test_grader_task2_full_score(env2): | |
| env2.step(Action(action_type=ActionType.SEARCH_KB)) | |
| env2.step(Action(action_type=ActionType.ASK_CLARIFY, payload="Can you confirm your account email and invoice number?")) | |
| env2.step(Action(action_type=ActionType.EMPATHIZE)) | |
| env2.step(Action( | |
| action_type=ActionType.OFFER_SOLUTION, | |
| payload="I have issued a $20 credit to your account. Your plan is now corrected to $29/month." | |
| )) | |
| env2.step(Action(action_type=ActionType.RESOLVE)) | |
| result = grade("task_2", env2.state()) | |
| assert result.score >= 0.70 | |
| def test_grader_task3_two_part_solution(env3): | |
| env3.step(Action(action_type=ActionType.SEARCH_KB)) | |
| env3.step(Action(action_type=ActionType.EMPATHIZE)) | |
| env3.step(Action( | |
| action_type=ActionType.OFFER_SOLUTION, | |
| payload="I have moved your export job to the priority queue β it will complete in 1-2 hours. " | |
| "As a backup, please start a partial export by date range which will be much faster. " | |
| "I will email you when the full export completes." | |
| )) | |
| env3.step(Action(action_type=ActionType.RESOLVE)) | |
| result = grade("task_3", env3.state()) | |
| assert result.score >= 0.70 | |
| assert result.passed is True | |
| def test_grader_task3_escalation_capped(env3): | |
| env3.step(Action(action_type=ActionType.SEARCH_KB)) | |
| env3.step(Action(action_type=ActionType.ESCALATE)) | |
| env3.step(Action(action_type=ActionType.RESOLVE)) | |
| result = grade("task_3", env3.state()) | |
| assert result.score <= 0.55 | |
| def test_grader_deterministic(env1): | |
| """Same inputs β same grader output every time.""" | |
| env1.step(Action(action_type=ActionType.SEARCH_KB)) | |
| env1.step(Action(action_type=ActionType.RESOLVE)) | |
| r1 = grade("task_1", env1.state()) | |
| env1.reset() | |
| env1.step(Action(action_type=ActionType.SEARCH_KB)) | |
| env1.step(Action(action_type=ActionType.RESOLVE)) | |
| r2 = grade("task_1", env1.state()) | |
| assert r1.score == r2.score | |
| # ββ Task specs ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def test_task_list(): | |
| assert set(CustomerSupportEnv.list_tasks()) == {"task_1", "task_2", "task_3"} | |
| def test_task_difficulty_progression(): | |
| diffs = [TASKS[tid].difficulty for tid in ["task_1", "task_2", "task_3"]] | |
| assert diffs == ["easy", "medium", "hard"] | |