File size: 3,595 Bytes
4b677a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import time
import traceback
import multiprocessing
from aworld import replay_buffer
from aworld.core.common import ActionModel, Observation
from aworld.replay_buffer.base import ReplayBuffer, DataRow, ExpMeta, Experience
from aworld.replay_buffer.query_filter import QueryBuilder
from aworld.replay_buffer.storage.multi_proc_mem import MultiProcMemoryStorage
from aworld.logs.util import logger


def write_processing(replay_buffer: ReplayBuffer, task_id: str):
    for i in range(10):
        try:
            data = DataRow(
                exp_meta=ExpMeta(
                    task_id=task_id,
                    task_name=task_id,
                    agent_id=f"agent_{i+1}",
                    step=i,
                    execute_time=time.time()
                ),
                exp_data=Experience(state=Observation(),
                                    actions=[ActionModel()])
            )
            replay_buffer.store(data)
        except Exception as e:
            stack_trace = traceback.format_exc()
            logger.error(
                f"write_processing error: {e}\nStack trace:\n{stack_trace}")
        time.sleep(1)


def read_processing_by_task(replay_buffer: ReplayBuffer, task_id: str):
    while True:
        try:
            query_condition = QueryBuilder().eq("exp_meta.task_id", task_id).build()
            data = replay_buffer.sample_task(
                query_condition=query_condition, batch_size=2)
            logger.info(f"read data of task[{task_id}]: {data}")
        except Exception as e:
            stack_trace = traceback.format_exc()
            logger.error(
                f"read_processing_by_task error: {e}\nStack trace:\n{stack_trace}")
        time.sleep(1)


def read_processing_by_agent(replay_buffer: ReplayBuffer, agent_id: str):
    while True:
        try:
            query_condition = QueryBuilder().eq("exp_meta.agent_id", agent_id).build()
            data = replay_buffer.sample_task(
                query_condition=query_condition, batch_size=2)
            logger.info(f"read data of agent[{agent_id}]: {data}")
        except Exception as e:
            logger.info(f"read_processing_by_agent error: {e}")
        time.sleep(1)


if __name__ == "__main__":
    multiprocessing.freeze_support()
    multiprocessing.set_start_method('spawn')
    manager = multiprocessing.Manager()

    replay_buffer = ReplayBuffer(storage=MultiProcMemoryStorage(
        data_dict=manager.dict(),
        fifo_queue=manager.list(),
        lock=manager.Lock(),
        max_capacity=10000
    ))

    processes = [
        multiprocessing.Process(target=write_processing,
                                args=(replay_buffer, "task_1",)),
        multiprocessing.Process(target=write_processing,
                                args=(replay_buffer, "task_2",)),
        multiprocessing.Process(target=write_processing,
                                args=(replay_buffer, "task_3",)),
        multiprocessing.Process(target=write_processing,
                                args=(replay_buffer, "task_4",)),
        # multiprocessing.Process(
        #     target=read_processing_by_task, args=(replay_buffer, "task_1",)),
        multiprocessing.Process(
            target=read_processing_by_agent, args=(replay_buffer, "agent_3",))
    ]
    for p in processes:
        p.start()

    try:
        for p in processes:
            p.join()
    except KeyboardInterrupt:
        for p in processes:
            p.terminate()
        for p in processes:
            p.join()
    finally:
        logger.info("Processes terminated.")