Spaces:
Running
Running
| import os | |
| import yaml | |
| import asyncio | |
| import inspect | |
| import json | |
| import time | |
| from typing import Any, Dict, AsyncIterator, List, Tuple, Optional | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from retrac.graph import build_graph | |
| # ----------------------------- | |
| # Your existing logic (kept) | |
| # ----------------------------- | |
| def load_config(config_path: str) -> dict: | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| return yaml.safe_load(f) | |
| async def stream_graph_execution( | |
| graph, | |
| initial_state: Dict[str, Any], | |
| ) -> AsyncIterator[Dict[str, Any]]: | |
| compiled_graph = graph.compile() | |
| recursion_limit = int(os.getenv("RECURSION_LIMIT", "10000")) | |
| run_config = {"recursion_limit": recursion_limit} | |
| state = initial_state.copy() | |
| async for event in compiled_graph.astream(state, config=run_config): | |
| yield event | |
| # same as your code: merge node outputs into state | |
| for node_name, node_output in event.items(): | |
| if node_output is not None: | |
| state.update(node_output) | |
| async def run_streaming(config_path: str, question: str) -> AsyncIterator[Dict[str, Any]]: | |
| cfg = load_config(config_path) | |
| graph = build_graph(cfg) | |
| initial_state = {"question": question} | |
| async for event in stream_graph_execution(graph, initial_state): | |
| yield event | |
| # ----------------------------- | |
| # Message serialization helpers | |
| # ----------------------------- | |
| def _safe_str(x: Any) -> str: | |
| try: | |
| return str(x) | |
| except Exception: | |
| return repr(x) | |
| def _strip_think_tags(text: str) -> str: | |
| if not text: | |
| return text | |
| return ( | |
| text.replace("<think>", "") | |
| .replace("</think>", "") | |
| .replace("<think/>", "") | |
| .strip() | |
| ) | |
| def _try_parse_json(value: Any) -> Any: | |
| if isinstance(value, (dict, list)): | |
| return value | |
| if isinstance(value, str): | |
| try: | |
| return json.loads(value) | |
| except Exception: | |
| return value | |
| return value | |
| def _format_json_block(value: Any) -> str: | |
| parsed = _try_parse_json(value) | |
| if isinstance(parsed, (dict, list)): | |
| body = json.dumps(parsed, ensure_ascii=False, indent=2) | |
| else: | |
| body = _safe_str(parsed) | |
| return f"```json\n{body}\n```" | |
| def _debug_log(message: str, data: Dict[str, Any], hypothesis_id: str, run_id: str = "pre-fix") -> None: | |
| # #region agent log | |
| try: | |
| payload = { | |
| "id": f"log_{int(time.time() * 1000)}_{os.getpid()}", | |
| "timestamp": int(time.time() * 1000), | |
| "location": "RE-TRAC/app.py", | |
| "message": message, | |
| "data": data, | |
| "sessionId": "debug-session", | |
| "runId": run_id, | |
| "hypothesisId": hypothesis_id, | |
| } | |
| with open("/home/azureuser/jialiang/repo/tz-dr/.cursor/debug.log", "a", encoding="utf-8") as f: | |
| f.write(json.dumps(payload, ensure_ascii=True) + "\n") | |
| except Exception: | |
| pass | |
| # #endregion | |
| CHATBOT_EXPECTS_MESSAGES: Optional[bool] = None | |
| CHATBOT_SUPPORTS_GROUPING: Optional[bool] = None | |
| def _gradio_version_major() -> Optional[int]: | |
| version = getattr(gr, "__version__", None) | |
| if not version: | |
| return None | |
| try: | |
| return int(str(version).split(".")[0]) | |
| except Exception: | |
| return None | |
| def _message_fingerprint(msg: Any) -> str: | |
| role, content = _extract_role_and_content(msg) | |
| return f"{role}::{content}" | |
| def _is_tool_message(msg: Any) -> bool: | |
| if isinstance(msg, dict): | |
| return (msg.get("role") == "tool") or (msg.get("type") == "tool") | |
| return getattr(msg, "type", None) == "tool" | |
| def _extract_role_and_content(msg: Any) -> Tuple[str, str]: | |
| """ | |
| Convert various message types into (role, content) for display. | |
| Supports: | |
| - LangChain BaseMessage-like objects (has .type/.content) | |
| - dict messages {"role": "...", "content": "..."} or similar | |
| """ | |
| # dict-like | |
| if isinstance(msg, dict): | |
| role = msg.get("role") or msg.get("type") or "unknown" | |
| content = msg.get("content") | |
| if content is None: | |
| # fallback: show whole dict | |
| content = _safe_str(msg) | |
| return str(role), _safe_str(content) | |
| # BaseMessage-like | |
| mtype = getattr(msg, "type", None) | |
| content = getattr(msg, "content", None) | |
| # ToolMessage often has name/tool_call_id, show them if present | |
| if mtype == "tool": | |
| return "tool", _safe_str(content) | |
| if mtype == "human": | |
| return "user", _safe_str(content) | |
| if mtype == "ai": | |
| return "assistant", _safe_str(content) | |
| if mtype == "system": | |
| return "system", _safe_str(content) | |
| # unknown object: try best | |
| if content is not None and mtype is not None: | |
| return _safe_str(mtype), _safe_str(content) | |
| return "unknown", _safe_str(msg) | |
| def _extract_tool_calls(msg: Any) -> List[Dict[str, Any]]: | |
| if isinstance(msg, dict): | |
| tool_calls = msg.get("tool_calls") or msg.get("tool_calls", []) | |
| if isinstance(tool_calls, list): | |
| return tool_calls | |
| return [] | |
| tool_calls = getattr(msg, "tool_calls", None) | |
| if isinstance(tool_calls, list): | |
| return tool_calls | |
| return [] | |
| def _split_ai_think(content: str) -> List[str]: | |
| if not content: | |
| return [] | |
| if "</think>" not in content and "<think/>" not in content and "<think>" not in content: | |
| return [content] | |
| if "<think/>" in content and "</think>" not in content: | |
| after = _strip_think_tags(content.replace("<think/>", "", 1)) | |
| return [after] if after else [] | |
| if "<think>" in content and "</think>" not in content: | |
| cleaned = _strip_think_tags(content) | |
| return [cleaned] if cleaned else [] | |
| before, after = content.split("</think>", 1) | |
| before = _strip_think_tags(before) | |
| after = _strip_think_tags(after.lstrip()) | |
| parts = [] | |
| if before: | |
| parts.append(before) | |
| if after: | |
| parts.append(after) | |
| return parts | |
| def _tool_call_to_text(tool_call: Any) -> str: | |
| if isinstance(tool_call, dict): | |
| name = tool_call.get("name") or tool_call.get("tool") or "tool_call" | |
| args = tool_call.get("args") or tool_call.get("arguments") or {} | |
| else: | |
| name = getattr(tool_call, "name", None) or getattr(tool_call, "tool", None) or "tool_call" | |
| args = getattr(tool_call, "args", None) or getattr(tool_call, "arguments", None) or {} | |
| return f"**Tool Call** ({name})\n{_format_json_block(args)}" | |
| def _explode_message(msg: Any, is_nonfirst_human: bool) -> List[Tuple[str, str, str]]: | |
| role, content = _extract_role_and_content(msg) | |
| parts: List[Tuple[str, str, str]] = [] | |
| render_role = "assistant" if role == "user" else role | |
| _debug_log( | |
| "explode_message entry", | |
| { | |
| "role": role, | |
| "render_role": render_role, | |
| "has_think": ("<think" in content) or ("</think>" in content), | |
| "content_preview": _safe_str(content)[:200], | |
| }, | |
| "H8", | |
| ) | |
| if _is_tool_message(msg): | |
| tool_body = _format_json_block(content) | |
| return [(render_role, f"**Tool**\n{tool_body}", "tool")] | |
| if role in ("assistant", "user"): | |
| split_parts = _split_ai_think(content) | |
| _debug_log( | |
| "explode_message split", | |
| { | |
| "role": role, | |
| "split_count": len(split_parts), | |
| "split_previews": [p[:120] for p in split_parts], | |
| }, | |
| "H8", | |
| ) | |
| has_think = ("</think>" in content) or ("<think/>" in content) or ("<think>" in content) | |
| for idx, part in enumerate(split_parts): | |
| if has_think: | |
| label = "**Model Reasoning**" if idx == 0 and "</think>" in content else "**Model Output**" | |
| body = f"{label}\n{_strip_think_tags(part)}" | |
| else: | |
| body = _strip_think_tags(part) | |
| if role == "user" and is_nonfirst_human: | |
| body = f"{body}" | |
| parts.append((render_role, body, role)) | |
| for tool_call in _extract_tool_calls(msg): | |
| parts.append((render_role, _tool_call_to_text(tool_call), "tool_call")) | |
| return parts or [(render_role, content, role)] | |
| return [(role, content, "assistant" if role in ("assistant", "system") else "user")] | |
| def _explode_messages(messages: List[Any]) -> List[Tuple[str, str, str]]: | |
| exploded: List[Tuple[str, str, str]] = [] | |
| human_count = 0 | |
| for msg in messages: | |
| role, _ = _extract_role_and_content(msg) | |
| if role == "user": | |
| human_count += 1 | |
| is_nonfirst_human = human_count > 1 | |
| else: | |
| is_nonfirst_human = False | |
| exploded.extend(_explode_message(msg, is_nonfirst_human)) | |
| return exploded | |
| def serialize_messages_for_gradio(messages: List[Any]) -> List[Dict[str, str]]: | |
| """ | |
| Gradio Chatbot(type="messages") expects: | |
| [{"role": "user"/"assistant"/"system", "content": "..."}, ...] | |
| Tool messages are mapped into role="assistant" with a prefix. | |
| """ | |
| out: List[Dict[str, str]] = [] | |
| for role, content, kind in _explode_messages(messages): | |
| is_tool = kind in ("tool_call", "tool") | |
| # normalize role into gradio-friendly roles | |
| if role not in ("user", "assistant", "system"): | |
| # put everything else as assistant, but keep a prefix | |
| content = f"{content}" | |
| role = "assistant" | |
| out.append({"role": role, "content": content}) | |
| return out | |
| def _chatbot_supports_messages() -> bool: | |
| try: | |
| params = list(inspect.signature(gr.Chatbot).parameters.keys()) | |
| supports = "type" in params | |
| grouping = "group_consecutive_messages" in params | |
| major = _gradio_version_major() | |
| _debug_log( | |
| "chatbot signature check", | |
| { | |
| "supports_messages": supports, | |
| "supports_grouping": grouping, | |
| "params": params, | |
| "gradio_version": getattr(gr, "__version__", None), | |
| "gradio_major": major, | |
| }, | |
| "H1", | |
| ) | |
| global CHATBOT_SUPPORTS_GROUPING | |
| CHATBOT_SUPPORTS_GROUPING = grouping | |
| return supports | |
| except Exception: | |
| _debug_log( | |
| "chatbot signature check failed", | |
| {"supports_messages": False}, | |
| "H1", | |
| ) | |
| return False | |
| def _should_use_messages_format() -> bool: | |
| major = _gradio_version_major() | |
| if major is not None and major >= 6: | |
| _debug_log( | |
| "format decision", | |
| {"reason": "gradio_major>=6", "gradio_major": major}, | |
| "H1", | |
| ) | |
| return True | |
| if CHATBOT_SUPPORTS_MESSAGES: | |
| _debug_log( | |
| "format decision", | |
| {"reason": "chatbot_supports_type", "gradio_major": major}, | |
| "H1", | |
| ) | |
| return True | |
| _debug_log( | |
| "format decision", | |
| {"reason": "legacy_tuple", "gradio_major": major}, | |
| "H1", | |
| ) | |
| return False | |
| CHATBOT_SUPPORTS_MESSAGES = _chatbot_supports_messages() | |
| def serialize_messages_for_legacy_chatbot( | |
| messages: List[Any], | |
| ) -> List[Tuple[Optional[str], Optional[str]]]: | |
| """ | |
| Legacy gr.Chatbot expects list of (user, bot) tuples. | |
| We map non-user roles onto the bot side with a role prefix. | |
| """ | |
| rendered: List[Tuple[Optional[str], Optional[str]]] = [] | |
| for role, content, kind in _explode_messages(messages): | |
| is_tool = kind in ("tool_call", "tool") | |
| if role == "user": | |
| rendered.append((None, f"**RE-TRAC: REcursive TRAjectory Compression**\n{content}")) | |
| continue | |
| if role != "assistant": | |
| content = f"{content}" | |
| if is_tool: | |
| rendered.append((None, content)) | |
| else: | |
| if rendered and rendered[-1][1] is None: | |
| rendered[-1] = (rendered[-1][0], content) | |
| else: | |
| rendered.append((None, content)) | |
| return rendered | |
| def serialize_messages_for_chatbot(messages: List[Any]) -> List[Any]: | |
| _debug_log( | |
| "serialize_messages_for_chatbot", | |
| { | |
| "supports_messages": CHATBOT_SUPPORTS_MESSAGES, | |
| "expects_messages": CHATBOT_EXPECTS_MESSAGES, | |
| "count": len(messages), | |
| }, | |
| "H2", | |
| ) | |
| if CHATBOT_EXPECTS_MESSAGES: | |
| return serialize_messages_for_gradio(messages) | |
| return serialize_messages_for_legacy_chatbot(messages) | |
| def append_messages_for_chatbot( | |
| rendered: List[Any], | |
| tail_messages: List[Any], | |
| ) -> List[Any]: | |
| _debug_log( | |
| "append_messages_for_chatbot", | |
| { | |
| "supports_messages": CHATBOT_SUPPORTS_MESSAGES, | |
| "expects_messages": CHATBOT_EXPECTS_MESSAGES, | |
| "rendered_len": len(rendered), | |
| "tail_len": len(tail_messages), | |
| }, | |
| "H3", | |
| ) | |
| if CHATBOT_EXPECTS_MESSAGES: | |
| rendered.extend(serialize_messages_for_gradio(tail_messages)) | |
| return rendered | |
| tail_rendered = serialize_messages_for_legacy_chatbot(tail_messages) | |
| if not tail_rendered: | |
| return rendered | |
| if not rendered: | |
| return tail_rendered | |
| first = tail_rendered[0] | |
| if rendered[-1][1] is None and first[0] is None and first[1] is not None: | |
| rendered[-1] = (rendered[-1][0], first[1]) | |
| tail_rendered = tail_rendered[1:] | |
| rendered.extend(tail_rendered) | |
| return rendered | |
| def fingerprint_messages(messages: List[Any]) -> List[str]: | |
| """ | |
| Create a stable-ish fingerprint list for prefix checks. | |
| We avoid id() because objects may be recreated. | |
| """ | |
| fp: List[str] = [] | |
| for m in messages: | |
| role, content = _extract_role_and_content(m) | |
| fp.append(f"{role}::{content}") | |
| return fp | |
| def compute_render_strategy( | |
| prev_fp: List[str], | |
| new_fp: List[str], | |
| ) -> Tuple[str, int]: | |
| """ | |
| Decide whether we can append incrementally or must rebuild. | |
| Returns: | |
| ("append", start_index) if prev is a prefix of new | |
| ("rebuild", 0) otherwise | |
| """ | |
| if not prev_fp: | |
| return ("rebuild", 0) | |
| if len(new_fp) >= len(prev_fp) and new_fp[: len(prev_fp)] == prev_fp: | |
| return ("append", len(prev_fp)) | |
| return ("rebuild", 0) | |
| # ----------------------------- | |
| # Gradio app | |
| # ----------------------------- | |
| CONFIG_PATH = os.getenv("RETRAC_CONFIG", "retrac/30B.yaml") | |
| async def run_once(query: str, ui_state: Optional[Dict[str, Any]]): | |
| """ | |
| Single-run streaming handler. | |
| ui_state stores: | |
| - "prev_fp": fingerprint list for diffing | |
| - "rendered": messages already rendered (as gradio 'messages' format) | |
| - "merged_state": current merged graph state (dict) | |
| - "done": bool | |
| """ | |
| if ui_state is None: | |
| ui_state = {} | |
| _debug_log( | |
| "run_once entry", | |
| { | |
| "query": _safe_str(query), | |
| "done": bool(ui_state.get("done")), | |
| "supports_messages": CHATBOT_SUPPORTS_MESSAGES, | |
| "expects_messages": CHATBOT_EXPECTS_MESSAGES, | |
| }, | |
| "H4", | |
| ) | |
| # Enforce one-time input per run. | |
| if ui_state.get("done"): | |
| # already finished; no second run without reset | |
| yield ( | |
| ui_state.get("rendered", []), | |
| ui_state, | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| "Already finished. Click Reset to run again.", | |
| ) | |
| return | |
| merged_state: Dict[str, Any] = ui_state.get("merged_state", {}) | |
| prev_fp: List[str] = ui_state.get("prev_fp", []) | |
| seen_fp: List[str] = ui_state.get("seen_fp", []) | |
| all_messages: List[Any] = ui_state.get("all_messages", []) | |
| rendered: List[Dict[str, str]] = ui_state.get("rendered", []) | |
| # disable input immediately after submit | |
| yield ( | |
| rendered, | |
| ui_state, | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| "Running...", | |
| ) | |
| try: | |
| async for event in run_streaming(CONFIG_PATH, query): | |
| # merge node outputs into merged_state (same logic as your runner) | |
| for _, node_output in event.items(): | |
| if node_output is not None: | |
| merged_state.update(node_output) | |
| # pick the full message list from merged_state | |
| full_messages = merged_state.get("messages", []) | |
| # accumulate messages; never clear already displayed ones | |
| new_messages: List[Any] = [] | |
| for m in full_messages: | |
| fp = _message_fingerprint(m) | |
| if fp not in seen_fp: | |
| seen_fp.append(fp) | |
| all_messages.append(m) | |
| new_messages.append(m) | |
| if not rendered: | |
| rendered = serialize_messages_for_chatbot(all_messages) | |
| elif new_messages: | |
| rendered = append_messages_for_chatbot(rendered, new_messages) | |
| new_fp = fingerprint_messages(full_messages) | |
| _debug_log( | |
| "rendered update", | |
| { | |
| "mode": "accumulate", | |
| "rendered_len": len(rendered), | |
| "rendered_item_type": type(rendered[0]).__name__ if rendered else None, | |
| "all_messages_len": len(all_messages), | |
| "new_messages_len": len(new_messages), | |
| }, | |
| "H5", | |
| ) | |
| prev_fp = new_fp | |
| # update ui_state | |
| ui_state["merged_state"] = merged_state | |
| ui_state["prev_fp"] = prev_fp | |
| ui_state["seen_fp"] = seen_fp | |
| ui_state["all_messages"] = all_messages | |
| ui_state["rendered"] = rendered | |
| # show some lightweight status | |
| status = "Running..." | |
| if "final" in event: | |
| status = "Finished." | |
| ui_state["done"] = True | |
| yield ( | |
| rendered, | |
| ui_state, | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| status, | |
| ) | |
| if "final" in event: | |
| break | |
| except Exception as e: | |
| ui_state["done"] = True | |
| yield ( | |
| rendered, | |
| ui_state, | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| f"Error: {_safe_str(e)}", | |
| ) | |
| def reset(): | |
| # re-enable input | |
| return [], None, gr.update(interactive=True, value=""), gr.update(interactive=True), "Idle." | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## RE-TRAC (REcursive TRAjectory Compression) 30B Demo") | |
| with gr.Row(): | |
| with gr.Column(scale=6): | |
| gr.Markdown(f"Config: `{CONFIG_PATH}`") | |
| chatbot_kwargs = {"height": 520, "label": "Messages (full stream)"} | |
| if CHATBOT_SUPPORTS_GROUPING: | |
| chatbot_kwargs["group_consecutive_messages"] = False | |
| if CHATBOT_SUPPORTS_MESSAGES: | |
| chatbot = gr.Chatbot(type="messages", **chatbot_kwargs) | |
| else: | |
| chatbot = gr.Chatbot(**chatbot_kwargs) | |
| # #region agent log | |
| try: | |
| attr_type = getattr(chatbot, "type", None) | |
| attr_format = getattr(chatbot, "format", None) | |
| attr_data_format = getattr(chatbot, "data_format", None) | |
| CHATBOT_EXPECTS_MESSAGES = _should_use_messages_format() or ( | |
| (attr_type == "messages") or (attr_format == "messages") or (attr_data_format == "messages") | |
| ) | |
| _debug_log( | |
| "chatbot instance attrs", | |
| { | |
| "attr_type": attr_type, | |
| "attr_format": attr_format, | |
| "attr_data_format": attr_data_format, | |
| "expects_messages": CHATBOT_EXPECTS_MESSAGES, | |
| }, | |
| "H1", | |
| ) | |
| except Exception: | |
| pass | |
| # #endregion | |
| status = gr.Markdown("Idle.") | |
| ui_state = gr.State(None) | |
| query = gr.Textbox(label="Query (only once)", placeholder="Enter your query and press Run") | |
| run_btn = gr.Button("Run", variant="primary") | |
| reset_btn = gr.Button("Reset") | |
| with gr.Column(scale=4): | |
| gr.Markdown( | |
| """ | |
| ## Introduction | |
| We introduce Re-TRAC, a recursive framework that resolves the inefficiency of isolated trials in agentic search. It not only boosts both commercial and open-source models by 15β20% over ReAct on BrowseComp but also drives SFT-only performance to new heights (30% for 4B, 53% for 30B). | |
| </div> | |
| <p align="center"> | |
| π€ <a href="https://github.com/microsoft/InfoAgent/tree/main/retrac">HuggingFace(comming soon)</a> | | |
| π» <a href="https://github.com/microsoft/InfoAgent/tree/main/retrac">GitHub</a> | | |
| π <a href="https://arxiv.org/abs/2602.02486">Paper</a> | | |
| π <a href="https://huggingface.co/spaces/JialiangZhu/RE-TRAC">Demo</a> | |
| </p> | |
| ## Method Overview | |
| """ | |
| ) | |
| gr.Image( | |
| value="image/method.png", | |
| label="Method Overview", | |
| show_label=False, | |
| ) | |
| gr.Markdown("## Main Result") | |
| gr.Image( | |
| value="image/main_table.png", | |
| label="Main Result", | |
| show_label=False, | |
| ) | |
| run_btn.click( | |
| run_once, | |
| inputs=[query, ui_state], | |
| outputs=[chatbot, ui_state, query, run_btn, status], | |
| ) | |
| # Also allow pressing Enter (still single-run because we disable immediately) | |
| query.submit( | |
| run_once, | |
| inputs=[query, ui_state], | |
| outputs=[chatbot, ui_state, query, run_btn, status], | |
| ) | |
| reset_btn.click( | |
| reset, | |
| outputs=[chatbot, ui_state, query, run_btn, status], | |
| ) | |
| demo.queue() # required for async streaming + multiple yields | |
| demo.launch() | |