import time import traceback import multiprocessing from aworld import replay_buffer from aworld.core.common import ActionModel, Observation from aworld.replay_buffer.base import ReplayBuffer, DataRow, ExpMeta, Experience from aworld.replay_buffer.query_filter import QueryBuilder from aworld.replay_buffer.storage.multi_proc_mem import MultiProcMemoryStorage from aworld.logs.util import logger def write_processing(replay_buffer: ReplayBuffer, task_id: str): for i in range(10): try: data = DataRow( exp_meta=ExpMeta( task_id=task_id, task_name=task_id, agent_id=f"agent_{i+1}", step=i, execute_time=time.time() ), exp_data=Experience(state=Observation(), actions=[ActionModel()]) ) replay_buffer.store(data) except Exception as e: stack_trace = traceback.format_exc() logger.error( f"write_processing error: {e}\nStack trace:\n{stack_trace}") time.sleep(1) def read_processing_by_task(replay_buffer: ReplayBuffer, task_id: str): while True: try: query_condition = QueryBuilder().eq("exp_meta.task_id", task_id).build() data = replay_buffer.sample_task( query_condition=query_condition, batch_size=2) logger.info(f"read data of task[{task_id}]: {data}") except Exception as e: stack_trace = traceback.format_exc() logger.error( f"read_processing_by_task error: {e}\nStack trace:\n{stack_trace}") time.sleep(1) def read_processing_by_agent(replay_buffer: ReplayBuffer, agent_id: str): while True: try: query_condition = QueryBuilder().eq("exp_meta.agent_id", agent_id).build() data = replay_buffer.sample_task( query_condition=query_condition, batch_size=2) logger.info(f"read data of agent[{agent_id}]: {data}") except Exception as e: logger.info(f"read_processing_by_agent error: {e}") time.sleep(1) if __name__ == "__main__": multiprocessing.freeze_support() multiprocessing.set_start_method('spawn') manager = multiprocessing.Manager() replay_buffer = ReplayBuffer(storage=MultiProcMemoryStorage( data_dict=manager.dict(), fifo_queue=manager.list(), lock=manager.Lock(), max_capacity=10000 )) processes = [ multiprocessing.Process(target=write_processing, args=(replay_buffer, "task_1",)), multiprocessing.Process(target=write_processing, args=(replay_buffer, "task_2",)), multiprocessing.Process(target=write_processing, args=(replay_buffer, "task_3",)), multiprocessing.Process(target=write_processing, args=(replay_buffer, "task_4",)), # multiprocessing.Process( # target=read_processing_by_task, args=(replay_buffer, "task_1",)), multiprocessing.Process( target=read_processing_by_agent, args=(replay_buffer, "agent_3",)) ] for p in processes: p.start() try: for p in processes: p.join() except KeyboardInterrupt: for p in processes: p.terminate() for p in processes: p.join() finally: logger.info("Processes terminated.")