File size: 2,477 Bytes
7d18ad1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# coding: utf-8
# Copyright (c) 2025 inclusionAI.
import json
from typing import List, Tuple, Dict, Any

from aworld.core.tool.base import Tool
from aworld.core.common import Observation, ActionModel, Config
from aworld.logs.util import logger
from aworld.tools.utils import build_observation


class TemplateTool(Tool):
    def __init__(self, conf: Config, **kwargs) -> None:
        super(TemplateTool, self).__init__(conf, **kwargs)

    def reset(self, *, seed: int | None = None, options: Dict[str, str] | None = None) -> Tuple[
        Observation, dict[str, Any]]:
        # from options obtain user query
        return build_observation(observer=self.name(),
                                 ability='',
                                 content=options.get("query", None) if options else None), {}

    def do_step(self,
                action: List[ActionModel],
                **kwargs) -> Tuple[Observation, float, bool, bool, Dict[str, Any]]:
        reward = 0
        fail_error = ""
        action_result = None

        invalid_acts: List[int] = []
        for i, act in enumerate(action):
            if act.tool_name != self.name():
                logger.warning(f"tool {act.tool_name} is not a {self.name()} tool!")
                invalid_acts.append(i)

        if invalid_acts:
            for i in invalid_acts:
                action[i] = None

        resp = ""
        try:
            action_result, resp = self.action_executor.execute_action(action, **kwargs)
            reward = 1
        except Exception as e:
            fail_error = str(e)

        terminated = kwargs.get("terminated", False)
        for res in action_result:
            if res.is_done:
                terminated = res.is_done
                self._finished = True

        info = {"exception": fail_error}
        info.update(kwargs)
        if resp:
            resp = json.dumps(resp)
        else:
            resp = action_result[0].content

        observation = build_observation(observer=self.name(),
                                        action_result=action_result,
                                        ability=action[-1].action_name,
                                        content=resp)
        return (observation,
                reward,
                terminated,
                kwargs.get("truncated", False),
                info)

    def close(self) -> None:
        pass

    def finished(self) -> bool:
        # one time
        return True