Spaces:
Sleeping
Sleeping
# coding: utf-8 | |
# Copyright (c) 2025 inclusionAI. | |
from typing import Dict, Any, List, Callable | |
from aworld.core.context.base import Context | |
from aworld.core.event import eventbus | |
from aworld.core.event.base import Constants, Message | |
class EventManager: | |
"""The event manager is now used to build an event bus instance and store the messages recently.""" | |
def __init__(self, context: Context, **kwargs): | |
# use conf to build event bus instance | |
self.event_bus = eventbus | |
self.context = context | |
# Record events in memory for re-consume. | |
self.messages: Dict[str, List[Message]] = {'None': []} | |
self.max_len = kwargs.get('max_len', 1000) | |
async def emit( | |
self, | |
data: Any, | |
sender: str, | |
receiver: str = None, | |
topic: str = None, | |
session_id: str = None, | |
event_type: str = Constants.TASK | |
): | |
"""Send data to the event bus. | |
Args: | |
data: Message payload. | |
sender: The sender name of the message. | |
receiver: The receiver name of the message. | |
topic: The topic to which the message belongs. | |
session_id: Special session id. | |
event_type: Event type. | |
""" | |
event = Message( | |
payload=data, | |
session_id=session_id if session_id else self.context.session_id, | |
sender=sender, | |
receiver=receiver, | |
topic=topic, | |
category=event_type, | |
) | |
return await self.emit_message(event) | |
async def emit_message(self, event: Message): | |
"""Send the message to the event bus.""" | |
key = event.key() | |
if key not in self.messages: | |
self.messages[key] = [] | |
self.messages[key].append(event) | |
if len(self.messages) > self.max_len: | |
self.messages = self.messages[-self.max_len:] | |
await self.event_bus.publish(event) | |
return True | |
async def consume(self, nowait: bool = False): | |
msg = Message(session_id=self.context.session_id, sender="", category="", payload="") | |
msg.context = self.context | |
if nowait: | |
return await self.event_bus.consume_nowait(msg) | |
return await self.event_bus.consume(msg) | |
async def done(self): | |
await self.event_bus.done(self.context.task_id) | |
async def register(self, event_type: str, topic: str, handler: Callable[..., Any], **kwargs): | |
await self.event_bus.subscribe(event_type, topic, handler, **kwargs) | |
async def unregister(self, event_type: str, topic: str, handler: Callable[..., Any], **kwargs): | |
await self.event_bus.unsubscribe(event_type, topic, handler, **kwargs) | |
async def register_transformer(self, event_type: str, topic: str, handler: Callable[..., Any], **kwargs): | |
await self.event_bus.subscribe(event_type, topic, handler, transformer=True, **kwargs) | |
async def unregister_transformer(self, event_type: str, topic: str, handler: Callable[..., Any], **kwargs): | |
await self.event_bus.unsubscribe(event_type, topic, handler, transformer=True, **kwargs) | |
def messages_by_key(self, key: str) -> List[Message]: | |
return self.messages.get(key, []) | |
def messages_by_sender(self, sender: str, key: str): | |
results = [] | |
for res in self.messages.get(key, []): | |
if res.sender == sender: | |
results.append(res) | |
return results | |
def messages_by_topic(self, topic: str, key: str): | |
results = [] | |
for res in self.messages.get(key, []): | |
if res.topic == topic: | |
results.append(res) | |
return results | |
def session_messages(self, session_id: str) -> List[Message]: | |
return [m for k, msg in self.messages.items() for m in msg if m.session_id == session_id] | |
def mark_valid(messages: List[Message]): | |
for msg in messages: | |
msg.is_valid = True | |
def mark_invalid(messages: List[Message]): | |
for msg in messages: | |
msg.is_valid = False | |
def clear_messages(self): | |
self.messages = [] | |