Duibonduil's picture
Upload 5 files
4b677a1 verified
import time
from aworld.core.common import ActionModel, Observation
from aworld.replay_buffer.base import (
DataRow,
DefaultConverter,
ReplayBuffer,
ExpMeta,
Experience,
RandomTaskSample
)
from aworld.replay_buffer.query_filter import QueryBuilder
from aworld.logs.util import logger
from aworld.replay_buffer.storage.odps import OdpsStorage
buffer = ReplayBuffer(storage=OdpsStorage(
table_name="adm_aworld_replay_buffer",
project="alifin_jtest_dev",
endpoint="",
access_id="",
access_key=""
))
def write_data():
rows = []
for id in range(5):
task_id = f"task_{id+1}"
for i in range(5):
agent_id = f"agent_{i+1}"
for j in range(5):
step = j + 1
execute_time = time.time() + j
row = DataRow(
exp_meta=ExpMeta(
task_id=task_id,
task_name="default_task_name",
agent_id=agent_id,
step=step,
execute_time=execute_time,
pre_agent="pre_agent_id"
),
exp_data=Experience(state=Observation(),
actions=[ActionModel()])
)
rows.append(row)
buffer.store_batch(rows)
def read_data():
query = QueryBuilder().eq("exp_meta.task_id", "task_1").build()
datas = buffer.sample_task(query_condition=query,
sampler=RandomTaskSample(),
converter=DefaultConverter(),
batch_size=1)
for data in datas:
logger.info(f"task_1 data: {data}")
query = QueryBuilder().eq("exp_meta.agent_id", "agent_5").build()
datas = buffer.sample_task(query_condition=query,
sampler=RandomTaskSample(),
converter=DefaultConverter(),
batch_size=2)
for data in datas:
logger.info(f"agent_5 data: {data}")
if __name__ == "__main__":
# write_data()
read_data()