File size: 2,042 Bytes
cc54e11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# coding: utf-8
# Copyright (c) 2025 inclusionAI.
from typing import List, Dict, Any, Callable

from aworld.agents.llm_agent import Agent
from aworld.config import RunConfig
from aworld.core.common import Observation, ActionModel, Config


class ParallelizableAgent(Agent):
    """Support for parallel agents in the swarm.

    The parameters of the extension function are the agent itself, which can obtain internal information of the agent.
    `aggregate_func` function example:
    >>> def agg(agent: ParallelizableAgent, res: Dict[str, List[ActionModel]]):
    >>>     ...
    """

    def __init__(self,
                 conf: Config,
                 resp_parse_func: Callable[..., Any] = None,
                 agents: List[Agent] = [],
                 aggregate_func: Callable[..., Any] = None,
                 **kwargs):
        super().__init__(conf=conf, resp_parse_func=resp_parse_func, **kwargs)
        self.agents = agents
        # The function of aggregating the results of the parallel execution of agents.
        self.aggregate_func = aggregate_func

    async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}, **kwargs) -> List[ActionModel]:
        from aworld.core.task import Task
        from aworld.runners.utils import choose_runners, execute_runner

        tasks = []
        if self.agents:
            for agent in self.agents:
                tasks.append(Task(input=observation, agent=agent, context=self.context))

        if not tasks:
            raise RuntimeError("no task need to run in parallelizable agent.")

        runners = await choose_runners(tasks)
        res = await execute_runner(runners, RunConfig(reuse_process=False))

        if self.aggregate_func:
            return self.aggregate_func(self, res)

        results = []
        for k, v in res.items():
            results.append(ActionModel(agent_name=self.id(), policy_info=v.answer))
        return results

    def finished(self) -> bool:
        return all([agent.finished for agent in self.agents])