File size: 4,397 Bytes
64a1e64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import asyncio
import sys
import streamlit as st
from dotenv import load_dotenv
import logging
import os
import traceback
import importlib.util
import utils
import aworld.trace as trace
from trace_net import generate_trace_graph_full
from aworld.trace.base import get_tracer_provider

load_dotenv(os.path.join(os.getcwd(), ".env"))

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

sys.path.insert(0, os.getcwd())


def agent_page():
    st.set_page_config(
        page_title="AWorld Agent",
        page_icon=":robot_face:",
        layout="wide",
    )

    st.markdown(
        """\
        <style> 
        .stAppHeader { display: none; }
        
        div[data-testid="stMarkdownContainer"] pre {
            max-height: 300px;
            overflow-y: auto;
        }
        div[data-testid="stMarkdownContainer"] img {
            max-height: 500px;
        }
        </style>""",
        unsafe_allow_html=True,
    )

    query_params = st.query_params
    selected_agent_from_url = query_params.get("agent", None)

    if "selected_agent" not in st.session_state:
        st.session_state.selected_agent = selected_agent_from_url
        logger.info(f"Initialized selected_agent from URL: {selected_agent_from_url}")

    if selected_agent_from_url != st.session_state.selected_agent:
        st.session_state.selected_agent = selected_agent_from_url

    with st.sidebar:
        st.title("AWorld Agents List")
        for agent in utils.list_agents():
            if st.button(agent):
                st.query_params["agent"] = agent
                st.session_state.selected_agent = agent
                logger.info(f"selected_agent={st.session_state.selected_agent}")

    if st.session_state.selected_agent:
        agent_name = st.session_state.selected_agent
        st.title(f"AWorld Agent: {agent_name}")

        if prompt := st.chat_input("Input message here~"):

            with st.chat_message("user"):
                st.markdown(prompt)

            with st.chat_message("assistant"):
                agent_name = st.session_state.selected_agent
                agent_package_path = utils.get_agent_package_path(agent_name)
                agent_module_file = os.path.join(agent_package_path, "agent.py")
                try:
                    spec = importlib.util.spec_from_file_location(
                        agent_name, agent_module_file
                    )

                    if spec is None or spec.loader is None:
                        logger.error(
                            f"Could not load spec for agent {agent_name} from {agent_module_file}"
                        )
                        st.error(f"Error: Could not load agent! {agent_name}")
                        return

                    agent_module = importlib.util.module_from_spec(spec)
                    spec.loader.exec_module(agent_module)
                except Exception as e:
                    logger.error(
                        f"Error loading agent {agent_name}, cwd:{os.getcwd()}, sys.path:{sys.path}: {traceback.format_exc()}"
                    )
                    st.error(f"Error: Could not load agent! {agent_name}")
                    return

                agent = agent_module.AWorldAgent()

                async def markdown_generator():
                    trace_id = None
                    async with trace.span("start") as span:
                        trace_id = span.get_trace_id()
                        async for line in agent.run(prompt):
                            st.write(line)
                            await asyncio.sleep(0.1)

                    get_tracer_provider().force_flush(5000)
                    file_name = f"graph.{trace_id}.html"
                    folder_name = "trace_data"
                    generate_trace_graph_full(
                        trace_id, folder_name=folder_name, file_name=file_name
                    )
                    view_page_url = f"/trace?trace_id={trace_id}"
                    st.write(f"\n---\n[View Trace]({view_page_url})\n")

                asyncio.run(markdown_generator())
    else:
        st.title("AWorld Agent Chat Assistant")
        st.info("Please select an Agent from the left sidebar to start")


try:
    agent_page()
except Exception as e:
    logger.error(f">>> Error: {traceback.format_exc()}")
    st.error(f"Error: {str(e)}")