Spaces:
Sleeping
Sleeping
# coding: utf-8 | |
""" | |
processor.py | |
Used to clean raw trace data into standard storage structure for reinforcement learning training. | |
""" | |
import json | |
import os | |
import datetime | |
from typing import Any | |
import threading | |
from aworld.utils import import_package | |
from aworld.replay_buffer.base import DataRow, Experience, ExpMeta | |
from aworld.logs.util import logger | |
from aworld.utils.common import get_local_ip | |
class ReplayBufferExporter: | |
def __init__(self): | |
"""Initialize ReplayBufferExporter instance""" | |
self._file_locks = {} | |
self._lock_dict_lock = threading.Lock() | |
self._task_output_paths = {} | |
def _get_file_lock(self, file_path): | |
"""Get the lock for the specified file""" | |
with self._lock_dict_lock: | |
if file_path not in self._file_locks: | |
self._file_locks[file_path] = threading.Lock() | |
return self._file_locks[file_path] | |
def replay_buffer_exporter(self, spans: list[dict[str, Any]], output_dir: str): | |
""" | |
Process spans, only process spans with 'step_execution_' prefix, and group by task_id to output to different files | |
Args: | |
spans: span data list | |
output_dir: output directory path | |
""" | |
# Ensure output directory exists | |
import_package("oss2") | |
import oss2 | |
os.makedirs(output_dir, exist_ok=True) | |
# Get OSS credentials from environment variables | |
enable_oss_export = os.getenv("EXPORT_REPLAY_TRACE_TO_OSS", "false").lower() == "true" | |
access_key_id = os.getenv('OSS_ACCESS_KEY_ID') | |
access_key_secret = os.getenv('OSS_ACCESS_KEY_SECRET') | |
endpoint = os.getenv('OSS_ENDPOINT') | |
bucket_name = os.getenv('OSS_BUCKET_NAME') | |
bucket = None | |
if not all([access_key_id, access_key_secret, endpoint, bucket_name]): | |
enable_oss_export = False | |
logger.warn("Missing required OSS environment variables") | |
else: | |
try: | |
# Initialize OSS client | |
auth = oss2.Auth(access_key_id, access_key_secret) | |
bucket = oss2.Bucket(auth, endpoint, bucket_name) | |
except Exception as e: | |
enable_oss_export = False | |
logger.warn(f"Failed to initialize OSS client, endpoint: {endpoint}, bucket: {bucket_name}. Error: {str(e)}") | |
# Group by task_id | |
task_groups = {} | |
for span_data in spans: | |
# Only process spans with 'step_execution_' prefix | |
if not span_data['name'].startswith('step_execution_'): | |
continue | |
attr = span_data.get('attributes', {}) | |
exp_id = attr.get('exp_id') | |
task_id = attr.get('task_id', '') | |
if not exp_id or not task_id: | |
continue | |
if task_id not in task_groups: | |
task_groups[task_id] = {} | |
if exp_id not in task_groups[task_id]: | |
task_groups[task_id][exp_id] = { | |
'exp_meta': None, | |
'exp_data': None | |
} | |
# Process step_execution span | |
task_name = attr.get('task_name', '') | |
agent_id = attr.get('agent_id', '') | |
step = attr.get('step', 0) | |
execute_time = float(span_data.get('start_time', 0).split('.')[0].replace(' ', '').replace('-', '').replace(':', '')) | |
observation = {} | |
action = [] | |
messages = [] | |
pre_agent = None | |
if 'observation' in attr: | |
try: | |
observation = json.loads(attr['observation']) | |
except: | |
observation = attr['observation'] | |
if 'actions' in attr: | |
try: | |
action = json.loads(attr['actions']) | |
except: | |
action = attr['actions'] | |
if 'messages' in attr: | |
try: | |
messages = json.loads(attr['messages']) | |
except: | |
messages = attr['messages'] | |
pre_agent = attr.get('pre_agent', '') | |
reward = attr.get('reward', 0.0) | |
adv = attr.get('adv_t', 0.0) | |
v = attr.get('v_t', 0.0) | |
exp_meta = ExpMeta(task_id, task_name, agent_id, step, execute_time, pre_agent) | |
exp_data = Experience(observation, action, reward, adv, v, messages) | |
task_groups[task_id][exp_id]['exp_meta'] = exp_meta | |
task_groups[task_id][exp_id]['exp_data'] = exp_data | |
# Process data for each task_id | |
for task_id, exp_groups in task_groups.items(): | |
# Merge data and generate final Experience object | |
data_rows = [] | |
# Read existing data (if any) | |
output_path = self._task_output_paths.get(task_id) | |
if not output_path: | |
timestamp = datetime.datetime.now().strftime("%Y%m%d") | |
replay_dir = os.path.join(output_dir or "./trace_data", timestamp, get_local_ip(), "replays") | |
replay_dataset_path = os.getenv("REPLAY_TRACE_DATASET_PATH", replay_dir) | |
export_dir = os.path.abspath(replay_dataset_path) | |
os.makedirs(export_dir, exist_ok=True) | |
output_path = os.path.join(export_dir, f"task_replay_{task_id}.json") | |
self._task_output_paths[task_id] = output_path | |
# Use thread lock to protect read and write operations | |
file_lock = self._get_file_lock(output_path) | |
with file_lock: | |
if os.path.exists(output_path): | |
try: | |
with open(output_path, 'r', encoding='utf-8') as f: | |
existing_data = json.load(f) | |
data_rows.extend([DataRow( | |
ExpMeta(**row['exp_meta']), | |
Experience(**row['exp_data']), | |
row['id'] | |
) for row in existing_data]) | |
except Exception as e: | |
print(f"Failed to read existing file {output_path}: {str(e)}") | |
# Add new data | |
for exp_id, group in exp_groups.items(): | |
if group['exp_meta'] and group['exp_data']: | |
row = DataRow(group['exp_meta'], group['exp_data'], exp_id) | |
data_rows.append(row) | |
# Sort by execute_time | |
data_rows.sort(key=lambda x: x.exp_meta.execute_time) | |
# Export to json | |
with open(output_path, 'w', encoding='utf-8') as f: | |
json.dump([row.to_dict() for row in data_rows], f, ensure_ascii=False, indent=2) | |
logger.info(f"Processing completed, exported {len(data_rows)} experiences to {output_path}") | |
if enable_oss_export: | |
# Upload to OSS | |
try: | |
# Get the relative path | |
abs_path = os.path.abspath(output_path) | |
path_parts = abs_path.split(os.sep) | |
if len(path_parts) >= 4: | |
# Get the last 4 parts of the path | |
relative_path = os.sep.join(path_parts[-4:]) | |
oss_key = relative_path | |
else: | |
oss_key = f"replay_buffer/{os.path.basename(output_path)}" | |
bucket.put_object_from_file(oss_key, output_path) | |
logger.info(f"Successfully uploaded {output_path} to OSS: {oss_key}") | |
except Exception as e: | |
logger.warn(f"Failed to upload {output_path} to OSS: {str(e)}") | |