From a61420952e293f230c3a73832c27192db8a1a13f Mon Sep 17 00:00:00 2001 From: justincc Date: Tue, 19 May 2026 20:24:30 +0100 Subject: [PATCH] fix(agent): set tool_name on tool-result messages at construction time Introduces make_tool_result_message() in tool_dispatch_helpers.py as the single place where tool-result message dicts are built. All six construction sites in tool_executor.py, agent_runtime_helpers.py, and mini_swe_runner.py now use it, so tool_name is set in memory from the moment a message is created rather than relying on fallback logic in the flush paths. Fixes blank tool_name in both state.db and JSON session logs. Adds tests. --- agent/agent_runtime_helpers.py | 13 +++--- agent/tool_dispatch_helpers.py | 14 ++++++ agent/tool_executor.py | 40 ++++++----------- mini_swe_runner.py | 9 ++-- .../test_tool_name_db_persistence.py | 45 +++++++++++++++++++ tests/test_hermes_state.py | 17 +++++++ 6 files changed, 99 insertions(+), 39 deletions(-) create mode 100644 tests/run_agent/test_tool_name_db_persistence.py diff --git a/agent/agent_runtime_helpers.py b/agent/agent_runtime_helpers.py index 56f4e5ba3a..7a9a0961a7 100644 --- a/agent/agent_runtime_helpers.py +++ b/agent/agent_runtime_helpers.py @@ -39,7 +39,7 @@ from agent.message_sanitization import ( _repair_tool_call_arguments, _sanitize_surrogates, ) -from agent.tool_dispatch_helpers import _trajectory_normalize_msg +from agent.tool_dispatch_helpers import _trajectory_normalize_msg, make_tool_result_message from agent.trajectory import convert_scratchpad_to_think from agent.error_classifier import classify_api_error, FailoverReason from utils import base_url_host_matches, base_url_hostname, env_var_enabled, atomic_json_write @@ -317,12 +317,11 @@ def sanitize_tool_call_arguments( if existing_tool_msg is None: messages.insert( insert_at, - { - "role": "tool", - "name": function_name if function_name != "?" else "", - "tool_call_id": tool_call_id, - "content": marker, - }, + make_tool_result_message( + function_name if function_name != "?" else "", + marker, + tool_call_id, + ), ) insert_at += 1 else: diff --git a/agent/tool_dispatch_helpers.py b/agent/tool_dispatch_helpers.py index 30aa8869db..789371edfa 100644 --- a/agent/tool_dispatch_helpers.py +++ b/agent/tool_dispatch_helpers.py @@ -317,6 +317,19 @@ def _trajectory_normalize_msg(msg: Dict[str, Any]) -> Dict[str, Any]: return msg +def make_tool_result_message(name: str, content: Any, tool_call_id: str) -> dict: + """Build a tool-result message dict with both the OpenAI-format ``name`` + field (required by the wire format and provider adapters) and the internal + ``tool_name`` field (written to the session DB messages table).""" + return { + "role": "tool", + "name": name, + "tool_name": name, + "content": content, + "tool_call_id": tool_call_id, + } + + __all__ = [ "_NEVER_PARALLEL_TOOLS", "_PARALLEL_SAFE_TOOLS", @@ -333,4 +346,5 @@ __all__ = [ "_extract_file_mutation_targets", "_extract_error_preview", "_trajectory_normalize_msg", + "make_tool_result_message", ] diff --git a/agent/tool_executor.py b/agent/tool_executor.py index 12bc725513..b161b507e8 100644 --- a/agent/tool_executor.py +++ b/agent/tool_executor.py @@ -35,6 +35,7 @@ from agent.tool_dispatch_helpers import ( _is_multimodal_tool_result, _multimodal_text_summary, _append_subdir_hint_to_multimodal, + make_tool_result_message, ) from tools.terminal_tool import ( _get_approval_callback, @@ -74,12 +75,11 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe if agent._interrupt_requested: print(f"{agent.log_prefix}⚡ Interrupt: skipping {num_tools} tool call(s)") for tc in tool_calls: - messages.append({ - "role": "tool", - "name": tc.function.name, - "content": f"[Tool execution cancelled — {tc.function.name} was skipped due to user interrupt]", - "tool_call_id": tc.id, - }) + messages.append(make_tool_result_message( + tc.function.name, + f"[Tool execution cancelled — {tc.function.name} was skipped due to user interrupt]", + tc.id, + )) return # ── Parse args + pre-execution bookkeeping ─────────────────────── @@ -443,13 +443,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe # image tool result never poisons canonical session history. # String results pass through unchanged. _tool_content = agent._tool_result_content_for_active_model(name, function_result) - tool_msg = { - "role": "tool", - "name": name, - "content": _tool_content, - "tool_call_id": tc.id, - } - messages.append(tool_msg) + messages.append(make_tool_result_message(name, _tool_content, tc.id)) # ── Per-tool /steer drain ─────────────────────────────────── # Same as the sequential path: drain between each collected @@ -864,13 +858,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe # Unwrap _multimodal dicts to an OpenAI-style content list # (see parallel path for rationale). String results pass through. _tool_content = agent._tool_result_content_for_active_model(function_name, function_result) - tool_msg = { - "role": "tool", - "name": function_name, - "content": _tool_content, - "tool_call_id": tool_call.id - } - messages.append(tool_msg) + messages.append(make_tool_result_message(function_name, _tool_content, tool_call.id)) # ── Per-tool /steer drain ─────────────────────────────────── # Drain pending steer BETWEEN individual tool calls so the @@ -892,13 +880,11 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe agent._vprint(f"{agent.log_prefix}⚡ Interrupt: skipping {remaining} remaining tool call(s)", force=True) for skipped_tc in assistant_message.tool_calls[i:]: skipped_name = skipped_tc.function.name - skip_msg = { - "role": "tool", - "name": skipped_name, - "content": f"[Tool execution skipped — {skipped_name} was not started. User sent a new message]", - "tool_call_id": skipped_tc.id - } - messages.append(skip_msg) + messages.append(make_tool_result_message( + skipped_name, + f"[Tool execution skipped — {skipped_name} was not started. User sent a new message]", + skipped_tc.id, + )) break if agent.tool_delay > 0 and i < len(assistant_message.tool_calls): diff --git a/mini_swe_runner.py b/mini_swe_runner.py index c434515045..e3d2f174e9 100644 --- a/mini_swe_runner.py +++ b/mini_swe_runner.py @@ -38,6 +38,7 @@ from typing import List, Dict, Any, Optional, Literal import fire from dotenv import load_dotenv +from agent.tool_dispatch_helpers import make_tool_result_message # Load environment variables load_dotenv() @@ -536,11 +537,9 @@ Complete the user's task step by step.""" completed = True # Add tool response - messages.append({ - "role": "tool", - "content": result_json, - "tool_call_id": tc.id - }) + messages.append(make_tool_result_message( + tc.function.name, result_json, tc.id, + )) print(f" ✅ exit_code={result['exit_code']}, output={len(result['output'])} chars") diff --git a/tests/run_agent/test_tool_name_db_persistence.py b/tests/run_agent/test_tool_name_db_persistence.py new file mode 100644 index 0000000000..3fcf7f33c3 --- /dev/null +++ b/tests/run_agent/test_tool_name_db_persistence.py @@ -0,0 +1,45 @@ +"""Test that tool_name is correctly persisted to the session DB for tool-result messages. + +make_tool_result_message() sets tool_name on every tool-result dict at construction +time. This test verifies that the value survives the flush path into the session DB. +""" +from unittest.mock import MagicMock, patch + +from run_agent import AIAgent +from agent.tool_dispatch_helpers import make_tool_result_message + + +def _make_agent(session_db): + with ( + patch("run_agent.get_tool_definitions", return_value=[]), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + return AIAgent( + api_key="test-key", + base_url="https://openrouter.ai/api/v1", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + session_db=session_db, + ) + + +def test_tool_name_persisted_to_session_db(): + """tool_name set by make_tool_result_message must be passed through to + append_message so the column is populated on first flush to the session DB.""" + session_db = MagicMock() + agent = _make_agent(session_db) + + messages = [ + {"role": "user", "content": "run a command"}, + make_tool_result_message("terminal", "$ ls\nfile.txt", "c1"), + ] + agent._flush_messages_to_session_db(messages) + + tool_appends = [ + c for c in session_db.append_message.call_args_list + if c.kwargs.get("role") == "tool" + ] + assert len(tool_appends) == 1 + assert tool_appends[0].kwargs["tool_name"] == "terminal" diff --git a/tests/test_hermes_state.py b/tests/test_hermes_state.py index 3bae763b94..2676457f58 100644 --- a/tests/test_hermes_state.py +++ b/tests/test_hermes_state.py @@ -267,6 +267,23 @@ class TestMessageStorage: ).fetchone() assert row["content"] == "plain text" + def test_replace_messages_persists_tool_name(self, db): + """`replace_messages` (used by /retry, /undo, /compress) must write + tool_name to the DB for messages built by make_tool_result_message.""" + from agent.tool_dispatch_helpers import make_tool_result_message + db.create_session(session_id="s1", source="cli") + db.replace_messages( + "s1", + [ + {"role": "user", "content": "do something"}, + make_tool_result_message("web_search", "some results", "c1"), + ], + ) + + msgs = db.get_messages("s1") + tool_msg = next(m for m in msgs if m["role"] == "tool") + assert tool_msg["tool_name"] == "web_search" + def test_replace_messages_handles_multimodal_content(self, db): """`replace_messages` (used by /retry, /undo, /compress) must also handle list content without crashing."""