Duibonduil's picture
Upload 5 files
4b677a1 verified
import time
import traceback
import multiprocessing
from aworld import replay_buffer
from aworld.core.common import ActionModel, Observation
from aworld.replay_buffer.base import ReplayBuffer, DataRow, ExpMeta, Experience
from aworld.replay_buffer.query_filter import QueryBuilder
from aworld.replay_buffer.storage.multi_proc_mem import MultiProcMemoryStorage
from aworld.logs.util import logger
def write_processing(replay_buffer: ReplayBuffer, task_id: str):
for i in range(10):
try:
data = DataRow(
exp_meta=ExpMeta(
task_id=task_id,
task_name=task_id,
agent_id=f"agent_{i+1}",
step=i,
execute_time=time.time()
),
exp_data=Experience(state=Observation(),
actions=[ActionModel()])
)
replay_buffer.store(data)
except Exception as e:
stack_trace = traceback.format_exc()
logger.error(
f"write_processing error: {e}\nStack trace:\n{stack_trace}")
time.sleep(1)
def read_processing_by_task(replay_buffer: ReplayBuffer, task_id: str):
while True:
try:
query_condition = QueryBuilder().eq("exp_meta.task_id", task_id).build()
data = replay_buffer.sample_task(
query_condition=query_condition, batch_size=2)
logger.info(f"read data of task[{task_id}]: {data}")
except Exception as e:
stack_trace = traceback.format_exc()
logger.error(
f"read_processing_by_task error: {e}\nStack trace:\n{stack_trace}")
time.sleep(1)
def read_processing_by_agent(replay_buffer: ReplayBuffer, agent_id: str):
while True:
try:
query_condition = QueryBuilder().eq("exp_meta.agent_id", agent_id).build()
data = replay_buffer.sample_task(
query_condition=query_condition, batch_size=2)
logger.info(f"read data of agent[{agent_id}]: {data}")
except Exception as e:
logger.info(f"read_processing_by_agent error: {e}")
time.sleep(1)
if __name__ == "__main__":
multiprocessing.freeze_support()
multiprocessing.set_start_method('spawn')
manager = multiprocessing.Manager()
replay_buffer = ReplayBuffer(storage=MultiProcMemoryStorage(
data_dict=manager.dict(),
fifo_queue=manager.list(),
lock=manager.Lock(),
max_capacity=10000
))
processes = [
multiprocessing.Process(target=write_processing,
args=(replay_buffer, "task_1",)),
multiprocessing.Process(target=write_processing,
args=(replay_buffer, "task_2",)),
multiprocessing.Process(target=write_processing,
args=(replay_buffer, "task_3",)),
multiprocessing.Process(target=write_processing,
args=(replay_buffer, "task_4",)),
# multiprocessing.Process(
# target=read_processing_by_task, args=(replay_buffer, "task_1",)),
multiprocessing.Process(
target=read_processing_by_agent, args=(replay_buffer, "agent_3",))
]
for p in processes:
p.start()
try:
for p in processes:
p.join()
except KeyboardInterrupt:
for p in processes:
p.terminate()
for p in processes:
p.join()
finally:
logger.info("Processes terminated.")