| |
| import os |
|
|
| |
| from typing import TypedDict, Annotated, Optional |
| from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage |
| from langchain_openai import ChatOpenAI |
| from langgraph.graph.message import add_messages |
| from langgraph.graph import START, StateGraph |
| from langgraph.prebuilt import ToolNode, tools_condition |
|
|
| |
| from tools import ( |
| select_tools_for_input, |
| ) |
|
|
|
|
| openai_token = os.getenv("HF_FINAL_ASSIGNMENT_OPENAI") |
|
|
| |
| llm = ChatOpenAI(model="gpt-5.2", api_key=openai_token, temperature=0) |
|
|
|
|
| class AgentState(TypedDict): |
| |
| input_file: Optional[str] |
| messages: Annotated[list[AnyMessage], add_messages] |
|
|
|
|
| def _selected_tools_from_state(state: AgentState): |
| return select_tools_for_input(state.get("input_file")) |
|
|
|
|
| def _build_tools_description(selected_tools: list) -> str: |
| lines = [] |
| for fn in selected_tools: |
| doc = (fn.__doc__ or "").strip().split("\n")[0] |
| if doc: |
| lines.append(f"- {fn.__name__}: {doc}") |
| else: |
| lines.append(f"- {fn.__name__}") |
| return "\n".join(lines) |
|
|
|
|
| def assistant(state: AgentState): |
| data_file = state["input_file"] |
| selected_tools = _selected_tools_from_state(state) |
| llm_with_tools = llm.bind_tools(selected_tools, parallel_tool_calls=False) |
| tools_description = _build_tools_description(selected_tools) |
|
|
| sys_msg = SystemMessage( |
| content=( |
| "You are a general AI assistant. I will ask you a question. " |
| "Report your thoughts, and finish your answer with the following template: " |
| "FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible " |
| "OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma " |
| "to write your number neither use units such as $ or percent sign unless specified otherwise. " |
| "If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write " |
| "the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply " |
| "the above rules depending of whether the element to be put in the list is a number or a string.\n\n" |
| f"Available tools for this input:\n{tools_description}" |
| ) |
| ) |
|
|
| prompt_messages = [sys_msg] + state["messages"] |
| if data_file: |
| prompt_messages.append( |
| HumanMessage(content=f"Input file path (local): {data_file}") |
| ) |
| print("Prompt messages for assistant:") |
| for msg in prompt_messages: |
| print(f"- {msg.content}") |
|
|
| response = llm_with_tools.invoke(prompt_messages) |
| return {"messages": [response], "input_file": state["input_file"]} |
|
|
|
|
| def tools_node(state: AgentState): |
| selected_tools = _selected_tools_from_state(state) |
| return ToolNode(selected_tools).invoke(state) |
|
|
|
|
| |
| builder = StateGraph(AgentState) |
|
|
| |
| builder.add_node("assistant", assistant) |
| builder.add_node("tools", tools_node) |
|
|
| |
| builder.add_edge(START, "assistant") |
| builder.add_conditional_edges( |
| "assistant", |
| |
| |
| tools_condition, |
| ) |
| builder.add_edge("tools", "assistant") |
| react_graph = builder.compile() |
|
|