customer_support / tests /test_env.py
Sayed223's picture
Upload 8 files
f672bcf verified
"""
Tests for CustomerSupportEnv.
Run: python -m pytest tests/ -v
"""
import sys, os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import pytest
from env.environment import CustomerSupportEnv, TASKS
from env.models import Action, ActionType, TicketStatus
from graders.graders import grade
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest.fixture
def env1():
e = CustomerSupportEnv(task_id="task_1", seed=0)
e.reset()
return e
@pytest.fixture
def env2():
e = CustomerSupportEnv(task_id="task_2", seed=0)
e.reset()
return e
@pytest.fixture
def env3():
e = CustomerSupportEnv(task_id="task_3", seed=0)
e.reset()
return e
# ── reset() ───────────────────────────────────────────────────────────────────
def test_reset_returns_observation():
env = CustomerSupportEnv(task_id="task_1", seed=0)
obs = env.reset()
assert obs.ticket_id == "TKT-001"
assert obs.done is False
assert obs.turn == 0
assert obs.status == TicketStatus.OPEN.value or obs.status == TicketStatus.OPEN
def test_reset_clears_state(env1):
env1.step(Action(action_type=ActionType.SEARCH_KB))
obs = env1.reset()
assert obs.kb_searched is False
assert obs.turn == 0
assert obs.cumulative_reward == 0.0
def test_reset_loads_history(env1):
obs = env1.state()
assert len(obs.history) >= 1
assert obs.history[0].role == "customer"
# ── state() ───────────────────────────────────────────────────────────────────
def test_state_does_not_advance(env1):
obs_before = env1.state()
env1.state()
obs_after = env1.state()
assert obs_before.turn == obs_after.turn
# ── step() ────────────────────────────────────────────────────────────────────
def test_step_search_kb(env1):
result = env1.step(Action(action_type=ActionType.SEARCH_KB))
assert result.reward.total == 2.0
assert result.observation.kb_searched is True
assert len(result.observation.kb_results) > 0
def test_step_search_kb_duplicate_penalised(env1):
env1.step(Action(action_type=ActionType.SEARCH_KB))
result = env1.step(Action(action_type=ActionType.SEARCH_KB))
assert result.reward.total < 0
def test_step_empathize(env1):
result = env1.step(Action(action_type=ActionType.EMPATHIZE))
assert result.reward.total == 1.0
assert result.observation.empathized is True
def test_step_empathize_no_double_reward(env1):
env1.step(Action(action_type=ActionType.EMPATHIZE))
result = env1.step(Action(action_type=ActionType.EMPATHIZE))
assert result.reward.total == 0.0
def test_step_offer_solution_without_kb_penalised(env1):
result = env1.step(Action(
action_type=ActionType.OFFER_SOLUTION,
payload="I have unlocked your account and sent a reset link."
))
assert result.reward.penalties == -1.0
def test_step_offer_solution_with_kb_rewarded(env1):
env1.step(Action(action_type=ActionType.SEARCH_KB))
result = env1.step(Action(
action_type=ActionType.OFFER_SOLUTION,
payload="I have unlocked your account and sent a password reset link."
))
assert result.reward.total > 0
def test_step_resolve_without_solution_penalised(env1):
result = env1.step(Action(action_type=ActionType.RESOLVE))
assert result.reward.total == -3.0
assert result.done is True
def test_step_resolve_good(env1):
env1.step(Action(action_type=ActionType.SEARCH_KB))
env1.step(Action(
action_type=ActionType.OFFER_SOLUTION,
payload="Account unlocked and reset email sent."
))
result = env1.step(Action(action_type=ActionType.RESOLVE))
assert result.reward.total >= 5.0
assert result.done is True
def test_step_raises_before_reset():
env = CustomerSupportEnv(task_id="task_1")
with pytest.raises(RuntimeError):
env.step(Action(action_type=ActionType.SEARCH_KB))
def test_step_raises_after_done(env1):
env1.step(Action(action_type=ActionType.RESOLVE))
with pytest.raises(RuntimeError):
env1.step(Action(action_type=ActionType.SEARCH_KB))
def test_timeout_penalty(env1):
"""Exceeding max_turns gives timeout penalty."""
for _ in range(env1._obs.max_turns - 1):
env1.step(Action(action_type=ActionType.EMPATHIZE))
obs = env1.state()
assert obs.turn >= obs.max_turns - 1
# ── Graders ───────────────────────────────────────────────────────────────────
def test_grader_task1_optimal(env1):
env1.step(Action(action_type=ActionType.SEARCH_KB))
env1.step(Action(action_type=ActionType.EMPATHIZE))
env1.step(Action(
action_type=ActionType.OFFER_SOLUTION,
payload="I have unlocked your account and sent a password reset link to your email."
))
env1.step(Action(action_type=ActionType.RESOLVE))
result = grade("task_1", env1.state())
assert result.score >= 0.90
assert result.passed is True
def test_grader_task1_minimal(env1):
"""Just resolve with no steps β€” should fail."""
env1.step(Action(action_type=ActionType.RESOLVE))
result = grade("task_1", env1.state())
assert result.score < 0.40
assert result.passed is False
def test_grader_task1_score_in_range(env1):
result = grade("task_1", env1.state())
assert 0.0 <= result.score <= 1.0
def test_grader_task2_requires_clarify(env2):
"""Medium task: no clarify β†’ lower score."""
env2.step(Action(action_type=ActionType.SEARCH_KB))
env2.step(Action(
action_type=ActionType.OFFER_SOLUTION,
payload="I have applied a $20 credit to your account."
))
env2.step(Action(action_type=ActionType.RESOLVE))
result = grade("task_2", env2.state())
assert result.breakdown.get("ask_clarify", 0) == 0.0
def test_grader_task2_full_score(env2):
env2.step(Action(action_type=ActionType.SEARCH_KB))
env2.step(Action(action_type=ActionType.ASK_CLARIFY, payload="Can you confirm your account email and invoice number?"))
env2.step(Action(action_type=ActionType.EMPATHIZE))
env2.step(Action(
action_type=ActionType.OFFER_SOLUTION,
payload="I have issued a $20 credit to your account. Your plan is now corrected to $29/month."
))
env2.step(Action(action_type=ActionType.RESOLVE))
result = grade("task_2", env2.state())
assert result.score >= 0.70
def test_grader_task3_two_part_solution(env3):
env3.step(Action(action_type=ActionType.SEARCH_KB))
env3.step(Action(action_type=ActionType.EMPATHIZE))
env3.step(Action(
action_type=ActionType.OFFER_SOLUTION,
payload="I have moved your export job to the priority queue β€” it will complete in 1-2 hours. "
"As a backup, please start a partial export by date range which will be much faster. "
"I will email you when the full export completes."
))
env3.step(Action(action_type=ActionType.RESOLVE))
result = grade("task_3", env3.state())
assert result.score >= 0.70
assert result.passed is True
def test_grader_task3_escalation_capped(env3):
env3.step(Action(action_type=ActionType.SEARCH_KB))
env3.step(Action(action_type=ActionType.ESCALATE))
env3.step(Action(action_type=ActionType.RESOLVE))
result = grade("task_3", env3.state())
assert result.score <= 0.55
def test_grader_deterministic(env1):
"""Same inputs β†’ same grader output every time."""
env1.step(Action(action_type=ActionType.SEARCH_KB))
env1.step(Action(action_type=ActionType.RESOLVE))
r1 = grade("task_1", env1.state())
env1.reset()
env1.step(Action(action_type=ActionType.SEARCH_KB))
env1.step(Action(action_type=ActionType.RESOLVE))
r2 = grade("task_1", env1.state())
assert r1.score == r2.score
# ── Task specs ────────────────────────────────────────────────────────────────
def test_task_list():
assert set(CustomerSupportEnv.list_tasks()) == {"task_1", "task_2", "task_3"}
def test_task_difficulty_progression():
diffs = [TASKS[tid].difficulty for tid in ["task_1", "task_2", "task_3"]]
assert diffs == ["easy", "medium", "hard"]