Duibonduil's picture
Upload 5 files
7c117ed verified
# coding: utf-8
"""
processor.py
Used to clean raw trace data into standard storage structure for reinforcement learning training.
"""
import json
import os
import datetime
from typing import Any
import threading
from aworld.utils import import_package
from aworld.replay_buffer.base import DataRow, Experience, ExpMeta
from aworld.logs.util import logger
from aworld.utils.common import get_local_ip
class ReplayBufferExporter:
def __init__(self):
"""Initialize ReplayBufferExporter instance"""
self._file_locks = {}
self._lock_dict_lock = threading.Lock()
self._task_output_paths = {}
def _get_file_lock(self, file_path):
"""Get the lock for the specified file"""
with self._lock_dict_lock:
if file_path not in self._file_locks:
self._file_locks[file_path] = threading.Lock()
return self._file_locks[file_path]
def replay_buffer_exporter(self, spans: list[dict[str, Any]], output_dir: str):
"""
Process spans, only process spans with 'step_execution_' prefix, and group by task_id to output to different files
Args:
spans: span data list
output_dir: output directory path
"""
# Ensure output directory exists
import_package("oss2")
import oss2
os.makedirs(output_dir, exist_ok=True)
# Get OSS credentials from environment variables
enable_oss_export = os.getenv("EXPORT_REPLAY_TRACE_TO_OSS", "false").lower() == "true"
access_key_id = os.getenv('OSS_ACCESS_KEY_ID')
access_key_secret = os.getenv('OSS_ACCESS_KEY_SECRET')
endpoint = os.getenv('OSS_ENDPOINT')
bucket_name = os.getenv('OSS_BUCKET_NAME')
bucket = None
if not all([access_key_id, access_key_secret, endpoint, bucket_name]):
enable_oss_export = False
logger.warn("Missing required OSS environment variables")
else:
try:
# Initialize OSS client
auth = oss2.Auth(access_key_id, access_key_secret)
bucket = oss2.Bucket(auth, endpoint, bucket_name)
except Exception as e:
enable_oss_export = False
logger.warn(f"Failed to initialize OSS client, endpoint: {endpoint}, bucket: {bucket_name}. Error: {str(e)}")
# Group by task_id
task_groups = {}
for span_data in spans:
# Only process spans with 'step_execution_' prefix
if not span_data['name'].startswith('step_execution_'):
continue
attr = span_data.get('attributes', {})
exp_id = attr.get('exp_id')
task_id = attr.get('task_id', '')
if not exp_id or not task_id:
continue
if task_id not in task_groups:
task_groups[task_id] = {}
if exp_id not in task_groups[task_id]:
task_groups[task_id][exp_id] = {
'exp_meta': None,
'exp_data': None
}
# Process step_execution span
task_name = attr.get('task_name', '')
agent_id = attr.get('agent_id', '')
step = attr.get('step', 0)
execute_time = float(span_data.get('start_time', 0).split('.')[0].replace(' ', '').replace('-', '').replace(':', ''))
observation = {}
action = []
messages = []
pre_agent = None
if 'observation' in attr:
try:
observation = json.loads(attr['observation'])
except:
observation = attr['observation']
if 'actions' in attr:
try:
action = json.loads(attr['actions'])
except:
action = attr['actions']
if 'messages' in attr:
try:
messages = json.loads(attr['messages'])
except:
messages = attr['messages']
pre_agent = attr.get('pre_agent', '')
reward = attr.get('reward', 0.0)
adv = attr.get('adv_t', 0.0)
v = attr.get('v_t', 0.0)
exp_meta = ExpMeta(task_id, task_name, agent_id, step, execute_time, pre_agent)
exp_data = Experience(observation, action, reward, adv, v, messages)
task_groups[task_id][exp_id]['exp_meta'] = exp_meta
task_groups[task_id][exp_id]['exp_data'] = exp_data
# Process data for each task_id
for task_id, exp_groups in task_groups.items():
# Merge data and generate final Experience object
data_rows = []
# Read existing data (if any)
output_path = self._task_output_paths.get(task_id)
if not output_path:
timestamp = datetime.datetime.now().strftime("%Y%m%d")
replay_dir = os.path.join(output_dir or "./trace_data", timestamp, get_local_ip(), "replays")
replay_dataset_path = os.getenv("REPLAY_TRACE_DATASET_PATH", replay_dir)
export_dir = os.path.abspath(replay_dataset_path)
os.makedirs(export_dir, exist_ok=True)
output_path = os.path.join(export_dir, f"task_replay_{task_id}.json")
self._task_output_paths[task_id] = output_path
# Use thread lock to protect read and write operations
file_lock = self._get_file_lock(output_path)
with file_lock:
if os.path.exists(output_path):
try:
with open(output_path, 'r', encoding='utf-8') as f:
existing_data = json.load(f)
data_rows.extend([DataRow(
ExpMeta(**row['exp_meta']),
Experience(**row['exp_data']),
row['id']
) for row in existing_data])
except Exception as e:
print(f"Failed to read existing file {output_path}: {str(e)}")
# Add new data
for exp_id, group in exp_groups.items():
if group['exp_meta'] and group['exp_data']:
row = DataRow(group['exp_meta'], group['exp_data'], exp_id)
data_rows.append(row)
# Sort by execute_time
data_rows.sort(key=lambda x: x.exp_meta.execute_time)
# Export to json
with open(output_path, 'w', encoding='utf-8') as f:
json.dump([row.to_dict() for row in data_rows], f, ensure_ascii=False, indent=2)
logger.info(f"Processing completed, exported {len(data_rows)} experiences to {output_path}")
if enable_oss_export:
# Upload to OSS
try:
# Get the relative path
abs_path = os.path.abspath(output_path)
path_parts = abs_path.split(os.sep)
if len(path_parts) >= 4:
# Get the last 4 parts of the path
relative_path = os.sep.join(path_parts[-4:])
oss_key = relative_path
else:
oss_key = f"replay_buffer/{os.path.basename(output_path)}"
bucket.put_object_from_file(oss_key, output_path)
logger.info(f"Successfully uploaded {output_path} to OSS: {oss_key}")
except Exception as e:
logger.warn(f"Failed to upload {output_path} to OSS: {str(e)}")