File size: 7,823 Bytes
7c117ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# coding: utf-8
"""
processor.py
Used to clean raw trace data into standard storage structure for reinforcement learning training.
"""
import json
import os
import datetime
from typing import Any
import threading

from aworld.utils import import_package
from aworld.replay_buffer.base import DataRow, Experience, ExpMeta
from aworld.logs.util import logger
from aworld.utils.common import get_local_ip


class ReplayBufferExporter:
    def __init__(self):
        """Initialize ReplayBufferExporter instance"""
        self._file_locks = {}
        self._lock_dict_lock = threading.Lock()
        self._task_output_paths = {}

    def _get_file_lock(self, file_path):
        """Get the lock for the specified file"""
        with self._lock_dict_lock:
            if file_path not in self._file_locks:
                self._file_locks[file_path] = threading.Lock()
            return self._file_locks[file_path]

    def replay_buffer_exporter(self, spans: list[dict[str, Any]], output_dir: str):
        """
        Process spans, only process spans with 'step_execution_' prefix, and group by task_id to output to different files

        Args:
            spans: span data list
            output_dir: output directory path
        """
        # Ensure output directory exists
        import_package("oss2")
        import oss2

        os.makedirs(output_dir, exist_ok=True)

        # Get OSS credentials from environment variables
        enable_oss_export = os.getenv("EXPORT_REPLAY_TRACE_TO_OSS", "false").lower() == "true"
        access_key_id = os.getenv('OSS_ACCESS_KEY_ID')
        access_key_secret = os.getenv('OSS_ACCESS_KEY_SECRET')
        endpoint = os.getenv('OSS_ENDPOINT')
        bucket_name = os.getenv('OSS_BUCKET_NAME')
        bucket = None

        if not all([access_key_id, access_key_secret, endpoint, bucket_name]):
            enable_oss_export = False
            logger.warn("Missing required OSS environment variables")
        else:
            try:
                # Initialize OSS client
                auth = oss2.Auth(access_key_id, access_key_secret)
                bucket = oss2.Bucket(auth, endpoint, bucket_name)
            except Exception as e:
                enable_oss_export = False
                logger.warn(f"Failed to initialize OSS client, endpoint: {endpoint}, bucket: {bucket_name}. Error: {str(e)}")

        # Group by task_id
        task_groups = {}

        for span_data in spans:
            # Only process spans with 'step_execution_' prefix
            if not span_data['name'].startswith('step_execution_'):
                continue

            attr = span_data.get('attributes', {})
            exp_id = attr.get('exp_id')
            task_id = attr.get('task_id', '')

            if not exp_id or not task_id:
                continue

            if task_id not in task_groups:
                task_groups[task_id] = {}

            if exp_id not in task_groups[task_id]:
                task_groups[task_id][exp_id] = {
                    'exp_meta': None,
                    'exp_data': None
                }

            # Process step_execution span
            task_name = attr.get('task_name', '')
            agent_id = attr.get('agent_id', '')
            step = attr.get('step', 0)
            execute_time = float(span_data.get('start_time', 0).split('.')[0].replace(' ', '').replace('-', '').replace(':', ''))

            observation = {}
            action = []
            messages = []
            pre_agent = None
            if 'observation' in attr:
                try:
                    observation = json.loads(attr['observation'])
                except:
                    observation = attr['observation']

            if 'actions' in attr:
                try:
                    action = json.loads(attr['actions'])
                except:
                    action = attr['actions']

            if 'messages' in attr:
                try:
                    messages = json.loads(attr['messages'])
                except:
                    messages = attr['messages']

            pre_agent = attr.get('pre_agent', '')
            reward = attr.get('reward', 0.0)
            adv = attr.get('adv_t', 0.0)
            v = attr.get('v_t', 0.0)

            exp_meta = ExpMeta(task_id, task_name, agent_id, step, execute_time, pre_agent)
            exp_data = Experience(observation, action, reward, adv, v, messages)

            task_groups[task_id][exp_id]['exp_meta'] = exp_meta
            task_groups[task_id][exp_id]['exp_data'] = exp_data

        # Process data for each task_id
        for task_id, exp_groups in task_groups.items():
            # Merge data and generate final Experience object
            data_rows = []

            # Read existing data (if any)
            output_path = self._task_output_paths.get(task_id)
            if not output_path:
                timestamp = datetime.datetime.now().strftime("%Y%m%d")
                replay_dir = os.path.join(output_dir or "./trace_data", timestamp, get_local_ip(), "replays")
                replay_dataset_path = os.getenv("REPLAY_TRACE_DATASET_PATH", replay_dir)
                export_dir = os.path.abspath(replay_dataset_path)
                os.makedirs(export_dir, exist_ok=True)
                output_path = os.path.join(export_dir, f"task_replay_{task_id}.json")
                self._task_output_paths[task_id] = output_path

            # Use thread lock to protect read and write operations
            file_lock = self._get_file_lock(output_path)
            with file_lock:
                if os.path.exists(output_path):
                    try:
                        with open(output_path, 'r', encoding='utf-8') as f:
                            existing_data = json.load(f)
                            data_rows.extend([DataRow(
                                ExpMeta(**row['exp_meta']),
                                Experience(**row['exp_data']),
                                row['id']
                            ) for row in existing_data])
                    except Exception as e:
                        print(f"Failed to read existing file {output_path}: {str(e)}")

                # Add new data
                for exp_id, group in exp_groups.items():
                    if group['exp_meta'] and group['exp_data']:
                        row = DataRow(group['exp_meta'], group['exp_data'], exp_id)
                        data_rows.append(row)

                # Sort by execute_time
                data_rows.sort(key=lambda x: x.exp_meta.execute_time)

                # Export to json
                with open(output_path, 'w', encoding='utf-8') as f:
                    json.dump([row.to_dict() for row in data_rows], f, ensure_ascii=False, indent=2)
                logger.info(f"Processing completed, exported {len(data_rows)} experiences to {output_path}")

                if enable_oss_export:
                    # Upload to OSS
                    try:
                        # Get the relative path
                        abs_path = os.path.abspath(output_path)
                        path_parts = abs_path.split(os.sep)
                        if len(path_parts) >= 4:
                            # Get the last 4 parts of the path
                            relative_path = os.sep.join(path_parts[-4:])
                            oss_key = relative_path
                        else:
                            oss_key = f"replay_buffer/{os.path.basename(output_path)}"
                        bucket.put_object_from_file(oss_key, output_path)
                        logger.info(f"Successfully uploaded {output_path} to OSS: {oss_key}")
                    except Exception as e:
                        logger.warn(f"Failed to upload {output_path} to OSS: {str(e)}")