File size: 2,284 Bytes
4b677a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import time
from aworld.replay_buffer.base import DataRow, ExpMeta, Experience
from aworld.replay_buffer.storage.redis import RedisStorage
from aworld.replay_buffer.query_filter import QueryBuilder
from aworld.core.common import Observation, ActionModel
from aworld.logs.util import logger

storage = RedisStorage(host="localhost", port=6379,
                       recreate_idx_if_exists=False)


def generate_data_row() -> list[DataRow]:
    rows: list[DataRow] = []
    for id in range(5):
        task_id = f"task_{id+1}"
        for i in range(5):
            agent_id = f"agent_{i+1}"
            for j in range(5):
                step = j + 1
                execute_time = time.time() + j
                row = DataRow(
                    exp_meta=ExpMeta(
                        task_id=task_id,
                        task_name="default_task_name",
                        agent_id=agent_id,
                        step=step,
                        execute_time=execute_time,
                        pre_agent="pre_agent_id"
                    ),
                    exp_data=Experience(state=Observation(),
                                        actions=[ActionModel()])
                )
                rows.append(row)
    return rows


def wriete_data():
    storage.clear()
    rows = generate_data_row()
    storage.add_batch(rows)
    logger.info(f"Add {len(rows)} rows to storage.")


def read_data():
    query_condition = (QueryBuilder()
                       .eq("exp_meta.task_id", "task_1")
                       .and_()
                       .eq("exp_meta.agent_id", "agent_1")
                       .or_()
                       .nested(QueryBuilder()
                               .eq("exp_meta.task_id", "task_4")
                               .and_()
                               .eq("exp_meta.agent_id", "agent_3")
                               .and_()
                               .gt("exp_meta.step", 4)).build())

    rows = storage.get_all(query_condition)
    for row in rows:
        logger.info(row)

    rows = storage.get_paginated(
        page=2, page_size=2, query_condition=query_condition)
    for row in rows:
        logger.info(f"get_paginated: {row}")


if __name__ == "__main__":
    # wriete_data()
    read_data()