Spaces:
Sleeping
Sleeping
File size: 7,823 Bytes
7c117ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
# coding: utf-8
"""
processor.py
Used to clean raw trace data into standard storage structure for reinforcement learning training.
"""
import json
import os
import datetime
from typing import Any
import threading
from aworld.utils import import_package
from aworld.replay_buffer.base import DataRow, Experience, ExpMeta
from aworld.logs.util import logger
from aworld.utils.common import get_local_ip
class ReplayBufferExporter:
def __init__(self):
"""Initialize ReplayBufferExporter instance"""
self._file_locks = {}
self._lock_dict_lock = threading.Lock()
self._task_output_paths = {}
def _get_file_lock(self, file_path):
"""Get the lock for the specified file"""
with self._lock_dict_lock:
if file_path not in self._file_locks:
self._file_locks[file_path] = threading.Lock()
return self._file_locks[file_path]
def replay_buffer_exporter(self, spans: list[dict[str, Any]], output_dir: str):
"""
Process spans, only process spans with 'step_execution_' prefix, and group by task_id to output to different files
Args:
spans: span data list
output_dir: output directory path
"""
# Ensure output directory exists
import_package("oss2")
import oss2
os.makedirs(output_dir, exist_ok=True)
# Get OSS credentials from environment variables
enable_oss_export = os.getenv("EXPORT_REPLAY_TRACE_TO_OSS", "false").lower() == "true"
access_key_id = os.getenv('OSS_ACCESS_KEY_ID')
access_key_secret = os.getenv('OSS_ACCESS_KEY_SECRET')
endpoint = os.getenv('OSS_ENDPOINT')
bucket_name = os.getenv('OSS_BUCKET_NAME')
bucket = None
if not all([access_key_id, access_key_secret, endpoint, bucket_name]):
enable_oss_export = False
logger.warn("Missing required OSS environment variables")
else:
try:
# Initialize OSS client
auth = oss2.Auth(access_key_id, access_key_secret)
bucket = oss2.Bucket(auth, endpoint, bucket_name)
except Exception as e:
enable_oss_export = False
logger.warn(f"Failed to initialize OSS client, endpoint: {endpoint}, bucket: {bucket_name}. Error: {str(e)}")
# Group by task_id
task_groups = {}
for span_data in spans:
# Only process spans with 'step_execution_' prefix
if not span_data['name'].startswith('step_execution_'):
continue
attr = span_data.get('attributes', {})
exp_id = attr.get('exp_id')
task_id = attr.get('task_id', '')
if not exp_id or not task_id:
continue
if task_id not in task_groups:
task_groups[task_id] = {}
if exp_id not in task_groups[task_id]:
task_groups[task_id][exp_id] = {
'exp_meta': None,
'exp_data': None
}
# Process step_execution span
task_name = attr.get('task_name', '')
agent_id = attr.get('agent_id', '')
step = attr.get('step', 0)
execute_time = float(span_data.get('start_time', 0).split('.')[0].replace(' ', '').replace('-', '').replace(':', ''))
observation = {}
action = []
messages = []
pre_agent = None
if 'observation' in attr:
try:
observation = json.loads(attr['observation'])
except:
observation = attr['observation']
if 'actions' in attr:
try:
action = json.loads(attr['actions'])
except:
action = attr['actions']
if 'messages' in attr:
try:
messages = json.loads(attr['messages'])
except:
messages = attr['messages']
pre_agent = attr.get('pre_agent', '')
reward = attr.get('reward', 0.0)
adv = attr.get('adv_t', 0.0)
v = attr.get('v_t', 0.0)
exp_meta = ExpMeta(task_id, task_name, agent_id, step, execute_time, pre_agent)
exp_data = Experience(observation, action, reward, adv, v, messages)
task_groups[task_id][exp_id]['exp_meta'] = exp_meta
task_groups[task_id][exp_id]['exp_data'] = exp_data
# Process data for each task_id
for task_id, exp_groups in task_groups.items():
# Merge data and generate final Experience object
data_rows = []
# Read existing data (if any)
output_path = self._task_output_paths.get(task_id)
if not output_path:
timestamp = datetime.datetime.now().strftime("%Y%m%d")
replay_dir = os.path.join(output_dir or "./trace_data", timestamp, get_local_ip(), "replays")
replay_dataset_path = os.getenv("REPLAY_TRACE_DATASET_PATH", replay_dir)
export_dir = os.path.abspath(replay_dataset_path)
os.makedirs(export_dir, exist_ok=True)
output_path = os.path.join(export_dir, f"task_replay_{task_id}.json")
self._task_output_paths[task_id] = output_path
# Use thread lock to protect read and write operations
file_lock = self._get_file_lock(output_path)
with file_lock:
if os.path.exists(output_path):
try:
with open(output_path, 'r', encoding='utf-8') as f:
existing_data = json.load(f)
data_rows.extend([DataRow(
ExpMeta(**row['exp_meta']),
Experience(**row['exp_data']),
row['id']
) for row in existing_data])
except Exception as e:
print(f"Failed to read existing file {output_path}: {str(e)}")
# Add new data
for exp_id, group in exp_groups.items():
if group['exp_meta'] and group['exp_data']:
row = DataRow(group['exp_meta'], group['exp_data'], exp_id)
data_rows.append(row)
# Sort by execute_time
data_rows.sort(key=lambda x: x.exp_meta.execute_time)
# Export to json
with open(output_path, 'w', encoding='utf-8') as f:
json.dump([row.to_dict() for row in data_rows], f, ensure_ascii=False, indent=2)
logger.info(f"Processing completed, exported {len(data_rows)} experiences to {output_path}")
if enable_oss_export:
# Upload to OSS
try:
# Get the relative path
abs_path = os.path.abspath(output_path)
path_parts = abs_path.split(os.sep)
if len(path_parts) >= 4:
# Get the last 4 parts of the path
relative_path = os.sep.join(path_parts[-4:])
oss_key = relative_path
else:
oss_key = f"replay_buffer/{os.path.basename(output_path)}"
bucket.put_object_from_file(oss_key, output_path)
logger.info(f"Successfully uploaded {output_path} to OSS: {oss_key}")
except Exception as e:
logger.warn(f"Failed to upload {output_path} to OSS: {str(e)}")
|