File size: 2,172 Bytes
4b677a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import time
from aworld.core.common import ActionModel, Observation
from aworld.replay_buffer.base import (
    DataRow,
    DefaultConverter,
    ReplayBuffer,
    ExpMeta,
    Experience,
    RandomTaskSample
)
from aworld.replay_buffer.query_filter import QueryBuilder
from aworld.logs.util import logger
from aworld.replay_buffer.storage.odps import OdpsStorage


buffer = ReplayBuffer(storage=OdpsStorage(
    table_name="adm_aworld_replay_buffer",
    project="alifin_jtest_dev",
    endpoint="",
    access_id="",
    access_key=""
))


def write_data():
    rows = []
    for id in range(5):
        task_id = f"task_{id+1}"
        for i in range(5):
            agent_id = f"agent_{i+1}"
            for j in range(5):
                step = j + 1
                execute_time = time.time() + j
                row = DataRow(
                    exp_meta=ExpMeta(
                        task_id=task_id,
                        task_name="default_task_name",
                        agent_id=agent_id,
                        step=step,
                        execute_time=execute_time,
                        pre_agent="pre_agent_id"
                    ),
                    exp_data=Experience(state=Observation(),
                                        actions=[ActionModel()])
                )
                rows.append(row)
    buffer.store_batch(rows)


def read_data():
    query = QueryBuilder().eq("exp_meta.task_id", "task_1").build()
    datas = buffer.sample_task(query_condition=query,
                               sampler=RandomTaskSample(),
                               converter=DefaultConverter(),
                               batch_size=1)
    for data in datas:
        logger.info(f"task_1 data: {data}")

    query = QueryBuilder().eq("exp_meta.agent_id", "agent_5").build()
    datas = buffer.sample_task(query_condition=query,
                               sampler=RandomTaskSample(),
                               converter=DefaultConverter(),
                               batch_size=2)
    for data in datas:
        logger.info(f"agent_5 data: {data}")


if __name__ == "__main__":
    # write_data()
    read_data()