ABVM commited on
Commit
882238c
·
verified ·
1 Parent(s): a4dd17a

Upload 2 files

Browse files
Files changed (2) hide show
  1. multi_agent.py +173 -0
  2. vision_tool.py +70 -0
multi_agent.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import (
2
+ CodeAgent,
3
+ VisitWebpageTool,
4
+ WebSearchTool,
5
+ WikipediaSearchTool,
6
+ PythonInterpreterTool,
7
+ FinalAnswerTool,
8
+ )
9
+ from groq import Groq
10
+ from vision_tool import image_reasoning_tool
11
+ import os
12
+ import time
13
+ from types import SimpleNamespace
14
+
15
+ # ---- TOOLS ----
16
+
17
+
18
+ # ---- GROQ MODEL WRAPPER ----
19
+ class GroqModel:
20
+ def __init__(self, model_name= str):
21
+ self.model_name = model_name
22
+ self.client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
23
+
24
+ def __call__(self, messages, max_tokens=8096):
25
+ params = {
26
+ "model": self.model_name,
27
+ "messages": messages,
28
+ "stream": False,
29
+ "max_completion_tokens": max_tokens,
30
+ }
31
+
32
+ for attempt in range(3):
33
+ try:
34
+ return self.client.chat.completions.create(**params)
35
+ except Exception as e:
36
+ msg = str(e).lower()
37
+ if "rate limit" in str(e).lower() and attempt < 2:
38
+ time.sleep(10 * (attempt + 1))
39
+ else:
40
+ raise
41
+
42
+
43
+ def generate(self, prompt, max_tokens=8096, **kwargs):
44
+ messages=prompt if not isinstance(prompt, str) else [
45
+ {"role":"user", "content": prompt}
46
+ ]
47
+ response = self._chat(messages, max_tokens, **extra)
48
+ return response.choices[0].message
49
+
50
+ def __call__(self, prompt, max_tokens=8_096, **extra):
51
+ return self.generate(prompt, max_tokens, **extra).content
52
+
53
+
54
+ # ---- MULTI-AGENT SYSTEM ----
55
+ class MultyAgentSystem:
56
+ def __init__(self):
57
+ self.primary_model_name = "deepseek-r1-distill-llama-70b"
58
+ self.fallback_model_name = "llama3-70b-8k"
59
+
60
+ self.deepseek_model = GroqModel(self.primary_model_name)
61
+ qwen_model = GroqModel("qwen-qwq-32b")
62
+ self.verification_limit = int(os.getenv("VERIFY_WORD_LIMIT", "75"))
63
+
64
+ # --- Web agent definition ---
65
+ self.web_agent = CodeAgent(
66
+ model=qwen_model,
67
+ tools=[WebSearchTool(), VisitWebpageTool(), WikipediaSearchTool()],
68
+ name="web_agent",
69
+ description=(
70
+ "You are a web browsing agent. Whenever the given {task} involves browsing "
71
+ "the web or a specific website such as Wikipedia or YouTube, you will use "
72
+ "the provided tools. For web-based factual and retrieval tasks, be as precise and source-reliable as possible."
73
+ ),
74
+ additional_authorized_imports=[
75
+ "markdownify",
76
+ "json",
77
+ "requests",
78
+ "urllib.request",
79
+ "urllib.parse",
80
+ "wikipedia-api",
81
+ ],
82
+ verbosity_level=0,
83
+ max_steps=10,
84
+ )
85
+
86
+ # --- Info agent definition ---
87
+ self.info_agent = CodeAgent(
88
+ model=qwen_model,
89
+ tools=[PythonInterpreterTool(), image_reasoning_tool],
90
+ name="info_agent",
91
+ description=(
92
+ "You are an agent tasked with cleaning, parsing, calculating information, and performing OCR if images are provided in the {task}. "
93
+ "You can also analyze images using a vision model. You handle all math, code, and data manipulation. Use numpy, math, and available libraries. "
94
+ "For image or chess tasks, use pytesseract, PIL, chess, or the image_reasoning_tool as required."
95
+ ),
96
+ additional_authorized_imports=[
97
+ "numpy",
98
+ "math",
99
+ "pytesseract",
100
+ "PIL",
101
+ "chess",
102
+ ],
103
+ )
104
+
105
+ # --- Manager agent definition ---
106
+ manager_planning_interval = int(os.getenv("MANAGER_PLANNING_INTERVAL", "3"))
107
+ manager_max_steps = int(os.getenv("MANAGER_MAX_STEPS", "8"))
108
+
109
+ self.manager_agent = CodeAgent(
110
+ model=qwen_model,
111
+ tools=[FinalAnswerTool()],
112
+ managed_agents=[self.web_agent, self.info_agent],
113
+ name="manager_agent",
114
+ description=(
115
+ "You are the manager. Given a {task}, plan which agent to use: "
116
+ "If web data is needed, delegate to web_agent. If math, parsing, image reasoning, or code is needed, use info_agent. "
117
+ "After collecting outputs, optionally cross-validate and check correctness, then finalize and submit the best answer using FinalAnswerTool. "
118
+ "For each task, explicitly explain your planning steps and reasons for choosing which agent, and always prefer the most accurate and complete answer possible."
119
+ ),
120
+ additional_authorized_imports=[
121
+ "json",
122
+ "pandas",
123
+ "numpy",
124
+ ],
125
+ planning_interval=manager_planning_interval,
126
+ verbosity_level=2,
127
+ max_steps=manager_max_steps,
128
+ )
129
+
130
+ # runtime tracking for fallback switching
131
+ self.total_runtime = 0.0
132
+ self.first_call_duration = None
133
+ self.model_switched = False
134
+
135
+ def _switch_to_fallback(self):
136
+ if self.model_switched:
137
+ return
138
+ self.manager_agent.model = GroqModel(self.fallback_model_name)
139
+ self.model_switched = True
140
+
141
+ def run(self, question, high_stakes: bool = False, **kwargs):
142
+ start_time = time.time()
143
+ print("Generating initial answer with Qwen-32B")
144
+ initial_answer = self.manager_agent(question, **kwargs)
145
+ call_duration = time.time() - start_time
146
+
147
+ answer = initial_answer
148
+ if high_stakes or len(initial_answer.split()) > self.verification_limit:
149
+ print("Verifying answer using DeepSeek-70B")
150
+ verification_prompt = (
151
+ "Review the following answer for accuracy and rewrite if needed:"
152
+ f"\n\n{initial_answer}"
153
+ )
154
+ try:
155
+ answer = self.deepseek_model(verification_prompt)
156
+ except Exception as e:
157
+ print(f"Verification failed: {e}. Using initial answer.")
158
+ answer = initial_answer
159
+
160
+ if self.first_call_duration is None:
161
+ self.first_call_duration = call_duration
162
+ if self.first_call_duration > 30:
163
+ self._switch_to_fallback()
164
+
165
+ self.total_runtime += call_duration
166
+ if self.total_runtime > 300 and not self.model_switched:
167
+ self._switch_to_fallback()
168
+
169
+ return answer
170
+
171
+ def __call__(self, question, high_stakes: bool = False, **kwargs):
172
+
173
+ return self.run(question, high_stakes=high_stakes, **kwargs)
vision_tool.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vision tool using Groq's Meta-Llama Scout model
2
+ from smolagents import tool
3
+ from groq import Groq
4
+
5
+ import os
6
+
7
+
8
+ def _llama_analyze(image_b64: str, prompt: str) -> str:
9
+ """Internal helper to query the Llama vision model."""
10
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
11
+ messages = [
12
+ {
13
+ "role": "user",
14
+ "content": [
15
+ {"type": "text", "text": prompt},
16
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
17
+ ],
18
+ }
19
+ ]
20
+ response = client.chat.completions.create(
21
+ model="meta-llama/llama-4-scout-17b-16e-instruct",
22
+ messages=messages,
23
+ stream=False,
24
+ max_completion_tokens=512,
25
+ )
26
+ return response.choices[0].message.content
27
+
28
+
29
+ @tool
30
+ def image_reasoning_tool(image_file: str, prompt: str | None = None) -> dict:
31
+ """Perform OCR and optional vision analysis on an image.
32
+
33
+ This single entry point unifies OCR extraction and Llama vision reasoning so
34
+ the planner only sees one image tool.
35
+
36
+ Args:
37
+ image_file: Path to the image file to analyze.
38
+ prompt: Optional instruction for the vision model. If omitted, only OCR
39
+ is performed.
40
+
41
+ Returns:
42
+ Dictionary with OCR text, base64 image data and optional vision model
43
+ response.
44
+ """
45
+ try:
46
+ from PIL import Image
47
+ from smolagents.utils import encode_image_base64
48
+ import pytesseract
49
+
50
+ image = Image.open(image_file)
51
+ b64 = encode_image_base64(image)
52
+ ocr_text = pytesseract.image_to_string(image)
53
+
54
+ vision_text = ""
55
+ if prompt:
56
+ try:
57
+ vision_text = _llama_analyze(b64, prompt)
58
+ except Exception as e: # vision errors shouldn't break OCR result
59
+ vision_text = f"Error processing image with vision model: {e}"
60
+
61
+ return {"ocr_text": ocr_text, "vision_text": vision_text, "base64_image": b64}
62
+ except Exception as e:
63
+ return {
64
+ "ocr_text": "",
65
+ "vision_text": "",
66
+ "base64_image": "",
67
+ "error": str(e),
68
+ }
69
+
70
+