File size: 3,661 Bytes
cb20efa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e030e6
 
cb20efa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# std lib
import os

# 3rd party imports
from typing import TypedDict, Annotated, Optional
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langgraph.graph.message import add_messages
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition

# local imports
from tools import (
    select_tools_for_input,
)


openai_token = os.getenv("HF_FINAL_ASSIGNMENT_OPENAI")

# llm = ChatOpenAI(model="gpt-4.1-nano", api_key=openai_token, temperature=0)
llm = ChatOpenAI(model="gpt-5.2", api_key=openai_token, temperature=0)


class AgentState(TypedDict):
    # The input document
    input_file: Optional[str]  # Contains file path, type (ANY)
    messages: Annotated[list[AnyMessage], add_messages]


def _selected_tools_from_state(state: AgentState):
    return select_tools_for_input(state.get("input_file"))


def _build_tools_description(selected_tools: list) -> str:
    lines = []
    for fn in selected_tools:
        doc = (fn.__doc__ or "").strip().split("\n")[0]
        if doc:
            lines.append(f"- {fn.__name__}: {doc}")
        else:
            lines.append(f"- {fn.__name__}")
    return "\n".join(lines)


def assistant(state: AgentState):
    data_file = state["input_file"]
    selected_tools = _selected_tools_from_state(state)
    llm_with_tools = llm.bind_tools(selected_tools, parallel_tool_calls=False)
    tools_description = _build_tools_description(selected_tools)

    sys_msg = SystemMessage(
        content=(
            "You are a general AI assistant. I will ask you a question. "
            "Report your thoughts, and finish your answer with the following template: "
            "FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible "
            "OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma "
            "to write your number neither use units such as $ or percent sign unless specified otherwise. "
            "If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write "
            "the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply "
            "the above rules depending of whether the element to be put in the list is a number or a string.\n\n"
            f"Available tools for this input:\n{tools_description}"
        )
    )

    prompt_messages = [sys_msg] + state["messages"]
    if data_file:
        prompt_messages.append(
            HumanMessage(content=f"Input file path (local): {data_file}")
        )
    print("Prompt messages for assistant:")
    for msg in prompt_messages:
        print(f"- {msg.content}")

    response = llm_with_tools.invoke(prompt_messages)
    return {"messages": [response], "input_file": state["input_file"]}


def tools_node(state: AgentState):
    selected_tools = _selected_tools_from_state(state)
    return ToolNode(selected_tools).invoke(state)


# Graph
builder = StateGraph(AgentState)

# Define nodes: these do the work
builder.add_node("assistant", assistant)
builder.add_node("tools", tools_node)

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
    "assistant",
    # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
    # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
    tools_condition,
)
builder.add_edge("tools", "assistant")
react_graph = builder.compile()