File size: 1,797 Bytes
912a768
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import time
from aworld.core.common import ActionModel, Observation
from aworld.replay_buffer.base import (
    DataRow,
    DefaultConverter,
    ReplayBuffer,
    ExpMeta,
    Experience,
    RandomTaskSample
)
from aworld.replay_buffer.query_filter import QueryBuilder
from aworld.logs.util import logger


buffer = ReplayBuffer()


def write_data():
    for task_id in range(5):
        for i in range(10):
            task_id = f"task_{task_id}"
            agent_id = f"agent_{i+1}"
            step = i + 1
            execute_time = time.time() + i
            row = DataRow(
                exp_meta=ExpMeta(
                    task_id=task_id,
                    task_name="default_task_name",
                    agent_id=agent_id,
                    step=step,
                    execute_time=execute_time,
                ),
                exp_data=Experience(state=Observation(),
                                    actions=[ActionModel()])
            )
            buffer.store(row)


def read_data():
    query = QueryBuilder().eq("exp_meta.task_id", "task_1").build()
    datas = buffer.sample_task(query_condition=query,
                               sampler=RandomTaskSample(),
                               converter=DefaultConverter(),
                               batch_size=2)
    for data in datas:
        logger.info(f"task_1 data: {data}")

    query = QueryBuilder().eq("exp_meta.agent_id", "agent_5").build()
    datas = buffer.sample_task(query_condition=query,
                               sampler=RandomTaskSample(),
                               converter=DefaultConverter(),
                               batch_size=2)
    for data in datas:
        logger.info(f"agent_5 data: {data}")


if __name__ == "__main__":
    write_data()
    read_data()