Spaces:
Sleeping
Sleeping
import time | |
import traceback | |
import multiprocessing | |
from aworld import replay_buffer | |
from aworld.core.common import ActionModel, Observation | |
from aworld.replay_buffer.base import ReplayBuffer, DataRow, ExpMeta, Experience | |
from aworld.replay_buffer.query_filter import QueryBuilder | |
from aworld.replay_buffer.storage.multi_proc_mem import MultiProcMemoryStorage | |
from aworld.logs.util import logger | |
def write_processing(replay_buffer: ReplayBuffer, task_id: str): | |
for i in range(10): | |
try: | |
data = DataRow( | |
exp_meta=ExpMeta( | |
task_id=task_id, | |
task_name=task_id, | |
agent_id=f"agent_{i+1}", | |
step=i, | |
execute_time=time.time() | |
), | |
exp_data=Experience(state=Observation(), | |
actions=[ActionModel()]) | |
) | |
replay_buffer.store(data) | |
except Exception as e: | |
stack_trace = traceback.format_exc() | |
logger.error( | |
f"write_processing error: {e}\nStack trace:\n{stack_trace}") | |
time.sleep(1) | |
def read_processing_by_task(replay_buffer: ReplayBuffer, task_id: str): | |
while True: | |
try: | |
query_condition = QueryBuilder().eq("exp_meta.task_id", task_id).build() | |
data = replay_buffer.sample_task( | |
query_condition=query_condition, batch_size=2) | |
logger.info(f"read data of task[{task_id}]: {data}") | |
except Exception as e: | |
stack_trace = traceback.format_exc() | |
logger.error( | |
f"read_processing_by_task error: {e}\nStack trace:\n{stack_trace}") | |
time.sleep(1) | |
def read_processing_by_agent(replay_buffer: ReplayBuffer, agent_id: str): | |
while True: | |
try: | |
query_condition = QueryBuilder().eq("exp_meta.agent_id", agent_id).build() | |
data = replay_buffer.sample_task( | |
query_condition=query_condition, batch_size=2) | |
logger.info(f"read data of agent[{agent_id}]: {data}") | |
except Exception as e: | |
logger.info(f"read_processing_by_agent error: {e}") | |
time.sleep(1) | |
if __name__ == "__main__": | |
multiprocessing.freeze_support() | |
multiprocessing.set_start_method('spawn') | |
manager = multiprocessing.Manager() | |
replay_buffer = ReplayBuffer(storage=MultiProcMemoryStorage( | |
data_dict=manager.dict(), | |
fifo_queue=manager.list(), | |
lock=manager.Lock(), | |
max_capacity=10000 | |
)) | |
processes = [ | |
multiprocessing.Process(target=write_processing, | |
args=(replay_buffer, "task_1",)), | |
multiprocessing.Process(target=write_processing, | |
args=(replay_buffer, "task_2",)), | |
multiprocessing.Process(target=write_processing, | |
args=(replay_buffer, "task_3",)), | |
multiprocessing.Process(target=write_processing, | |
args=(replay_buffer, "task_4",)), | |
# multiprocessing.Process( | |
# target=read_processing_by_task, args=(replay_buffer, "task_1",)), | |
multiprocessing.Process( | |
target=read_processing_by_agent, args=(replay_buffer, "agent_3",)) | |
] | |
for p in processes: | |
p.start() | |
try: | |
for p in processes: | |
p.join() | |
except KeyboardInterrupt: | |
for p in processes: | |
p.terminate() | |
for p in processes: | |
p.join() | |
finally: | |
logger.info("Processes terminated.") | |