File size: 2,535 Bytes
7d18ad1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# coding: utf-8
# Copyright (c) 2025 inclusionAI.
import json
from typing import List, Tuple, Dict, Any

from aworld.core.tool.base import AsyncTool
from aworld.core.common import Observation, ActionModel, Config
from aworld.logs.util import logger
from aworld.tools.utils import build_observation


class TemplateTool(AsyncTool):
    def __init__(self, conf: Config, **kwargs) -> None:
        super(TemplateTool, self).__init__(conf, **kwargs)

    async def reset(self, *, seed: int | None = None, options: Dict[str, str] | None = None) -> Tuple[
        Observation, dict[str, Any]]:
        # from options obtain user query
        return build_observation(observer=self.name(),
                                 ability='',
                                 content=options.get("query", None) if options else None), {}

    async def do_step(self,
                      action: List[ActionModel],
                      **kwargs) -> Tuple[Observation, float, bool, bool, Dict[str, Any]]:
        reward = 0
        fail_error = ""
        action_result = None

        invalid_acts: List[int] = []
        for i, act in enumerate(action):
            if act.tool_name != self.name():
                logger.warning(f"tool {act.tool_name} is not a {self.name()} tool!")
                invalid_acts.append(i)

        if invalid_acts:
            for i in invalid_acts:
                action[i] = None

        resp = ""
        try:
            action_result, resp = await self.action_executor.async_execute_action(action, **kwargs)
            reward = 1
        except Exception as e:
            fail_error = str(e)

        terminated = kwargs.get("terminated", False)
        for res in action_result:
            if res.is_done:
                terminated = res.is_done
                self._finished = True

        info = {"exception": fail_error}
        info.update(kwargs)
        if resp:
            resp = json.dumps(resp)
        else:
            resp = action_result[0].content

        observation = build_observation(observer=self.name(),
                                        action_result=action_result,
                                        ability=action[-1].action_name,
                                        content=resp)
        return (observation,
                reward,
                terminated,
                kwargs.get("truncated", False),
                info)

    async def close(self) -> None:
        pass

    async def finished(self) -> bool:
        # one time
        return True