mirror of
https://github.com/droidrun/droidrun.git
synced 2026-05-23 07:40:37 +00:00
Merge pull request #273 from droidrun/thinking
Preserve thinking tokens & ephemeral state injection
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
from llama_index.core.base.llms.types import ChatMessage, ImageBlock, TextBlock
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
|
||||
from opentelemetry import trace
|
||||
@@ -21,7 +23,6 @@ from droidrun.agent.usage import get_usage_from_response
|
||||
from droidrun.agent.utils.chat_utils import (
|
||||
extract_code_and_thought,
|
||||
limit_history,
|
||||
to_chat_messages,
|
||||
)
|
||||
from droidrun.agent.utils.executer import ExecuterState, SimpleCodeExecutor
|
||||
from droidrun.agent.utils.inference import acall_with_retries
|
||||
@@ -52,7 +53,7 @@ class CodeActAgent(Workflow):
|
||||
Agent that generates and executes Python code using atomic actions.
|
||||
|
||||
Uses ReAct cycle: Thought -> Code -> Observation -> repeat until complete().
|
||||
Messages stored as list[dict], converted to ChatMessage only for LLM calls.
|
||||
Messages stored as list[ChatMessage] to preserve thinking tokens across turns.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -90,7 +91,7 @@ class CodeActAgent(Workflow):
|
||||
self.prompt_resolver = prompt_resolver or PromptResolver()
|
||||
self.tracing_config = tracing_config
|
||||
|
||||
self.system_prompt: dict | None = None
|
||||
self.system_prompt: ChatMessage | None = None
|
||||
self.code_exec_counter = 0
|
||||
self.remembered_info: list[str] | None = None
|
||||
|
||||
@@ -150,7 +151,7 @@ class CodeActAgent(Workflow):
|
||||
|
||||
logger.debug("CodeActAgent initialized.")
|
||||
|
||||
async def _build_system_prompt(self) -> dict:
|
||||
async def _build_system_prompt(self) -> ChatMessage:
|
||||
"""Build system prompt message."""
|
||||
# Build template context with available tools for conditional examples
|
||||
template_context = {
|
||||
@@ -178,9 +179,9 @@ class CodeActAgent(Workflow):
|
||||
str(PathResolver.resolve(prompt_path, must_exist=True)),
|
||||
template_context,
|
||||
)
|
||||
return {"role": "system", "content": [{"text": system_text}]}
|
||||
return ChatMessage(role="system", content=system_text)
|
||||
|
||||
async def _build_user_prompt(self, goal: str) -> dict:
|
||||
async def _build_user_prompt(self, goal: str) -> ChatMessage:
|
||||
"""Build initial user prompt message."""
|
||||
custom_user_prompt = self.prompt_resolver.get_prompt("fast_agent_user")
|
||||
if custom_user_prompt:
|
||||
@@ -207,7 +208,7 @@ class CodeActAgent(Workflow):
|
||||
),
|
||||
},
|
||||
)
|
||||
return {"role": "user", "content": [{"text": user_text}]}
|
||||
return ChatMessage(role="user", content=user_text)
|
||||
|
||||
@step
|
||||
async def prepare_chat(self, ctx: Context, ev: StartEvent) -> CodeActInputEvent:
|
||||
@@ -240,8 +241,8 @@ class CodeActAgent(Workflow):
|
||||
for idx, item in enumerate(remembered_info, 1):
|
||||
memory_text += f"{idx}. {item}\n"
|
||||
# Append to first user message
|
||||
self.shared_state.message_history[0]["content"].append(
|
||||
{"text": memory_text}
|
||||
self.shared_state.message_history[0].blocks.append(
|
||||
TextBlock(text=memory_text)
|
||||
)
|
||||
|
||||
return CodeActInputEvent()
|
||||
@@ -296,7 +297,10 @@ class CodeActAgent(Workflow):
|
||||
ui_state = await self.state_provider.get_state()
|
||||
self.action_ctx.ui = ui_state
|
||||
|
||||
# Update shared state
|
||||
# Update shared state (previous ← current, current ← new)
|
||||
self.shared_state.previous_formatted_device_state = (
|
||||
self.shared_state.formatted_device_state
|
||||
)
|
||||
self.shared_state.formatted_device_state = ui_state.formatted_text
|
||||
self.shared_state.focused_text = ui_state.focused_text
|
||||
self.shared_state.a11y_tree = ui_state.elements
|
||||
@@ -311,11 +315,6 @@ class CodeActAgent(Workflow):
|
||||
# Stream formatted state for trajectory
|
||||
ctx.write_event_to_stream(RecordUIStateEvent(ui_state=ui_state.elements))
|
||||
|
||||
# Add device state to last user message
|
||||
self.shared_state.message_history[-1]["content"].append(
|
||||
{"text": f"\n{ui_state.formatted_text}\n"}
|
||||
)
|
||||
|
||||
except DeviceDisconnectedError:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -323,27 +322,51 @@ class CodeActAgent(Workflow):
|
||||
if self.debug:
|
||||
logger.error("State retrieval error details:", exc_info=True)
|
||||
|
||||
# Add screenshot to message if vision enabled
|
||||
if self.vision and screenshot:
|
||||
self.shared_state.message_history[-1]["content"].append(
|
||||
{"image": screenshot}
|
||||
)
|
||||
|
||||
# Limit history and prepare for LLM
|
||||
# Limit history and build ephemeral copy for LLM
|
||||
limited_history = limit_history(
|
||||
self.shared_state.message_history,
|
||||
LLM_HISTORY_LIMIT * 2,
|
||||
preserve_first=True,
|
||||
)
|
||||
messages_to_send = [self.system_prompt] + copy.deepcopy(limited_history)
|
||||
|
||||
# Build final messages: system + history
|
||||
messages_to_send = [self.system_prompt] + limited_history
|
||||
chat_messages = to_chat_messages(messages_to_send)
|
||||
# Inject device state and screenshot into the copy (not the original)
|
||||
user_indices = [
|
||||
i for i, msg in enumerate(messages_to_send) if msg.role == "user"
|
||||
]
|
||||
if user_indices:
|
||||
last_user_idx = user_indices[-1]
|
||||
|
||||
# Current device state → last user message
|
||||
current_state = self.shared_state.formatted_device_state.strip()
|
||||
if current_state:
|
||||
messages_to_send[last_user_idx].blocks.append(
|
||||
TextBlock(
|
||||
text=f"\n<device_state>\n{current_state}\n</device_state>\n"
|
||||
)
|
||||
)
|
||||
|
||||
# Screenshot → last user message
|
||||
if self.vision and screenshot:
|
||||
messages_to_send[last_user_idx].blocks.append(
|
||||
ImageBlock(image=screenshot)
|
||||
)
|
||||
|
||||
# Previous device state → second-to-last user message
|
||||
if len(user_indices) >= 2:
|
||||
second_last_idx = user_indices[-2]
|
||||
prev_state = self.shared_state.previous_formatted_device_state.strip()
|
||||
if prev_state:
|
||||
messages_to_send[second_last_idx].blocks.append(
|
||||
TextBlock(
|
||||
text=f"\n<previous_device_state>\n{prev_state}\n</previous_device_state>\n"
|
||||
)
|
||||
)
|
||||
|
||||
# Call LLM
|
||||
logger.info("CodeAct response:", extra={"color": "yellow"})
|
||||
response = await acall_with_retries(
|
||||
self.llm, chat_messages, stream=self.agent_config.streaming
|
||||
self.llm, messages_to_send, stream=self.agent_config.streaming
|
||||
)
|
||||
|
||||
if response is None:
|
||||
@@ -360,11 +383,9 @@ class CodeActAgent(Workflow):
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get usage: {e}")
|
||||
|
||||
# Store assistant response
|
||||
# Store assistant response (preserves ThinkingBlock, additional_kwargs, etc.)
|
||||
self.shared_state.message_history.append(response.message)
|
||||
response_text = response.message.content
|
||||
self.shared_state.message_history.append(
|
||||
{"role": "assistant", "content": [{"text": response_text}]}
|
||||
)
|
||||
|
||||
# Extract thought and code
|
||||
code, thought = extract_code_and_thought(response_text)
|
||||
@@ -391,7 +412,7 @@ class CodeActAgent(Workflow):
|
||||
"Now, describe the next step you will take to address the original goal."
|
||||
)
|
||||
self.shared_state.message_history.append(
|
||||
{"role": "user", "content": [{"text": no_thoughts_text}]}
|
||||
ChatMessage(role="user", content=no_thoughts_text)
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Reasoning: {ev.thought}")
|
||||
@@ -408,7 +429,7 @@ class CodeActAgent(Workflow):
|
||||
"function within a <python></python> code block."
|
||||
)
|
||||
self.shared_state.message_history.append(
|
||||
{"role": "user", "content": [{"text": no_code_text}]}
|
||||
ChatMessage(role="user", content=no_code_text)
|
||||
)
|
||||
return CodeActInputEvent()
|
||||
|
||||
@@ -482,7 +503,7 @@ class CodeActAgent(Workflow):
|
||||
# Add execution output as user message
|
||||
observation_text = f"Execution Result:\n<result>\n{output}\n</result>"
|
||||
self.shared_state.message_history.append(
|
||||
{"role": "user", "content": [{"text": observation_text}]}
|
||||
ChatMessage(role="user", content=observation_text)
|
||||
)
|
||||
|
||||
return CodeActInputEvent()
|
||||
|
||||
@@ -9,10 +9,12 @@ compatibility with DroidAgent's execute_task() method.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
from llama_index.core.base.llms.types import ChatMessage, ImageBlock, TextBlock
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
|
||||
from opentelemetry import trace
|
||||
@@ -36,7 +38,7 @@ from droidrun.agent.codeact.xml_parser import (
|
||||
from droidrun.agent.common.constants import LLM_HISTORY_LIMIT
|
||||
from droidrun.agent.common.events import RecordUIStateEvent, ScreenshotEvent
|
||||
from droidrun.agent.usage import get_usage_from_response
|
||||
from droidrun.agent.utils.chat_utils import limit_history, to_chat_messages
|
||||
from droidrun.agent.utils.chat_utils import limit_history
|
||||
from droidrun.agent.utils.inference import acall_with_retries
|
||||
from droidrun.agent.utils.prompt_resolver import PromptResolver
|
||||
from droidrun.agent.utils.tracing_setup import record_langfuse_screenshot
|
||||
@@ -57,7 +59,7 @@ class FastAgent(Workflow):
|
||||
"""Agent that uses XML tool-calling instead of code generation.
|
||||
|
||||
Uses ReAct cycle: Thought -> Tool Call -> Observation -> repeat until complete().
|
||||
Messages stored as list[dict], converted to ChatMessage only for LLM calls.
|
||||
Messages stored as list[ChatMessage] to preserve thinking tokens across turns.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -97,7 +99,7 @@ class FastAgent(Workflow):
|
||||
self.prompt_resolver = prompt_resolver or PromptResolver()
|
||||
self.tracing_config = tracing_config
|
||||
|
||||
self.system_prompt: dict | None = None
|
||||
self.system_prompt: ChatMessage | None = None
|
||||
self.tool_call_counter = 0
|
||||
self.remembered_info: list[str] | None = None
|
||||
|
||||
@@ -112,7 +114,7 @@ class FastAgent(Workflow):
|
||||
|
||||
logger.debug("FastAgent initialized.")
|
||||
|
||||
async def _build_system_prompt(self) -> dict:
|
||||
async def _build_system_prompt(self) -> ChatMessage:
|
||||
"""Build system prompt message."""
|
||||
template_context = {
|
||||
"tool_descriptions": self.tool_descriptions,
|
||||
@@ -137,9 +139,9 @@ class FastAgent(Workflow):
|
||||
self.agent_config.get_fast_agent_system_prompt_path(),
|
||||
template_context,
|
||||
)
|
||||
return {"role": "system", "content": [{"text": system_text}]}
|
||||
return ChatMessage(role="system", content=system_text)
|
||||
|
||||
async def _build_user_prompt(self, goal: str) -> dict:
|
||||
async def _build_user_prompt(self, goal: str) -> ChatMessage:
|
||||
"""Build initial user prompt message."""
|
||||
custom_user_prompt = self.prompt_resolver.get_prompt("fast_agent_user")
|
||||
if custom_user_prompt:
|
||||
@@ -162,7 +164,7 @@ class FastAgent(Workflow):
|
||||
),
|
||||
},
|
||||
)
|
||||
return {"role": "user", "content": [{"text": user_text}]}
|
||||
return ChatMessage(role="user", content=user_text)
|
||||
|
||||
@step
|
||||
async def prepare_chat(self, ctx: Context, ev: StartEvent) -> FastAgentInputEvent:
|
||||
@@ -194,8 +196,8 @@ class FastAgent(Workflow):
|
||||
memory_text = "\n### Remembered Information:\n"
|
||||
for idx, item in enumerate(remembered_info, 1):
|
||||
memory_text += f"{idx}. {item}\n"
|
||||
self.shared_state.message_history[0]["content"].append(
|
||||
{"text": memory_text}
|
||||
self.shared_state.message_history[0].blocks.append(
|
||||
TextBlock(text=memory_text)
|
||||
)
|
||||
|
||||
return FastAgentInputEvent()
|
||||
@@ -250,7 +252,10 @@ class FastAgent(Workflow):
|
||||
ui_state = await self.state_provider.get_state()
|
||||
self.action_ctx.ui = ui_state
|
||||
|
||||
# Update shared state
|
||||
# Update shared state (previous ← current, current ← new)
|
||||
self.shared_state.previous_formatted_device_state = (
|
||||
self.shared_state.formatted_device_state
|
||||
)
|
||||
self.shared_state.formatted_device_state = ui_state.formatted_text
|
||||
self.shared_state.focused_text = ui_state.focused_text
|
||||
self.shared_state.a11y_tree = ui_state.elements
|
||||
@@ -265,11 +270,6 @@ class FastAgent(Workflow):
|
||||
# Stream formatted state for trajectory
|
||||
ctx.write_event_to_stream(RecordUIStateEvent(ui_state=ui_state.elements))
|
||||
|
||||
# Add device state to last user message
|
||||
self.shared_state.message_history[-1]["content"].append(
|
||||
{"text": f"\n{ui_state.formatted_text}\n"}
|
||||
)
|
||||
|
||||
except DeviceDisconnectedError:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -277,27 +277,51 @@ class FastAgent(Workflow):
|
||||
if self.debug:
|
||||
logger.error("State retrieval error details:", exc_info=True)
|
||||
|
||||
# Add screenshot to message if vision enabled
|
||||
if self.vision and screenshot:
|
||||
self.shared_state.message_history[-1]["content"].append(
|
||||
{"image": screenshot}
|
||||
)
|
||||
|
||||
# Limit history and prepare for LLM
|
||||
# Limit history and build ephemeral copy for LLM
|
||||
limited_history = limit_history(
|
||||
self.shared_state.message_history,
|
||||
LLM_HISTORY_LIMIT * 2,
|
||||
preserve_first=True,
|
||||
)
|
||||
messages_to_send = [self.system_prompt] + copy.deepcopy(limited_history)
|
||||
|
||||
# Build final messages: system + history
|
||||
messages_to_send = [self.system_prompt] + limited_history
|
||||
chat_messages = to_chat_messages(messages_to_send)
|
||||
# Inject device state and screenshot into the copy (not the original)
|
||||
user_indices = [
|
||||
i for i, msg in enumerate(messages_to_send) if msg.role == "user"
|
||||
]
|
||||
if user_indices:
|
||||
last_user_idx = user_indices[-1]
|
||||
|
||||
# Current device state → last user message
|
||||
current_state = self.shared_state.formatted_device_state.strip()
|
||||
if current_state:
|
||||
messages_to_send[last_user_idx].blocks.append(
|
||||
TextBlock(
|
||||
text=f"\n<device_state>\n{current_state}\n</device_state>\n"
|
||||
)
|
||||
)
|
||||
|
||||
# Screenshot → last user message
|
||||
if self.vision and screenshot:
|
||||
messages_to_send[last_user_idx].blocks.append(
|
||||
ImageBlock(image=screenshot)
|
||||
)
|
||||
|
||||
# Previous device state → second-to-last user message
|
||||
if len(user_indices) >= 2:
|
||||
second_last_idx = user_indices[-2]
|
||||
prev_state = self.shared_state.previous_formatted_device_state.strip()
|
||||
if prev_state:
|
||||
messages_to_send[second_last_idx].blocks.append(
|
||||
TextBlock(
|
||||
text=f"\n<previous_device_state>\n{prev_state}\n</previous_device_state>\n"
|
||||
)
|
||||
)
|
||||
|
||||
# Call LLM
|
||||
logger.info("FastAgent response:", extra={"color": "yellow"})
|
||||
response = await acall_with_retries(
|
||||
self.llm, chat_messages, stream=self.agent_config.streaming
|
||||
self.llm, messages_to_send, stream=self.agent_config.streaming
|
||||
)
|
||||
|
||||
if response is None:
|
||||
@@ -314,11 +338,9 @@ class FastAgent(Workflow):
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not get usage: {e}")
|
||||
|
||||
# Store assistant response
|
||||
# Store assistant response (preserves ThinkingBlock, additional_kwargs, etc.)
|
||||
self.shared_state.message_history.append(response.message)
|
||||
response_text = response.message.content
|
||||
self.shared_state.message_history.append(
|
||||
{"role": "assistant", "content": [{"text": response_text}]}
|
||||
)
|
||||
|
||||
# Parse tool calls from response
|
||||
thought, tool_calls = parse_tool_calls(response_text, self.param_types)
|
||||
@@ -364,7 +386,7 @@ class FastAgent(Workflow):
|
||||
"Now, describe the next step you will take to address the original goal."
|
||||
)
|
||||
self.shared_state.message_history.append(
|
||||
{"role": "user", "content": [{"text": no_thoughts_text}]}
|
||||
ChatMessage(role="user", content=no_thoughts_text)
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Reasoning: {ev.thought}")
|
||||
@@ -386,7 +408,7 @@ class FastAgent(Workflow):
|
||||
"</function_calls>"
|
||||
)
|
||||
self.shared_state.message_history.append(
|
||||
{"role": "user", "content": [{"text": no_tools_text}]}
|
||||
ChatMessage(role="user", content=no_tools_text)
|
||||
)
|
||||
return FastAgentInputEvent()
|
||||
|
||||
@@ -473,7 +495,7 @@ class FastAgent(Workflow):
|
||||
|
||||
# Add results as user message
|
||||
self.shared_state.message_history.append(
|
||||
{"role": "user", "content": [{"text": output}]}
|
||||
ChatMessage(role="user", content=output)
|
||||
)
|
||||
|
||||
return FastAgentInputEvent()
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from llama_index.core.base.llms.types import ChatMessage
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from droidrun.telemetry import PackageVisitEvent, capture
|
||||
@@ -86,9 +87,9 @@ class DroidAgentState(BaseModel):
|
||||
success: Optional[bool] = None
|
||||
|
||||
# ========================================================================
|
||||
# Message History (for stateful agents - list of dicts)
|
||||
# Message History (for stateful agents - preserves ChatMessage blocks)
|
||||
# ========================================================================
|
||||
message_history: List[Dict] = Field(default_factory=list)
|
||||
message_history: List[ChatMessage] = Field(default_factory=list)
|
||||
|
||||
# ========================================================================
|
||||
# Error Handling
|
||||
|
||||
@@ -14,18 +14,18 @@ import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from llama_index.core.base.llms.types import ChatMessage, ImageBlock, TextBlock
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
|
||||
|
||||
from droidrun.agent.executor.events import (
|
||||
ExecutorActionEvent,
|
||||
ExecutorActionResultEvent,
|
||||
ExecutorContextEvent,
|
||||
ExecutorResponseEvent,
|
||||
ExecutorActionResultEvent,
|
||||
)
|
||||
from droidrun.agent.executor.prompts import parse_executor_response
|
||||
from droidrun.agent.usage import get_usage_from_response
|
||||
from droidrun.agent.utils.chat_utils import to_chat_messages
|
||||
from droidrun.agent.utils.inference import acall_with_retries
|
||||
from droidrun.agent.utils.prompt_resolver import PromptResolver
|
||||
from droidrun.config_manager.config_manager import AgentConfig
|
||||
@@ -44,7 +44,7 @@ class ExecutorAgent(Workflow):
|
||||
Action execution agent that performs specific actions.
|
||||
|
||||
Single-turn agent: receives subgoal, selects action, executes it.
|
||||
Uses dict messages, converts to ChatMessage at LLM call time.
|
||||
Uses ChatMessage objects directly for LLM calls.
|
||||
"""
|
||||
|
||||
# Flow-control tools hidden from executor's LLM prompt
|
||||
@@ -123,14 +123,14 @@ class ExecutorAgent(Workflow):
|
||||
variables,
|
||||
)
|
||||
|
||||
# Build message as dict
|
||||
messages = [{"role": "user", "content": [{"text": prompt_text}]}]
|
||||
# Build message
|
||||
messages = [ChatMessage(role="user", blocks=[TextBlock(text=prompt_text)])]
|
||||
|
||||
# Add screenshot if vision enabled
|
||||
if self.vision:
|
||||
screenshot = self.shared_state.screenshot
|
||||
if screenshot is not None:
|
||||
messages[0]["content"].append({"image": screenshot})
|
||||
messages[0].blocks.append(ImageBlock(image=screenshot))
|
||||
logger.debug("📸 Using screenshot for Executor")
|
||||
else:
|
||||
logger.warning("⚠️ Vision enabled but no screenshot available")
|
||||
@@ -149,13 +149,10 @@ class ExecutorAgent(Workflow):
|
||||
# Get messages from context
|
||||
messages = await ctx.store.get("executor_messages")
|
||||
|
||||
# Convert to ChatMessage and call LLM
|
||||
chat_messages = to_chat_messages(messages)
|
||||
|
||||
try:
|
||||
logger.info("Executor response:", extra={"color": "green"})
|
||||
response = await acall_with_retries(
|
||||
self.llm, chat_messages, stream=self.agent_config.streaming
|
||||
self.llm, messages, stream=self.agent_config.streaming
|
||||
)
|
||||
response_text = str(response)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -16,6 +16,12 @@ import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
from llama_index.core.base.llms.types import (
|
||||
ChatMessage,
|
||||
ImageBlock,
|
||||
MessageRole,
|
||||
TextBlock,
|
||||
)
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.core.workflow import Context, StartEvent, StopEvent, Workflow, step
|
||||
from opentelemetry import trace
|
||||
@@ -29,10 +35,7 @@ from droidrun.agent.manager.events import (
|
||||
)
|
||||
from droidrun.agent.manager.prompts import parse_manager_response
|
||||
from droidrun.agent.usage import get_usage_from_response
|
||||
from droidrun.agent.utils.chat_utils import (
|
||||
filter_empty_messages,
|
||||
to_chat_messages,
|
||||
)
|
||||
from droidrun.agent.utils.chat_utils import filter_empty_messages
|
||||
from droidrun.agent.utils.inference import acall_with_retries
|
||||
from droidrun.agent.utils.prompt_resolver import PromptResolver
|
||||
from droidrun.agent.utils.signatures import ATOMIC_ACTION_SIGNATURES
|
||||
@@ -238,7 +241,7 @@ class ManagerAgent(Workflow):
|
||||
|
||||
def _build_messages_with_context(
|
||||
self, system_prompt: str, screenshot: bytes | None = None
|
||||
) -> list[dict]:
|
||||
) -> list[ChatMessage]:
|
||||
"""
|
||||
Build messages from history and inject current context.
|
||||
|
||||
@@ -247,17 +250,19 @@ class ManagerAgent(Workflow):
|
||||
screenshot: Current screenshot if vision enabled
|
||||
|
||||
Returns:
|
||||
List of message dicts ready for conversion
|
||||
List of ChatMessage objects ready for LLM
|
||||
"""
|
||||
|
||||
# Start with system message
|
||||
messages = [{"role": "system", "content": [{"text": system_prompt}]}]
|
||||
messages = [ChatMessage(role="system", content=system_prompt)]
|
||||
|
||||
# Add accumulated message history (deep copy to avoid mutation)
|
||||
messages.extend(copy.deepcopy(self.shared_state.message_history))
|
||||
|
||||
# Find last user message
|
||||
user_indices = [i for i, msg in enumerate(messages) if msg["role"] == "user"]
|
||||
user_indices = [
|
||||
i for i, msg in enumerate(messages) if msg.role == MessageRole.USER
|
||||
]
|
||||
|
||||
if user_indices:
|
||||
last_user_idx = user_indices[-1]
|
||||
@@ -265,20 +270,22 @@ class ManagerAgent(Workflow):
|
||||
# Add memory to last user message
|
||||
current_memory = (self.shared_state.manager_memory or "").strip()
|
||||
if current_memory:
|
||||
messages[last_user_idx]["content"].append(
|
||||
{"text": f"\n<memory>\n{current_memory}\n</memory>\n"}
|
||||
messages[last_user_idx].blocks.append(
|
||||
TextBlock(text=f"\n<memory>\n{current_memory}\n</memory>\n")
|
||||
)
|
||||
|
||||
# Add current device state
|
||||
current_state = self.shared_state.formatted_device_state.strip()
|
||||
if current_state:
|
||||
messages[last_user_idx]["content"].append(
|
||||
{"text": f"\n<device_state>\n{current_state}\n</device_state>\n"}
|
||||
messages[last_user_idx].blocks.append(
|
||||
TextBlock(
|
||||
text=f"\n<device_state>\n{current_state}\n</device_state>\n"
|
||||
)
|
||||
)
|
||||
|
||||
# Add screenshot if vision enabled
|
||||
if screenshot and self.vision:
|
||||
messages[last_user_idx]["content"].append({"image": screenshot})
|
||||
messages[last_user_idx].blocks.append(ImageBlock(image=screenshot))
|
||||
|
||||
# Add script result if available
|
||||
if self.shared_state.last_scripter_message:
|
||||
@@ -290,7 +297,7 @@ class ManagerAgent(Workflow):
|
||||
f"{self.shared_state.last_scripter_message}\n"
|
||||
f"</script_result>\n"
|
||||
)
|
||||
messages[last_user_idx]["content"].append({"text": script_context})
|
||||
messages[last_user_idx].blocks.append(TextBlock(text=script_context))
|
||||
self.shared_state.last_scripter_message = ""
|
||||
|
||||
# Add previous device state to second-to-last user message
|
||||
@@ -298,15 +305,17 @@ class ManagerAgent(Workflow):
|
||||
second_last_idx = user_indices[-2]
|
||||
prev_state = self.shared_state.previous_formatted_device_state.strip()
|
||||
if prev_state:
|
||||
messages[second_last_idx]["content"].append(
|
||||
{"text": f"\n<device_state>\n{prev_state}\n</device_state>\n"}
|
||||
messages[second_last_idx].blocks.append(
|
||||
TextBlock(
|
||||
text=f"\n<previous_device_state>\n{prev_state}\n</previous_device_state>\n"
|
||||
)
|
||||
)
|
||||
|
||||
messages = filter_empty_messages(messages)
|
||||
return messages
|
||||
|
||||
async def _validate_and_retry(
|
||||
self, messages: list[dict], initial_response: str
|
||||
self, messages: list[ChatMessage], initial_response: str
|
||||
) -> str:
|
||||
"""Validate LLM response and retry if needed."""
|
||||
output = initial_response
|
||||
@@ -350,15 +359,13 @@ class ManagerAgent(Workflow):
|
||||
|
||||
# Build retry messages
|
||||
retry_messages = messages + [
|
||||
{"role": "assistant", "content": [{"text": output}]},
|
||||
{"role": "user", "content": [{"text": error_message}]},
|
||||
ChatMessage(role="assistant", content=output),
|
||||
ChatMessage(role="user", content=error_message),
|
||||
]
|
||||
|
||||
chat_messages = to_chat_messages(retry_messages)
|
||||
|
||||
try:
|
||||
response = await acall_with_retries(
|
||||
self.llm, chat_messages, stream=self.agent_config.streaming
|
||||
self.llm, retry_messages, stream=self.agent_config.streaming
|
||||
)
|
||||
output = response.message.content
|
||||
parsed = parse_manager_response(output)
|
||||
@@ -449,7 +456,7 @@ class ManagerAgent(Workflow):
|
||||
# Build user message and add to history
|
||||
user_content = self._build_user_message_content()
|
||||
self.shared_state.message_history.append(
|
||||
{"role": "user", "content": [{"text": user_content}]}
|
||||
ChatMessage(role="user", content=user_content)
|
||||
)
|
||||
|
||||
event = ManagerContextEvent()
|
||||
@@ -474,13 +481,10 @@ class ManagerAgent(Workflow):
|
||||
system_prompt=system_prompt, screenshot=screenshot
|
||||
)
|
||||
|
||||
# Convert and call LLM
|
||||
chat_messages = to_chat_messages(messages)
|
||||
|
||||
try:
|
||||
logger.info("📋 Manager response:", extra={"color": "cyan"})
|
||||
response = await acall_with_retries(
|
||||
self.llm, chat_messages, stream=self.agent_config.streaming
|
||||
self.llm, messages, stream=self.agent_config.streaming
|
||||
)
|
||||
output = response.message.content
|
||||
except Exception as e:
|
||||
@@ -520,7 +524,7 @@ class ManagerAgent(Workflow):
|
||||
|
||||
# Append assistant response to message history
|
||||
self.shared_state.message_history.append(
|
||||
{"role": "assistant", "content": [{"text": output}]}
|
||||
ChatMessage(role="assistant", content=output)
|
||||
)
|
||||
|
||||
# Update unified state fields
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from PIL import Image
|
||||
from llama_index.core.base.llms.types import ChatMessage, ImageBlock, TextBlock
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger("droidrun")
|
||||
|
||||
@@ -94,22 +94,22 @@ def extract_code_and_thought(response_text: str) -> Tuple[Optional[str], str]:
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def has_content(message: dict) -> bool:
|
||||
for item in message.get("content", []):
|
||||
if "text" in item and item["text"].strip():
|
||||
def has_content(message: ChatMessage) -> bool:
|
||||
for block in message.blocks:
|
||||
if isinstance(block, TextBlock) and block.text and block.text.strip():
|
||||
return True
|
||||
if "image" in item and item["image"]:
|
||||
if isinstance(block, ImageBlock) and block.image:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def filter_empty_messages(messages: list[dict]) -> list[dict]:
|
||||
def filter_empty_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
|
||||
return [msg for msg in messages if has_content(msg)]
|
||||
|
||||
|
||||
def limit_history(
|
||||
messages: list[dict], max_messages: int, preserve_first: bool = True
|
||||
) -> list[dict]:
|
||||
messages: list[ChatMessage], max_messages: int, preserve_first: bool = True
|
||||
) -> list[ChatMessage]:
|
||||
if len(messages) <= max_messages:
|
||||
return messages
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Optional, Type, TypeVar
|
||||
from typing import Optional, Type, TypeVar
|
||||
|
||||
from llama_index.core.base.llms.types import (
|
||||
ChatMessage,
|
||||
@@ -106,8 +106,14 @@ async def _stream_response(llm, messages: list, timeout: float) -> ChatResponse:
|
||||
await asyncio.wait_for(stream_chunks(), timeout=timeout)
|
||||
|
||||
# Build response matching non-streaming format
|
||||
# Use last_chunk.message to preserve all blocks (ThinkingBlock, etc.)
|
||||
# that providers accumulate during streaming
|
||||
response = ChatResponse(
|
||||
message=ChatMessage(role="assistant", content=content),
|
||||
message=(
|
||||
last_chunk.message
|
||||
if last_chunk
|
||||
else ChatMessage(role="assistant", content=content)
|
||||
),
|
||||
raw=last_chunk.raw if last_chunk else None,
|
||||
additional_kwargs=last_chunk.additional_kwargs if last_chunk else {},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user