Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import asyncio | |
import logging | |
import time | |
from typing import ( | |
TYPE_CHECKING, | |
Any, | |
AsyncIterator, | |
Dict, | |
Iterator, | |
List, | |
Optional, | |
Tuple, | |
Union, | |
) | |
from uuid import UUID | |
from langchain_core.agents import ( | |
AgentAction, | |
AgentFinish, | |
AgentStep, | |
) | |
from langchain_core.callbacks import ( | |
AsyncCallbackManager, | |
AsyncCallbackManagerForChainRun, | |
CallbackManager, | |
CallbackManagerForChainRun, | |
Callbacks, | |
) | |
from langchain_core.load.dump import dumpd | |
from langchain_core.outputs import RunInfo | |
from langchain_core.runnables.utils import AddableDict | |
from langchain_core.tools import BaseTool | |
from langchain_core.utils.input import get_color_mapping | |
from langchain.schema import RUN_KEY | |
from langchain.utilities.asyncio import asyncio_timeout | |
if TYPE_CHECKING: | |
from langchain.agents.agent import AgentExecutor, NextStepOutput | |
logger = logging.getLogger(__name__) | |
class AgentExecutorIterator: | |
"""Iterator for AgentExecutor.""" | |
def __init__( | |
self, | |
agent_executor: AgentExecutor, | |
inputs: Any, | |
callbacks: Callbacks = None, | |
*, | |
tags: Optional[list[str]] = None, | |
metadata: Optional[Dict[str, Any]] = None, | |
run_name: Optional[str] = None, | |
run_id: Optional[UUID] = None, | |
include_run_info: bool = False, | |
yield_actions: bool = False, | |
): | |
""" | |
Initialize the AgentExecutorIterator with the given AgentExecutor, | |
inputs, and optional callbacks. | |
""" | |
self._agent_executor = agent_executor | |
self.inputs = inputs | |
self.callbacks = callbacks | |
self.tags = tags | |
self.metadata = metadata | |
self.run_name = run_name | |
self.run_id = run_id | |
self.include_run_info = include_run_info | |
self.yield_actions = yield_actions | |
self.reset() | |
_inputs: Dict[str, str] | |
callbacks: Callbacks | |
tags: Optional[list[str]] | |
metadata: Optional[Dict[str, Any]] | |
run_name: Optional[str] | |
run_id: Optional[UUID] | |
include_run_info: bool | |
yield_actions: bool | |
def inputs(self) -> Dict[str, str]: | |
return self._inputs | |
def inputs(self, inputs: Any) -> None: | |
self._inputs = self.agent_executor.prep_inputs(inputs) | |
def agent_executor(self) -> AgentExecutor: | |
return self._agent_executor | |
def agent_executor(self, agent_executor: AgentExecutor) -> None: | |
self._agent_executor = agent_executor | |
# force re-prep inputs in case agent_executor's prep_inputs fn changed | |
self.inputs = self.inputs | |
def name_to_tool_map(self) -> Dict[str, BaseTool]: | |
return {tool.name: tool for tool in self.agent_executor.tools} | |
def color_mapping(self) -> Dict[str, str]: | |
return get_color_mapping( | |
[tool.name for tool in self.agent_executor.tools], | |
excluded_colors=["green", "red"], | |
) | |
def reset(self) -> None: | |
""" | |
Reset the iterator to its initial state, clearing intermediate steps, | |
iterations, and time elapsed. | |
""" | |
logger.debug("(Re)setting AgentExecutorIterator to fresh state") | |
self.intermediate_steps: list[tuple[AgentAction, str]] = [] | |
self.iterations = 0 | |
# maybe better to start these on the first __anext__ call? | |
self.time_elapsed = 0.0 | |
self.start_time = time.time() | |
def update_iterations(self) -> None: | |
""" | |
Increment the number of iterations and update the time elapsed. | |
""" | |
self.iterations += 1 | |
self.time_elapsed = time.time() - self.start_time | |
logger.debug( | |
f"Agent Iterations: {self.iterations} ({self.time_elapsed:.2f}s elapsed)" | |
) | |
def make_final_outputs( | |
self, | |
outputs: Dict[str, Any], | |
run_manager: Union[CallbackManagerForChainRun, AsyncCallbackManagerForChainRun], | |
) -> AddableDict: | |
# have access to intermediate steps by design in iterator, | |
# so return only outputs may as well always be true. | |
prepared_outputs = AddableDict( | |
self.agent_executor.prep_outputs( | |
self.inputs, outputs, return_only_outputs=True | |
) | |
) | |
if self.include_run_info: | |
prepared_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) | |
return prepared_outputs | |
def __iter__(self: "AgentExecutorIterator") -> Iterator[AddableDict]: | |
logger.debug("Initialising AgentExecutorIterator") | |
self.reset() | |
callback_manager = CallbackManager.configure( | |
self.callbacks, | |
self.agent_executor.callbacks, | |
self.agent_executor.verbose, | |
self.tags, | |
self.agent_executor.tags, | |
self.metadata, | |
self.agent_executor.metadata, | |
) | |
run_manager = callback_manager.on_chain_start( | |
dumpd(self.agent_executor), | |
self.inputs, | |
self.run_id, | |
name=self.run_name, | |
) | |
try: | |
while self.agent_executor._should_continue( | |
self.iterations, self.time_elapsed | |
): | |
# take the next step: this plans next action, executes it, | |
# yielding action and observation as they are generated | |
next_step_seq: NextStepOutput = [] | |
for chunk in self.agent_executor._iter_next_step( | |
self.name_to_tool_map, | |
self.color_mapping, | |
self.inputs, | |
self.intermediate_steps, | |
run_manager, | |
): | |
next_step_seq.append(chunk) | |
# if we're yielding actions, yield them as they come | |
# do not yield AgentFinish, which will be handled below | |
if self.yield_actions: | |
if isinstance(chunk, AgentAction): | |
yield AddableDict(actions=[chunk], messages=chunk.messages) | |
elif isinstance(chunk, AgentStep): | |
yield AddableDict(steps=[chunk], messages=chunk.messages) | |
# convert iterator output to format handled by _process_next_step_output | |
next_step = self.agent_executor._consume_next_step(next_step_seq) | |
# update iterations and time elapsed | |
self.update_iterations() | |
# decide if this is the final output | |
output = self._process_next_step_output(next_step, run_manager) | |
is_final = "intermediate_step" not in output | |
# yield the final output always | |
# for backwards compat, yield int. output if not yielding actions | |
if not self.yield_actions or is_final: | |
yield output | |
# if final output reached, stop iteration | |
if is_final: | |
return | |
except BaseException as e: | |
run_manager.on_chain_error(e) | |
raise | |
# if we got here means we exhausted iterations or time | |
yield self._stop(run_manager) | |
async def __aiter__(self) -> AsyncIterator[AddableDict]: | |
""" | |
N.B. __aiter__ must be a normal method, so need to initialize async run manager | |
on first __anext__ call where we can await it | |
""" | |
logger.debug("Initialising AgentExecutorIterator (async)") | |
self.reset() | |
callback_manager = AsyncCallbackManager.configure( | |
self.callbacks, | |
self.agent_executor.callbacks, | |
self.agent_executor.verbose, | |
self.tags, | |
self.agent_executor.tags, | |
self.metadata, | |
self.agent_executor.metadata, | |
) | |
run_manager = await callback_manager.on_chain_start( | |
dumpd(self.agent_executor), | |
self.inputs, | |
self.run_id, | |
name=self.run_name, | |
) | |
try: | |
async with asyncio_timeout(self.agent_executor.max_execution_time): | |
while self.agent_executor._should_continue( | |
self.iterations, self.time_elapsed | |
): | |
# take the next step: this plans next action, executes it, | |
# yielding action and observation as they are generated | |
next_step_seq: NextStepOutput = [] | |
async for chunk in self.agent_executor._aiter_next_step( | |
self.name_to_tool_map, | |
self.color_mapping, | |
self.inputs, | |
self.intermediate_steps, | |
run_manager, | |
): | |
next_step_seq.append(chunk) | |
# if we're yielding actions, yield them as they come | |
# do not yield AgentFinish, which will be handled below | |
if self.yield_actions: | |
if isinstance(chunk, AgentAction): | |
yield AddableDict( | |
actions=[chunk], messages=chunk.messages | |
) | |
elif isinstance(chunk, AgentStep): | |
yield AddableDict( | |
steps=[chunk], messages=chunk.messages | |
) | |
# convert iterator output to format handled by _process_next_step | |
next_step = self.agent_executor._consume_next_step(next_step_seq) | |
# update iterations and time elapsed | |
self.update_iterations() | |
# decide if this is the final output | |
output = await self._aprocess_next_step_output( | |
next_step, run_manager | |
) | |
is_final = "intermediate_step" not in output | |
# yield the final output always | |
# for backwards compat, yield int. output if not yielding actions | |
if not self.yield_actions or is_final: | |
yield output | |
# if final output reached, stop iteration | |
if is_final: | |
return | |
except (TimeoutError, asyncio.TimeoutError): | |
yield await self._astop(run_manager) | |
return | |
except BaseException as e: | |
await run_manager.on_chain_error(e) | |
raise | |
# if we got here means we exhausted iterations or time | |
yield await self._astop(run_manager) | |
def _process_next_step_output( | |
self, | |
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]], | |
run_manager: CallbackManagerForChainRun, | |
) -> AddableDict: | |
""" | |
Process the output of the next step, | |
handling AgentFinish and tool return cases. | |
""" | |
logger.debug("Processing output of Agent loop step") | |
if isinstance(next_step_output, AgentFinish): | |
logger.debug( | |
"Hit AgentFinish: _return -> on_chain_end -> run final output logic" | |
) | |
return self._return(next_step_output, run_manager=run_manager) | |
self.intermediate_steps.extend(next_step_output) | |
logger.debug("Updated intermediate_steps with step output") | |
# Check for tool return | |
if len(next_step_output) == 1: | |
next_step_action = next_step_output[0] | |
tool_return = self.agent_executor._get_tool_return(next_step_action) | |
if tool_return is not None: | |
return self._return(tool_return, run_manager=run_manager) | |
return AddableDict(intermediate_step=next_step_output) | |
async def _aprocess_next_step_output( | |
self, | |
next_step_output: Union[AgentFinish, List[Tuple[AgentAction, str]]], | |
run_manager: AsyncCallbackManagerForChainRun, | |
) -> AddableDict: | |
""" | |
Process the output of the next async step, | |
handling AgentFinish and tool return cases. | |
""" | |
logger.debug("Processing output of async Agent loop step") | |
if isinstance(next_step_output, AgentFinish): | |
logger.debug( | |
"Hit AgentFinish: _areturn -> on_chain_end -> run final output logic" | |
) | |
return await self._areturn(next_step_output, run_manager=run_manager) | |
self.intermediate_steps.extend(next_step_output) | |
logger.debug("Updated intermediate_steps with step output") | |
# Check for tool return | |
if len(next_step_output) == 1: | |
next_step_action = next_step_output[0] | |
tool_return = self.agent_executor._get_tool_return(next_step_action) | |
if tool_return is not None: | |
return await self._areturn(tool_return, run_manager=run_manager) | |
return AddableDict(intermediate_step=next_step_output) | |
def _stop(self, run_manager: CallbackManagerForChainRun) -> AddableDict: | |
""" | |
Stop the iterator and raise a StopIteration exception with the stopped response. | |
""" | |
logger.warning("Stopping agent prematurely due to triggering stop condition") | |
# this manually constructs agent finish with output key | |
output = self.agent_executor.agent.return_stopped_response( | |
self.agent_executor.early_stopping_method, | |
self.intermediate_steps, | |
**self.inputs, | |
) | |
return self._return(output, run_manager=run_manager) | |
async def _astop(self, run_manager: AsyncCallbackManagerForChainRun) -> AddableDict: | |
""" | |
Stop the async iterator and raise a StopAsyncIteration exception with | |
the stopped response. | |
""" | |
logger.warning("Stopping agent prematurely due to triggering stop condition") | |
output = self.agent_executor.agent.return_stopped_response( | |
self.agent_executor.early_stopping_method, | |
self.intermediate_steps, | |
**self.inputs, | |
) | |
return await self._areturn(output, run_manager=run_manager) | |
def _return( | |
self, output: AgentFinish, run_manager: CallbackManagerForChainRun | |
) -> AddableDict: | |
""" | |
Return the final output of the iterator. | |
""" | |
returned_output = self.agent_executor._return( | |
output, self.intermediate_steps, run_manager=run_manager | |
) | |
returned_output["messages"] = output.messages | |
run_manager.on_chain_end(returned_output) | |
return self.make_final_outputs(returned_output, run_manager) | |
async def _areturn( | |
self, output: AgentFinish, run_manager: AsyncCallbackManagerForChainRun | |
) -> AddableDict: | |
""" | |
Return the final output of the async iterator. | |
""" | |
returned_output = await self.agent_executor._areturn( | |
output, self.intermediate_steps, run_manager=run_manager | |
) | |
returned_output["messages"] = output.messages | |
await run_manager.on_chain_end(returned_output) | |
return self.make_final_outputs(returned_output, run_manager) | |