Spaces:
Sleeping
Sleeping
import time | |
from aworld.core.common import ActionModel, Observation | |
from aworld.replay_buffer.base import ( | |
DataRow, | |
DefaultConverter, | |
ReplayBuffer, | |
ExpMeta, | |
Experience, | |
RandomTaskSample | |
) | |
from aworld.replay_buffer.query_filter import QueryBuilder | |
from aworld.logs.util import logger | |
from aworld.replay_buffer.storage.odps import OdpsStorage | |
buffer = ReplayBuffer(storage=OdpsStorage( | |
table_name="adm_aworld_replay_buffer", | |
project="alifin_jtest_dev", | |
endpoint="", | |
access_id="", | |
access_key="" | |
)) | |
def write_data(): | |
rows = [] | |
for id in range(5): | |
task_id = f"task_{id+1}" | |
for i in range(5): | |
agent_id = f"agent_{i+1}" | |
for j in range(5): | |
step = j + 1 | |
execute_time = time.time() + j | |
row = DataRow( | |
exp_meta=ExpMeta( | |
task_id=task_id, | |
task_name="default_task_name", | |
agent_id=agent_id, | |
step=step, | |
execute_time=execute_time, | |
pre_agent="pre_agent_id" | |
), | |
exp_data=Experience(state=Observation(), | |
actions=[ActionModel()]) | |
) | |
rows.append(row) | |
buffer.store_batch(rows) | |
def read_data(): | |
query = QueryBuilder().eq("exp_meta.task_id", "task_1").build() | |
datas = buffer.sample_task(query_condition=query, | |
sampler=RandomTaskSample(), | |
converter=DefaultConverter(), | |
batch_size=1) | |
for data in datas: | |
logger.info(f"task_1 data: {data}") | |
query = QueryBuilder().eq("exp_meta.agent_id", "agent_5").build() | |
datas = buffer.sample_task(query_condition=query, | |
sampler=RandomTaskSample(), | |
converter=DefaultConverter(), | |
batch_size=2) | |
for data in datas: | |
logger.info(f"agent_5 data: {data}") | |
if __name__ == "__main__": | |
# write_data() | |
read_data() | |