| import json |
| import copy |
|
|
| from encoding_dsv32 import encode_messages, parse_message_from_completion_text |
|
|
| with open("test_input.json", "r") as f: |
| test_dict = json.load(f) |
| messages = test_dict["messages"] |
| messages[0]["tools"] = test_dict["tools"] |
|
|
| with open("test_output.txt", "r") as f: |
| gold_prompt = f.read().strip() |
|
|
| print(messages) |
| print("=" * 60) |
|
|
| encode_config = dict(thinking_mode="thinking", drop_thinking=True, add_default_bos_token=True) |
| prompt = encode_messages(messages, **encode_config) |
| print(prompt) |
| assert prompt == gold_prompt |
| print("=" * 60) |
|
|
| tool_call_message = messages[4] |
| tool_call_prompt = encode_messages([tool_call_message], context=messages[:4], **encode_config) |
| tool_call_message_wo_id = copy.deepcopy(tool_call_message) |
| for tool_call in tool_call_message_wo_id["tool_calls"]: |
| tool_call.pop("id") |
| parsed_tool_call_message = parse_message_from_completion_text(tool_call_prompt, thinking_mode="thinking") |
| parsed_tool_call_message.pop("content") |
| assert tool_call_message_wo_id == parsed_tool_call_message |
|
|
| thinking_message = messages[-6] |
| thinking_prompt = encode_messages([thinking_message], context=messages[:-6], **encode_config) |
| parsed_thinking_message = parse_message_from_completion_text(thinking_prompt, thinking_mode="thinking") |
| parsed_thinking_message.pop("tool_calls") |
| assert thinking_message == parsed_thinking_message |
|
|
| with open("test_input_search_wo_date.json", "r") as f: |
| search_messages = json.load(f)["messages"] |
|
|
| with open("test_output_search_wo_date.txt", "r") as f: |
| search_gold_prompt = f.read().strip() |
|
|
| search_prompt = encode_messages(search_messages, **encode_config) |
| assert search_prompt == search_gold_prompt |
|
|
| with open("test_input_search_w_date.json", "r") as f: |
| search_messages_w_date = json.load(f)["messages"] |
|
|
| with open("test_output_search_w_date.txt", "r") as f: |
| search_gold_prompt_w_date = f.read().strip() |
|
|
| search_prompt_w_date = encode_messages(search_messages_w_date, **encode_config) |
| with open("test_output_search_w_date_2.txt", "w") as f: |
| f.write(search_prompt_w_date) |
| assert search_prompt_w_date == search_gold_prompt_w_date |