from dataclasses import asdict, dataclass from logging import getLogger from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, Union from smolagents.models import ChatMessage, MessageRole from smolagents.monitoring import AgentLogger, LogLevel from smolagents.utils import AgentError, make_json_serializable if TYPE_CHECKING: from smolagents.models import ChatMessage from smolagents.monitoring import AgentLogger logger = getLogger(__name__) class Message(TypedDict): role: MessageRole content: str | list[dict] @dataclass class ToolCall: name: str arguments: Any id: str def dict(self): return { "id": self.id, "type": "function", "function": { "name": self.name, "arguments": make_json_serializable(self.arguments), }, } @dataclass class MemoryStep: def dict(self): return asdict(self) def to_messages(self, **kwargs) -> List[Dict[str, Any]]: raise NotImplementedError @dataclass class ActionStep(MemoryStep): model_input_messages: List[Message] | None = None tool_calls: List[ToolCall] | None = None start_time: float | None = None end_time: float | None = None step_number: int | None = None error: AgentError | None = None duration: float | None = None model_output_message: ChatMessage = None model_output: str | None = None observations: str | None = None observations_images: List[str] | None = None action_output: Any = None def dict(self): # We overwrite the method to parse the tool_calls and action_output manually return { "model_input_messages": self.model_input_messages, "tool_calls": [tc.dict() for tc in self.tool_calls] if self.tool_calls else [], "start_time": self.start_time, "end_time": self.end_time, "step": self.step_number, "error": self.error.dict() if self.error else None, "duration": self.duration, "model_output_message": self.model_output_message, "model_output": self.model_output, "observations": self.observations, "action_output": make_json_serializable(self.action_output), } def to_messages(self, summary_mode: bool = False, show_model_input_messages: bool = False) -> List[Message]: messages = [] if self.model_input_messages is not None and show_model_input_messages: messages.append(Message(role=MessageRole.SYSTEM, content=self.model_input_messages)) if self.model_output is not None and not summary_mode: messages.append( Message(role=MessageRole.ASSISTANT, content=[{"type": "text", "text": self.model_output.strip()}]) ) if self.tool_calls is not None: messages.append( Message( role=MessageRole.ASSISTANT, content=[ { "type": "text", "text": "Calling tools:\n" + str([tc.dict() for tc in self.tool_calls]), } ], ) ) if self.observations is not None: messages.append( Message( role=MessageRole.TOOL_RESPONSE, content=[ { "type": "text", "text": f"Call id: {self.tool_calls[0].id}\nObservation:\n{self.observations}", } ], ) ) if self.error is not None: error_message = ( "Error:\n" + str(self.error) + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n" ) message_content = f"Call id: {self.tool_calls[0].id}\n" if self.tool_calls else "" message_content += error_message messages.append( Message(role=MessageRole.TOOL_RESPONSE, content=[{"type": "text", "text": message_content}]) ) if self.observations_images: messages.append( Message( role=MessageRole.USER, content=[{"type": "text", "text": "Here are the observed images:"}] + [ { "type": "image", "image": image, } for image in self.observations_images ], ) ) return messages @dataclass class PlanningStep(MemoryStep): model_input_messages: List[Message] model_output_message_facts: ChatMessage facts: str model_output_message_plan: ChatMessage plan: str def to_messages(self, summary_mode: bool, **kwargs) -> List[Message]: messages = [] messages.append( Message( role=MessageRole.ASSISTANT, content=[{"type": "text", "text": f"[FACTS LIST]:\n{self.facts.strip()}"}] ) ) if not summary_mode: # This step is not shown to a model writing a plan to avoid influencing the new plan messages.append( Message( role=MessageRole.ASSISTANT, content=[{"type": "text", "text": f"[PLAN]:\n{self.plan.strip()}"}] ) ) return messages @dataclass class TaskStep(MemoryStep): task: str task_images: List[str] | None = None def to_messages(self, summary_mode: bool = False, **kwargs) -> List[Message]: content = [{"type": "text", "text": f"New task:\n{self.task}"}] if self.task_images: for image in self.task_images: content.append({"type": "image", "image": image}) return [Message(role=MessageRole.USER, content=content)] @dataclass class SystemPromptStep(MemoryStep): system_prompt: str def to_messages(self, summary_mode: bool = False, **kwargs) -> List[Message]: if summary_mode: return [] return [Message(role=MessageRole.SYSTEM, content=[{"type": "text", "text": self.system_prompt}])] class AgentMemory: def __init__(self, system_prompt: str): self.system_prompt = SystemPromptStep(system_prompt=system_prompt) self.steps: List[Union[TaskStep, ActionStep, PlanningStep]] = [] def reset(self): self.steps = [] def get_succinct_steps(self) -> list[dict]: return [ {key: value for key, value in step.dict().items() if key != "model_input_messages"} for step in self.steps ] def get_full_steps(self) -> list[dict]: return [step.dict() for step in self.steps] def replay(self, logger: AgentLogger, detailed: bool = False): """Prints a pretty replay of the agent's steps. Args: logger (AgentLogger): The logger to print replay logs to. detailed (bool, optional): If True, also displays the memory at each step. Defaults to False. Careful: will increase log length exponentially. Use only for debugging. """ logger.console.log("Replaying the agent's steps:") for step in self.steps: if isinstance(step, SystemPromptStep) and detailed: logger.log_markdown(title="System prompt", content=step.system_prompt, level=LogLevel.ERROR) elif isinstance(step, TaskStep): logger.log_task(step.task, "", level=LogLevel.ERROR) elif isinstance(step, ActionStep): logger.log_rule(f"Step {step.step_number}", level=LogLevel.ERROR) if detailed: logger.log_messages(step.model_input_messages) logger.log_markdown(title="Agent output:", content=step.model_output, level=LogLevel.ERROR) elif isinstance(step, PlanningStep): logger.log_rule("Planning step", level=LogLevel.ERROR) if detailed: logger.log_messages(step.model_input_messages, level=LogLevel.ERROR) logger.log_markdown(title="Agent output:", content=step.facts + "\n" + step.plan, level=LogLevel.ERROR) __all__ = ["AgentMemory"]