Spaces:
Running
Running
from __future__ import annotations | |
import dataclasses | |
import enum | |
import os | |
from collections import OrderedDict | |
from collections.abc import Mapping, Sequence | |
from pathlib import Path | |
from types import MappingProxyType | |
from typing import TYPE_CHECKING, Any | |
import boto3 | |
import botocore | |
import botocore.exceptions | |
import gradio as gr | |
import gradio.themes as gr_themes | |
import markdown | |
from langchain_aws import ChatBedrock | |
from langchain_core.callbacks import BaseCallbackHandler | |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
from langchain_core.tools import BaseTool | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
from langchain_mcp_adapters.client import MultiServerMCPClient | |
from langchain_openai import AzureChatOpenAI | |
from langgraph.prebuilt import create_react_agent | |
from openai import OpenAI | |
from openai.types.chat import ChatCompletion | |
from tdagent.grcomponents import MutableCheckBoxGroup, MutableCheckBoxGroupEntry | |
if TYPE_CHECKING: | |
from langgraph.graph.graph import CompiledGraph | |
#### Constants #### | |
class AgentType(str, enum.Enum): | |
"""TDAgent type.""" | |
DATA_ENRICHER = "Data enricher" | |
INCIDENT_HANDLER = "Incident handler" | |
PEN_TESTER = "PenTester" | |
def __str__(self) -> str: # noqa: D105 | |
return self.value | |
AGENT_SYSTEM_MESSAGES = OrderedDict( | |
( | |
( | |
AgentType.DATA_ENRICHER, | |
""" | |
You are a cybersecurity incidence data enriching assistant. Analysts | |
will present information about security incidents and you must use | |
all the tools at your disposal to enrich the data as much as possible. | |
""".strip(), | |
), | |
( | |
AgentType.INCIDENT_HANDLER, | |
""" | |
You are a security analyst assistant responsible for collecting, analyzing | |
and disseminating actionable intelligence related to cyber threats, | |
vulnerabilities and threat actors. | |
When presented with potential incidents information or tickets, you should | |
evaluate the presented evidence, gather additional data using any tool at | |
your disposal and take corrective actions if possible. | |
Afterwards, generate a cybersecurity report including: key findings, challenges, | |
actions taken and recommendations. | |
Never use external means of communication, like emails or SMS, unless | |
instructed to do so. | |
""".strip(), | |
), | |
( | |
AgentType.PEN_TESTER, | |
""" | |
You are a cybersecurity pentester. You use tools to analyze domain to try to discover system vulnerabilities. | |
Always report you findings and suggest next steps to deep dive where applicable. | |
""".strip(), | |
), | |
), | |
) | |
GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType( | |
{ | |
"user": HumanMessage, | |
"assistant": AIMessage, | |
}, | |
) | |
MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order | |
( | |
( | |
"HuggingFace", | |
{ | |
"Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3", | |
"Llama 3.1 8B Instruct": "meta-llama/Llama-3.1-8B-Instruct", | |
# "Qwen3 235B A22B": "Qwen/Qwen3-235B-A22B", # Slow inference | |
"Microsoft Phi-3.5-mini Instruct": "microsoft/Phi-3.5-mini-instruct", | |
# "Deepseek R1 distill-llama 70B": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", # noqa: E501 | |
# "Deepseek V3": "deepseek-ai/DeepSeek-V3", | |
}, | |
), | |
( | |
"AWS Bedrock", | |
{ | |
"Anthropic Claude 3.5 Sonnet (EU)": ( | |
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0" | |
), | |
"Anthropic Claude 3.7 Sonnet": ( | |
"anthropic.claude-3-7-sonnet-20250219-v1:0" | |
), | |
"Claude Sonnet 4": ( | |
"anthropic.claude-sonnet-4-20250514-v1:0" | |
), | |
}, | |
), | |
( | |
"Azure OpenAI", | |
{ | |
"GPT-4o": ("ggpt-4o-global-standard"), | |
"GPT-4o Mini": ("o4-mini"), | |
"GPT-4.5 Preview": ("gpt-4.5-preview"), | |
}, | |
), | |
), | |
) | |
CONNECT_STATE_DEFAULT = gr.State() | |
class ToolInvocationInfo: | |
"""Information related to a tool invocation by the LLM.""" | |
name: str | |
inputs: Mapping[str, Any] | |
class ToolsTracerCallback(BaseCallbackHandler): | |
"""Callback that registers tools invoked by the Agent.""" | |
def __init__(self) -> None: | |
self._tools_trace: list[ToolInvocationInfo] = [] | |
def on_tool_start( # noqa: D102 | |
self, | |
serialized: dict[str, Any], | |
*args: Any, | |
inputs: dict[str, Any] | None = None, | |
**kwargs: Any, | |
) -> Any: | |
self._tools_trace.append( | |
ToolInvocationInfo( | |
name=serialized.get("name", "<unknown-function-name>"), | |
inputs=inputs if inputs else {}, | |
), | |
) | |
return super().on_tool_start(serialized, *args, inputs=inputs, **kwargs) | |
def tools_trace(self) -> Sequence[ToolInvocationInfo]: | |
"""Tools trace information.""" | |
return self._tools_trace | |
def clear(self) -> None: | |
"""Clear tools trace.""" | |
self._tools_trace.clear() | |
#### Shared variables #### | |
llm_agent: CompiledGraph | None = None | |
llm_tools_tracer: ToolsTracerCallback | None = None | |
#### Utility functions #### | |
## Bedrock LLM creation ## | |
def create_bedrock_llm( | |
bedrock_model_id: str, | |
aws_access_key: str, | |
aws_secret_key: str, | |
aws_session_token: str, | |
aws_region: str, | |
temperature: float = 0.8, | |
max_tokens: int = 512, | |
) -> tuple[ChatBedrock | None, str]: | |
"""Create a LangGraph Bedrock agent.""" | |
boto3_config = { | |
"aws_access_key_id": aws_access_key, | |
"aws_secret_access_key": aws_secret_key, | |
"aws_session_token": aws_session_token if aws_session_token else None, | |
"region_name": aws_region, | |
} | |
# Verify credentials | |
try: | |
sts = boto3.client("sts", **boto3_config) | |
sts.get_caller_identity() | |
except botocore.exceptions.ClientError as err: | |
return None, str(err) | |
try: | |
bedrock_client = boto3.client("bedrock-runtime", **boto3_config) | |
llm = ChatBedrock( | |
model=bedrock_model_id, | |
client=bedrock_client, | |
model_kwargs={"temperature": temperature, "max_tokens": max_tokens}, | |
) | |
except Exception as e: # noqa: BLE001 | |
return None, str(e) | |
return llm, "" | |
## Hugging Face LLM creation ## | |
def create_hf_llm( | |
hf_model_id: str, | |
huggingfacehub_api_token: str | None = None, | |
temperature: float = 0.8, | |
max_tokens: int = 512, | |
) -> tuple[ChatHuggingFace | None, str]: | |
"""Create a LangGraph Hugging Face agent.""" | |
try: | |
llm = HuggingFaceEndpoint( | |
model=hf_model_id, | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
task="text-generation", | |
huggingfacehub_api_token=huggingfacehub_api_token, | |
) | |
chat_llm = ChatHuggingFace(llm=llm) | |
except Exception as e: # noqa: BLE001 | |
return None, str(e) | |
return chat_llm, "" | |
## OpenAI LLM creation ## | |
def create_openai_llm( | |
model_id: str, | |
token_id: str, | |
) -> tuple[ChatCompletion | None, str]: | |
"""Create a LangGraph OpenAI agent.""" | |
try: | |
client = OpenAI( | |
base_url="https://api.studio.nebius.com/v1/", | |
api_key=token_id, | |
) | |
llm = client.chat.completions.create( | |
messages=[], # needs to be fixed | |
model=model_id, | |
max_tokens=512, | |
temperature=0.8, | |
) | |
except Exception as e: # noqa: BLE001 | |
return None, str(e) | |
return llm, "" | |
def create_azure_llm( | |
model_id: str, | |
api_version: str, | |
endpoint: str, | |
token_id: str, | |
temperature: float = 0.8, | |
max_tokens: int = 512, | |
) -> tuple[AzureChatOpenAI | None, str]: | |
"""Create a LangGraph Azure OpenAI agent.""" | |
try: | |
os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint | |
os.environ["AZURE_OPENAI_API_KEY"] = token_id | |
if "o4-mini" in model_id: | |
kwargs = {"max_completion_tokens": max_tokens} | |
else: | |
kwargs = {"max_tokens": max_tokens} | |
llm = AzureChatOpenAI( | |
azure_deployment=model_id, | |
api_key=token_id, | |
api_version=api_version, | |
temperature=temperature, | |
**kwargs, | |
) | |
except Exception as e: # noqa: BLE001 | |
return None, str(e) | |
return llm, "" | |
#### UI functionality #### | |
async def gr_fetch_mcp_tools( | |
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
*, | |
trace_tools: bool, | |
) -> list[BaseTool]: | |
"""Fetch tools from MCP servers.""" | |
global llm_tools_tracer # noqa: PLW0603 | |
if mcp_servers: | |
client = MultiServerMCPClient( | |
{ | |
server.name.replace(" ", "-"): { | |
"url": server.value, | |
"transport": "sse", | |
} | |
for server in mcp_servers | |
}, | |
) | |
tools = await client.get_tools() | |
if trace_tools: | |
llm_tools_tracer = ToolsTracerCallback() | |
for tool in tools: | |
if tool.callbacks is None: | |
tool.callbacks = [llm_tools_tracer] | |
elif isinstance(tool.callbacks, list): | |
tool.callbacks.append(llm_tools_tracer) | |
else: | |
tool.callbacks.add_handler(llm_tools_tracer) | |
else: | |
llm_tools_tracer = None | |
return tools | |
return [] | |
def gr_make_system_message( | |
agent_type: AgentType, | |
) -> SystemMessage: | |
"""Make agent's system message.""" | |
try: | |
system_msg = AGENT_SYSTEM_MESSAGES[agent_type] | |
except KeyError as err: | |
raise gr.Error(f"Unknown agent type '{agent_type}'") from err | |
return SystemMessage(system_msg) | |
async def gr_connect_to_bedrock( # noqa: PLR0913 | |
model_id: str, | |
access_key: str, | |
secret_key: str, | |
session_token: str, | |
region: str, | |
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
agent_type: AgentType, | |
trace_tool_calls: bool, | |
temperature: float = 0.8, | |
max_tokens: int = 512, | |
) -> str: | |
"""Initialize Bedrock agent.""" | |
global llm_agent # noqa: PLW0603 | |
CONNECT_STATE_DEFAULT.value = True | |
if not access_key or not secret_key: | |
return "❌ Please provide both Access Key ID and Secret Access Key" | |
llm, error = create_bedrock_llm( | |
model_id, | |
access_key.strip(), | |
secret_key.strip(), | |
session_token.strip(), | |
region, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
) | |
if llm is None: | |
return f"❌ Connection failed: {error}" | |
llm_agent = create_react_agent( | |
model=llm, | |
tools=await gr_fetch_mcp_tools( | |
mcp_servers, | |
trace_tools=trace_tool_calls, | |
), | |
prompt=gr_make_system_message(agent_type=agent_type), | |
) | |
return "✅ Successfully connected to AWS Bedrock!" | |
async def gr_connect_to_hf( | |
model_id: str, | |
hf_access_token_textbox: str | None, | |
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
agent_type: AgentType, | |
trace_tool_calls: bool, | |
temperature: float = 0.8, | |
max_tokens: int = 512, | |
) -> str: | |
"""Initialize Hugging Face agent.""" | |
global llm_agent # noqa: PLW0603 | |
CONNECT_STATE_DEFAULT.value = True | |
llm, error = create_hf_llm( | |
model_id, | |
hf_access_token_textbox, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
) | |
if llm is None: | |
return f"❌ Connection failed: {error}" | |
llm_agent = create_react_agent( | |
model=llm, | |
tools=await gr_fetch_mcp_tools( | |
mcp_servers, | |
trace_tools=trace_tool_calls, | |
), | |
prompt=gr_make_system_message(agent_type=agent_type), | |
) | |
return "✅ Successfully connected to Hugging Face!" | |
async def gr_connect_to_azure( # noqa: PLR0913 | |
model_id: str, | |
azure_endpoint: str, | |
api_key: str, | |
api_version: str, | |
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
agent_type: AgentType, | |
trace_tool_calls: bool, | |
temperature: float = 0.8, | |
max_tokens: int = 512, | |
) -> str: | |
"""Initialize Hugging Face agent.""" | |
global llm_agent # noqa: PLW0603 | |
CONNECT_STATE_DEFAULT.value = True | |
llm, error = create_azure_llm( | |
model_id, | |
api_version=api_version, | |
endpoint=azure_endpoint, | |
token_id=api_key, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
) | |
if llm is None: | |
return f"❌ Connection failed: {error}" | |
llm_agent = create_react_agent( | |
model=llm, | |
tools=await gr_fetch_mcp_tools(mcp_servers, trace_tools=trace_tool_calls), | |
prompt=gr_make_system_message(agent_type=agent_type), | |
) | |
return "✅ Successfully connected to Azure OpenAI!" | |
# async def gr_connect_to_nebius( | |
# model_id: str, | |
# nebius_access_token_textbox: str, | |
# mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
# ) -> str: | |
# """Initialize Hugging Face agent.""" | |
# global llm_agent | |
# connected_state.value = True | |
# llm, error = create_openai_llm(model_id, nebius_access_token_textbox) | |
# if llm is None: | |
# return f"❌ Connection failed: {error}" | |
# tools = [] | |
# if mcp_servers: | |
# client = MultiServerMCPClient( | |
# { | |
# server.name.replace(" ", "-"): { | |
# "url": server.value, | |
# "transport": "sse", | |
# } | |
# for server in mcp_servers | |
# }, | |
# ) | |
# tools = await client.get_tools() | |
# llm_agent = create_react_agent( | |
# model=str(llm), | |
# tools=tools, | |
# prompt=SYSTEM_MESSAGE, | |
# ) | |
# return "✅ Successfully connected to nebius!" | |
with open("exfiltration_ticket.txt") as fhandle: # noqa: PTH123 | |
exfiltration_ticket = fhandle.read() | |
with open("sample_kali_linux_1.txt") as fhandle1: # noqa: PTH123 | |
service_discovery_ticket = fhandle1.read() | |
async def gr_chat_function( # noqa: D103 | |
message: str, | |
history: list[Mapping[str, str]], | |
) -> str: | |
if llm_agent is None: | |
return "Please configure your credentials first." | |
messages = [] | |
for hist_msg in history: | |
role = hist_msg["role"] | |
message_type = GRADIO_ROLE_TO_LG_MESSAGE_TYPE[role] | |
messages.append(message_type(content=hist_msg["content"])) | |
messages.append(HumanMessage(content=message)) | |
try: | |
if llm_tools_tracer is not None: | |
llm_tools_tracer.clear() | |
llm_response = await llm_agent.ainvoke( | |
{ | |
"messages": messages, | |
}, | |
) | |
return _add_tools_trace_to_message( | |
llm_response["messages"][-1].content, | |
) | |
except Exception as err: | |
raise gr.Error( | |
f"We encountered an error while invoking the model:\n{err}", | |
print_exception=True, | |
) from err | |
def _add_tools_trace_to_message(message: str) -> str: | |
if not llm_tools_tracer or not llm_tools_tracer.tools_trace: | |
return message | |
import json | |
traces = [] | |
for index, tool_info in enumerate(llm_tools_tracer.tools_trace): | |
trace_msg = f" {index}. {tool_info.name}" | |
if tool_info.inputs: | |
trace_msg += "\n" | |
trace_msg += " * Arguments:\n" | |
trace_msg += " ```json\n" | |
trace_msg += f" {json.dumps(tool_info.inputs, indent=4)}\n" | |
trace_msg += " ```\n" | |
traces.append(trace_msg) | |
return f"{message}\n\n# Tools Trace\n\n" + "\n".join(traces) | |
def _read_markdown_body_as_html(path: str = "README.md") -> str: | |
with Path(path).open(encoding="utf-8") as f: # Default mode is "r" | |
lines = f.readlines() | |
# Skip YAML front matter if present | |
if lines and lines[0].strip() == "---": | |
for i in range(1, len(lines)): | |
if lines[i].strip() == "---": | |
lines = lines[i + 1 :] # skip metadata block | |
break | |
markdown_body = "".join(lines).strip() | |
return markdown.markdown(markdown_body) | |
## UI components ## | |
custom_css = """ | |
.main-header { | |
background: linear-gradient(135deg, #00a388 0%, #ffae00 100%); | |
padding: 30px; | |
border-radius: 5px; | |
margin-bottom: 20px; | |
text-align: center; | |
} | |
""" | |
with ( | |
gr.Blocks( | |
theme=gr_themes.Origin( | |
primary_hue="teal", | |
spacing_size="sm", | |
font="sans-serif", | |
), | |
title="TDAgent", | |
fill_height=True, | |
fill_width=True, | |
css=custom_css, | |
) as gr_app, | |
): | |
gr.HTML( | |
""" | |
<div class="main-header"> | |
<h1>👩💻 TDAgentTools & TDAgent 👨💻</h1> | |
<p style="font-size: 1.2em; margin: 10px 0 0 0;"> | |
Empowering Cybersecurity with Agentic AI | |
</p> | |
</div> | |
""", | |
) | |
with gr.Tabs(): | |
with gr.TabItem("About"), gr.Row(): | |
html_content = _read_markdown_body_as_html("README.md") | |
gr.Markdown(html_content) | |
with gr.TabItem("TDAgent"), gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Accordion("🔌 MCP Servers", open=False): | |
mcp_list = MutableCheckBoxGroup( | |
values=[ | |
MutableCheckBoxGroupEntry( | |
name="TDAgent tools", | |
value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse", | |
), | |
], | |
label="MCP Servers", | |
new_value_label="MCP endpoint", | |
new_name_label="MCP endpoint name", | |
new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse", | |
new_name_placeholder="Swiss army knife of MCPs", | |
) | |
with gr.Accordion("⚙️ Provider Configuration", open=True): | |
model_provider = gr.Dropdown( | |
choices=list(MODEL_OPTIONS.keys()), | |
value=None, | |
label="Select Model Provider", | |
) | |
## Amazon Bedrock Configuration ## | |
with gr.Group(visible=False) as aws_bedrock_conf_group: | |
aws_access_key_textbox = gr.Textbox( | |
label="AWS Access Key ID", | |
type="password", | |
placeholder="Enter your AWS Access Key ID", | |
) | |
aws_secret_key_textbox = gr.Textbox( | |
label="AWS Secret Access Key", | |
type="password", | |
placeholder="Enter your AWS Secret Access Key", | |
) | |
aws_region_dropdown = gr.Dropdown( | |
label="AWS Region", | |
choices=[ | |
"us-east-1", | |
"us-west-2", | |
"eu-west-1", | |
"eu-central-1", | |
"ap-southeast-1", | |
], | |
value="eu-west-1", | |
) | |
aws_session_token_textbox = gr.Textbox( | |
label="AWS Session Token", | |
type="password", | |
placeholder="Enter your AWS session token", | |
) | |
## Huggingface Configuration ## | |
with gr.Group(visible=False) as hf_conf_group: | |
hf_token = gr.Textbox( | |
label="HuggingFace Token", | |
type="password", | |
placeholder="Enter your Hugging Face Access Token", | |
) | |
## Azure Configuration ## | |
with gr.Group(visible=False) as azure_conf_group: | |
azure_endpoint = gr.Textbox( | |
label="Azure OpenAI Endpoint", | |
type="text", | |
placeholder="Enter your Azure OpenAI Endpoint", | |
) | |
azure_api_token = gr.Textbox( | |
label="Azure Access Token", | |
type="password", | |
placeholder="Enter your Azure OpenAI Access Token", | |
) | |
azure_api_version = gr.Textbox( | |
label="Azure OpenAI API Version", | |
type="text", | |
placeholder="Enter your Azure OpenAI API Version", | |
value="2024-12-01-preview", | |
) | |
with gr.Accordion("🧠 Model Configuration", open=True): | |
model_id_dropdown = gr.Dropdown( | |
label="Select known model id or type your own below", | |
choices=[], | |
visible=False, | |
) | |
model_id_textbox = gr.Textbox( | |
label="Model ID", | |
type="text", | |
placeholder="Enter the model ID", | |
visible=False, | |
interactive=True, | |
) | |
# Agent configuration options | |
with gr.Group(): | |
agent_system_message_radio = gr.Radio( | |
choices=list(AGENT_SYSTEM_MESSAGES.keys()), | |
value=next(iter(AGENT_SYSTEM_MESSAGES.keys())), | |
label="Agent type", | |
info=( | |
"Changes the system message to pre-condition the agent" | |
" to act in a desired way." | |
), | |
) | |
agent_trace_tools_checkbox = gr.Checkbox( | |
value=False, | |
label="Trace tool calls", | |
info=( | |
"Add the invoked tools trace at the end of the" | |
" message" | |
), | |
) | |
# Initialize the temperature and max tokens based on model specs | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.8, | |
step=0.1, | |
) | |
max_tokens = gr.Slider( | |
label="Max Tokens", | |
minimum=128, | |
maximum=8192, | |
value=2048, | |
step=64, | |
) | |
connect_aws_bedrock_btn = gr.Button( | |
"🔌 Connect to Bedrock", | |
variant="primary", | |
visible=False, | |
) | |
connect_hf_btn = gr.Button( | |
"🔌 Connect to Huggingface 🤗", | |
variant="primary", | |
visible=False, | |
) | |
connect_azure_btn = gr.Button( | |
"🔌 Connect to Azure", | |
variant="primary", | |
visible=False, | |
) | |
status_textbox = gr.Textbox( | |
label="Connection Status", | |
interactive=False, | |
) | |
with gr.Column(scale=2): | |
chat_interface = gr.ChatInterface( | |
fn=gr_chat_function, | |
type="messages", | |
examples=[exfiltration_ticket, service_discovery_ticket], | |
example_labels=[ | |
"Enrich & Handle exfiltration ticket 🕵️♂️", | |
"Handle service discovery ticket 🤖💻"], | |
description="A simple threat analyst agent with MCP tools.", | |
) | |
with gr.TabItem("Demo"): | |
gr.Markdown( | |
""" | |
This is a demo of TDAgent, a simple threat analyst agent with MCP tools. | |
You can configure the agent to use different LLM providers and connect to | |
various MCP servers to access tools. | |
""", | |
) | |
gr.HTML( | |
"""<iframe width="560" height="315" src="https://www.youtube.com/embed/C6Z9EOW-3lE" frameborder="0" allowfullscreen></iframe>""", # noqa: E501 | |
) | |
## UI Events ## | |
def _toggle_model_choices_ui( | |
provider: str, | |
) -> dict[str, Any]: | |
if provider in MODEL_OPTIONS: | |
model_choices = list(MODEL_OPTIONS[provider].keys()) | |
return gr.update( | |
choices=model_choices, | |
value=model_choices[0], | |
visible=True, | |
interactive=True, | |
) | |
return gr.update(choices=[], visible=False) | |
def _toggle_model_aws_bedrock_conf_ui( | |
provider: str, | |
) -> tuple[dict[str, Any], ...]: | |
is_aws = provider == "AWS Bedrock" | |
return gr.update(visible=is_aws), gr.update(visible=is_aws) | |
def _toggle_model_hf_conf_ui( | |
provider: str, | |
) -> tuple[dict[str, Any], ...]: | |
is_hf = provider == "HuggingFace" | |
return gr.update(visible=is_hf), gr.update(visible=is_hf) | |
def _toggle_model_azure_conf_ui( | |
provider: str, | |
) -> tuple[dict[str, Any], ...]: | |
is_azure = provider == "Azure OpenAI" | |
return gr.update(visible=is_azure), gr.update(visible=is_azure) | |
# Initialize a flag to check if connected | |
def _on_change_model_configuration(*args: str) -> Any: # noqa: ARG001 | |
# If model configuration changes after connecting, issue a warning | |
if CONNECT_STATE_DEFAULT.value: | |
CONNECT_STATE_DEFAULT.value = False # Reset the state | |
return gr.Warning( | |
"When changing model configuration, you need to reconnect.", | |
duration=5, | |
) | |
return gr.update() | |
## Connect Event Listeners ## | |
model_provider.change( | |
_toggle_model_choices_ui, | |
inputs=[model_provider], | |
outputs=[model_id_dropdown], | |
) | |
model_provider.change( | |
_toggle_model_aws_bedrock_conf_ui, | |
inputs=[model_provider], | |
outputs=[aws_bedrock_conf_group, connect_aws_bedrock_btn], | |
) | |
model_provider.change( | |
_toggle_model_hf_conf_ui, | |
inputs=[model_provider], | |
outputs=[hf_conf_group, connect_hf_btn], | |
) | |
model_provider.change( | |
_toggle_model_azure_conf_ui, | |
inputs=[model_provider], | |
outputs=[azure_conf_group, connect_azure_btn], | |
) | |
connect_aws_bedrock_btn.click( | |
gr_connect_to_bedrock, | |
inputs=[ | |
model_id_textbox, | |
aws_access_key_textbox, | |
aws_secret_key_textbox, | |
aws_session_token_textbox, | |
aws_region_dropdown, | |
mcp_list.state, | |
agent_system_message_radio, | |
agent_trace_tools_checkbox, | |
temperature, | |
max_tokens, | |
], | |
outputs=[status_textbox], | |
) | |
connect_hf_btn.click( | |
gr_connect_to_hf, | |
inputs=[ | |
model_id_textbox, | |
hf_token, | |
mcp_list.state, | |
agent_system_message_radio, | |
agent_trace_tools_checkbox, | |
temperature, | |
max_tokens, | |
], | |
outputs=[status_textbox], | |
) | |
connect_azure_btn.click( | |
gr_connect_to_azure, | |
inputs=[ | |
model_id_textbox, | |
azure_endpoint, | |
azure_api_token, | |
azure_api_version, | |
mcp_list.state, | |
agent_system_message_radio, | |
agent_trace_tools_checkbox, | |
temperature, | |
max_tokens, | |
], | |
outputs=[status_textbox], | |
) | |
model_id_dropdown.change( | |
lambda x, y: ( | |
gr.update( | |
value=MODEL_OPTIONS.get(y, {}).get(x), | |
visible=True, | |
) | |
if x | |
else model_id_textbox.value | |
), | |
inputs=[model_id_dropdown, model_provider], | |
outputs=[model_id_textbox], | |
) | |
model_provider.change( | |
_on_change_model_configuration, | |
inputs=[model_provider], | |
) | |
model_id_dropdown.change( | |
_on_change_model_configuration, | |
inputs=[model_id_dropdown, model_provider], | |
) | |
## Entry Point ## | |
if __name__ == "__main__": | |
gr_app.launch() | |