TDAgent / tdagent /grchat.py
Pedro Bento
Fix to BedrockChat model invoke
1711059
from __future__ import annotations
import dataclasses
import enum
import os
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from pathlib import Path
from types import MappingProxyType
from typing import TYPE_CHECKING, Any
import boto3
import botocore
import botocore.exceptions
import gradio as gr
import gradio.themes as gr_themes
import markdown
from langchain_aws import ChatBedrock
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_openai import AzureChatOpenAI
from langgraph.prebuilt import create_react_agent
from openai import OpenAI
from openai.types.chat import ChatCompletion
from tdagent.grcomponents import MutableCheckBoxGroup, MutableCheckBoxGroupEntry
if TYPE_CHECKING:
from langgraph.graph.graph import CompiledGraph
#### Constants ####
class AgentType(str, enum.Enum):
"""TDAgent type."""
DATA_ENRICHER = "Data enricher"
INCIDENT_HANDLER = "Incident handler"
PEN_TESTER = "PenTester"
def __str__(self) -> str: # noqa: D105
return self.value
AGENT_SYSTEM_MESSAGES = OrderedDict(
(
(
AgentType.DATA_ENRICHER,
"""
You are a cybersecurity incidence data enriching assistant. Analysts
will present information about security incidents and you must use
all the tools at your disposal to enrich the data as much as possible.
""".strip(),
),
(
AgentType.INCIDENT_HANDLER,
"""
You are a security analyst assistant responsible for collecting, analyzing
and disseminating actionable intelligence related to cyber threats,
vulnerabilities and threat actors.
When presented with potential incidents information or tickets, you should
evaluate the presented evidence, gather additional data using any tool at
your disposal and take corrective actions if possible.
Afterwards, generate a cybersecurity report including: key findings, challenges,
actions taken and recommendations.
Never use external means of communication, like emails or SMS, unless
instructed to do so.
""".strip(),
),
(
AgentType.PEN_TESTER,
"""
You are a cybersecurity pentester. You use tools to analyze domain to try to discover system vulnerabilities.
Always report you findings and suggest next steps to deep dive where applicable.
""".strip(),
),
),
)
GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType(
{
"user": HumanMessage,
"assistant": AIMessage,
},
)
MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
(
(
"HuggingFace",
{
"Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3",
"Llama 3.1 8B Instruct": "meta-llama/Llama-3.1-8B-Instruct",
# "Qwen3 235B A22B": "Qwen/Qwen3-235B-A22B", # Slow inference
"Microsoft Phi-3.5-mini Instruct": "microsoft/Phi-3.5-mini-instruct",
# "Deepseek R1 distill-llama 70B": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", # noqa: E501
# "Deepseek V3": "deepseek-ai/DeepSeek-V3",
},
),
(
"AWS Bedrock",
{
"Anthropic Claude 3.5 Sonnet (EU)": (
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0"
),
"Anthropic Claude 3.7 Sonnet": (
"anthropic.claude-3-7-sonnet-20250219-v1:0"
),
"Claude Sonnet 4": (
"anthropic.claude-sonnet-4-20250514-v1:0"
),
},
),
(
"Azure OpenAI",
{
"GPT-4o": ("ggpt-4o-global-standard"),
"GPT-4o Mini": ("o4-mini"),
"GPT-4.5 Preview": ("gpt-4.5-preview"),
},
),
),
)
CONNECT_STATE_DEFAULT = gr.State()
@dataclasses.dataclass
class ToolInvocationInfo:
"""Information related to a tool invocation by the LLM."""
name: str
inputs: Mapping[str, Any]
class ToolsTracerCallback(BaseCallbackHandler):
"""Callback that registers tools invoked by the Agent."""
def __init__(self) -> None:
self._tools_trace: list[ToolInvocationInfo] = []
def on_tool_start( # noqa: D102
self,
serialized: dict[str, Any],
*args: Any,
inputs: dict[str, Any] | None = None,
**kwargs: Any,
) -> Any:
self._tools_trace.append(
ToolInvocationInfo(
name=serialized.get("name", "<unknown-function-name>"),
inputs=inputs if inputs else {},
),
)
return super().on_tool_start(serialized, *args, inputs=inputs, **kwargs)
@property
def tools_trace(self) -> Sequence[ToolInvocationInfo]:
"""Tools trace information."""
return self._tools_trace
def clear(self) -> None:
"""Clear tools trace."""
self._tools_trace.clear()
#### Shared variables ####
llm_agent: CompiledGraph | None = None
llm_tools_tracer: ToolsTracerCallback | None = None
#### Utility functions ####
## Bedrock LLM creation ##
def create_bedrock_llm(
bedrock_model_id: str,
aws_access_key: str,
aws_secret_key: str,
aws_session_token: str,
aws_region: str,
temperature: float = 0.8,
max_tokens: int = 512,
) -> tuple[ChatBedrock | None, str]:
"""Create a LangGraph Bedrock agent."""
boto3_config = {
"aws_access_key_id": aws_access_key,
"aws_secret_access_key": aws_secret_key,
"aws_session_token": aws_session_token if aws_session_token else None,
"region_name": aws_region,
}
# Verify credentials
try:
sts = boto3.client("sts", **boto3_config)
sts.get_caller_identity()
except botocore.exceptions.ClientError as err:
return None, str(err)
try:
bedrock_client = boto3.client("bedrock-runtime", **boto3_config)
llm = ChatBedrock(
model=bedrock_model_id,
client=bedrock_client,
model_kwargs={"temperature": temperature, "max_tokens": max_tokens},
)
except Exception as e: # noqa: BLE001
return None, str(e)
return llm, ""
## Hugging Face LLM creation ##
def create_hf_llm(
hf_model_id: str,
huggingfacehub_api_token: str | None = None,
temperature: float = 0.8,
max_tokens: int = 512,
) -> tuple[ChatHuggingFace | None, str]:
"""Create a LangGraph Hugging Face agent."""
try:
llm = HuggingFaceEndpoint(
model=hf_model_id,
temperature=temperature,
max_new_tokens=max_tokens,
task="text-generation",
huggingfacehub_api_token=huggingfacehub_api_token,
)
chat_llm = ChatHuggingFace(llm=llm)
except Exception as e: # noqa: BLE001
return None, str(e)
return chat_llm, ""
## OpenAI LLM creation ##
def create_openai_llm(
model_id: str,
token_id: str,
) -> tuple[ChatCompletion | None, str]:
"""Create a LangGraph OpenAI agent."""
try:
client = OpenAI(
base_url="https://api.studio.nebius.com/v1/",
api_key=token_id,
)
llm = client.chat.completions.create(
messages=[], # needs to be fixed
model=model_id,
max_tokens=512,
temperature=0.8,
)
except Exception as e: # noqa: BLE001
return None, str(e)
return llm, ""
def create_azure_llm(
model_id: str,
api_version: str,
endpoint: str,
token_id: str,
temperature: float = 0.8,
max_tokens: int = 512,
) -> tuple[AzureChatOpenAI | None, str]:
"""Create a LangGraph Azure OpenAI agent."""
try:
os.environ["AZURE_OPENAI_ENDPOINT"] = endpoint
os.environ["AZURE_OPENAI_API_KEY"] = token_id
if "o4-mini" in model_id:
kwargs = {"max_completion_tokens": max_tokens}
else:
kwargs = {"max_tokens": max_tokens}
llm = AzureChatOpenAI(
azure_deployment=model_id,
api_key=token_id,
api_version=api_version,
temperature=temperature,
**kwargs,
)
except Exception as e: # noqa: BLE001
return None, str(e)
return llm, ""
#### UI functionality ####
async def gr_fetch_mcp_tools(
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
*,
trace_tools: bool,
) -> list[BaseTool]:
"""Fetch tools from MCP servers."""
global llm_tools_tracer # noqa: PLW0603
if mcp_servers:
client = MultiServerMCPClient(
{
server.name.replace(" ", "-"): {
"url": server.value,
"transport": "sse",
}
for server in mcp_servers
},
)
tools = await client.get_tools()
if trace_tools:
llm_tools_tracer = ToolsTracerCallback()
for tool in tools:
if tool.callbacks is None:
tool.callbacks = [llm_tools_tracer]
elif isinstance(tool.callbacks, list):
tool.callbacks.append(llm_tools_tracer)
else:
tool.callbacks.add_handler(llm_tools_tracer)
else:
llm_tools_tracer = None
return tools
return []
def gr_make_system_message(
agent_type: AgentType,
) -> SystemMessage:
"""Make agent's system message."""
try:
system_msg = AGENT_SYSTEM_MESSAGES[agent_type]
except KeyError as err:
raise gr.Error(f"Unknown agent type '{agent_type}'") from err
return SystemMessage(system_msg)
async def gr_connect_to_bedrock( # noqa: PLR0913
model_id: str,
access_key: str,
secret_key: str,
session_token: str,
region: str,
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
agent_type: AgentType,
trace_tool_calls: bool,
temperature: float = 0.8,
max_tokens: int = 512,
) -> str:
"""Initialize Bedrock agent."""
global llm_agent # noqa: PLW0603
CONNECT_STATE_DEFAULT.value = True
if not access_key or not secret_key:
return "❌ Please provide both Access Key ID and Secret Access Key"
llm, error = create_bedrock_llm(
model_id,
access_key.strip(),
secret_key.strip(),
session_token.strip(),
region,
temperature=temperature,
max_tokens=max_tokens,
)
if llm is None:
return f"❌ Connection failed: {error}"
llm_agent = create_react_agent(
model=llm,
tools=await gr_fetch_mcp_tools(
mcp_servers,
trace_tools=trace_tool_calls,
),
prompt=gr_make_system_message(agent_type=agent_type),
)
return "✅ Successfully connected to AWS Bedrock!"
async def gr_connect_to_hf(
model_id: str,
hf_access_token_textbox: str | None,
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
agent_type: AgentType,
trace_tool_calls: bool,
temperature: float = 0.8,
max_tokens: int = 512,
) -> str:
"""Initialize Hugging Face agent."""
global llm_agent # noqa: PLW0603
CONNECT_STATE_DEFAULT.value = True
llm, error = create_hf_llm(
model_id,
hf_access_token_textbox,
temperature=temperature,
max_tokens=max_tokens,
)
if llm is None:
return f"❌ Connection failed: {error}"
llm_agent = create_react_agent(
model=llm,
tools=await gr_fetch_mcp_tools(
mcp_servers,
trace_tools=trace_tool_calls,
),
prompt=gr_make_system_message(agent_type=agent_type),
)
return "✅ Successfully connected to Hugging Face!"
async def gr_connect_to_azure( # noqa: PLR0913
model_id: str,
azure_endpoint: str,
api_key: str,
api_version: str,
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
agent_type: AgentType,
trace_tool_calls: bool,
temperature: float = 0.8,
max_tokens: int = 512,
) -> str:
"""Initialize Hugging Face agent."""
global llm_agent # noqa: PLW0603
CONNECT_STATE_DEFAULT.value = True
llm, error = create_azure_llm(
model_id,
api_version=api_version,
endpoint=azure_endpoint,
token_id=api_key,
temperature=temperature,
max_tokens=max_tokens,
)
if llm is None:
return f"❌ Connection failed: {error}"
llm_agent = create_react_agent(
model=llm,
tools=await gr_fetch_mcp_tools(mcp_servers, trace_tools=trace_tool_calls),
prompt=gr_make_system_message(agent_type=agent_type),
)
return "✅ Successfully connected to Azure OpenAI!"
# async def gr_connect_to_nebius(
# model_id: str,
# nebius_access_token_textbox: str,
# mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None,
# ) -> str:
# """Initialize Hugging Face agent."""
# global llm_agent
# connected_state.value = True
# llm, error = create_openai_llm(model_id, nebius_access_token_textbox)
# if llm is None:
# return f"❌ Connection failed: {error}"
# tools = []
# if mcp_servers:
# client = MultiServerMCPClient(
# {
# server.name.replace(" ", "-"): {
# "url": server.value,
# "transport": "sse",
# }
# for server in mcp_servers
# },
# )
# tools = await client.get_tools()
# llm_agent = create_react_agent(
# model=str(llm),
# tools=tools,
# prompt=SYSTEM_MESSAGE,
# )
# return "✅ Successfully connected to nebius!"
with open("exfiltration_ticket.txt") as fhandle: # noqa: PTH123
exfiltration_ticket = fhandle.read()
with open("sample_kali_linux_1.txt") as fhandle1: # noqa: PTH123
service_discovery_ticket = fhandle1.read()
async def gr_chat_function( # noqa: D103
message: str,
history: list[Mapping[str, str]],
) -> str:
if llm_agent is None:
return "Please configure your credentials first."
messages = []
for hist_msg in history:
role = hist_msg["role"]
message_type = GRADIO_ROLE_TO_LG_MESSAGE_TYPE[role]
messages.append(message_type(content=hist_msg["content"]))
messages.append(HumanMessage(content=message))
try:
if llm_tools_tracer is not None:
llm_tools_tracer.clear()
llm_response = await llm_agent.ainvoke(
{
"messages": messages,
},
)
return _add_tools_trace_to_message(
llm_response["messages"][-1].content,
)
except Exception as err:
raise gr.Error(
f"We encountered an error while invoking the model:\n{err}",
print_exception=True,
) from err
def _add_tools_trace_to_message(message: str) -> str:
if not llm_tools_tracer or not llm_tools_tracer.tools_trace:
return message
import json
traces = []
for index, tool_info in enumerate(llm_tools_tracer.tools_trace):
trace_msg = f" {index}. {tool_info.name}"
if tool_info.inputs:
trace_msg += "\n"
trace_msg += " * Arguments:\n"
trace_msg += " ```json\n"
trace_msg += f" {json.dumps(tool_info.inputs, indent=4)}\n"
trace_msg += " ```\n"
traces.append(trace_msg)
return f"{message}\n\n# Tools Trace\n\n" + "\n".join(traces)
def _read_markdown_body_as_html(path: str = "README.md") -> str:
with Path(path).open(encoding="utf-8") as f: # Default mode is "r"
lines = f.readlines()
# Skip YAML front matter if present
if lines and lines[0].strip() == "---":
for i in range(1, len(lines)):
if lines[i].strip() == "---":
lines = lines[i + 1 :] # skip metadata block
break
markdown_body = "".join(lines).strip()
return markdown.markdown(markdown_body)
## UI components ##
custom_css = """
.main-header {
background: linear-gradient(135deg, #00a388 0%, #ffae00 100%);
padding: 30px;
border-radius: 5px;
margin-bottom: 20px;
text-align: center;
}
"""
with (
gr.Blocks(
theme=gr_themes.Origin(
primary_hue="teal",
spacing_size="sm",
font="sans-serif",
),
title="TDAgent",
fill_height=True,
fill_width=True,
css=custom_css,
) as gr_app,
):
gr.HTML(
"""
<div class="main-header">
<h1>👩‍💻 TDAgentTools & TDAgent 👨‍💻</h1>
<p style="font-size: 1.2em; margin: 10px 0 0 0;">
Empowering Cybersecurity with Agentic AI
</p>
</div>
""",
)
with gr.Tabs():
with gr.TabItem("About"), gr.Row():
html_content = _read_markdown_body_as_html("README.md")
gr.Markdown(html_content)
with gr.TabItem("TDAgent"), gr.Row():
with gr.Column(scale=1):
with gr.Accordion("🔌 MCP Servers", open=False):
mcp_list = MutableCheckBoxGroup(
values=[
MutableCheckBoxGroupEntry(
name="TDAgent tools",
value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse",
),
],
label="MCP Servers",
new_value_label="MCP endpoint",
new_name_label="MCP endpoint name",
new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse",
new_name_placeholder="Swiss army knife of MCPs",
)
with gr.Accordion("⚙️ Provider Configuration", open=True):
model_provider = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value=None,
label="Select Model Provider",
)
## Amazon Bedrock Configuration ##
with gr.Group(visible=False) as aws_bedrock_conf_group:
aws_access_key_textbox = gr.Textbox(
label="AWS Access Key ID",
type="password",
placeholder="Enter your AWS Access Key ID",
)
aws_secret_key_textbox = gr.Textbox(
label="AWS Secret Access Key",
type="password",
placeholder="Enter your AWS Secret Access Key",
)
aws_region_dropdown = gr.Dropdown(
label="AWS Region",
choices=[
"us-east-1",
"us-west-2",
"eu-west-1",
"eu-central-1",
"ap-southeast-1",
],
value="eu-west-1",
)
aws_session_token_textbox = gr.Textbox(
label="AWS Session Token",
type="password",
placeholder="Enter your AWS session token",
)
## Huggingface Configuration ##
with gr.Group(visible=False) as hf_conf_group:
hf_token = gr.Textbox(
label="HuggingFace Token",
type="password",
placeholder="Enter your Hugging Face Access Token",
)
## Azure Configuration ##
with gr.Group(visible=False) as azure_conf_group:
azure_endpoint = gr.Textbox(
label="Azure OpenAI Endpoint",
type="text",
placeholder="Enter your Azure OpenAI Endpoint",
)
azure_api_token = gr.Textbox(
label="Azure Access Token",
type="password",
placeholder="Enter your Azure OpenAI Access Token",
)
azure_api_version = gr.Textbox(
label="Azure OpenAI API Version",
type="text",
placeholder="Enter your Azure OpenAI API Version",
value="2024-12-01-preview",
)
with gr.Accordion("🧠 Model Configuration", open=True):
model_id_dropdown = gr.Dropdown(
label="Select known model id or type your own below",
choices=[],
visible=False,
)
model_id_textbox = gr.Textbox(
label="Model ID",
type="text",
placeholder="Enter the model ID",
visible=False,
interactive=True,
)
# Agent configuration options
with gr.Group():
agent_system_message_radio = gr.Radio(
choices=list(AGENT_SYSTEM_MESSAGES.keys()),
value=next(iter(AGENT_SYSTEM_MESSAGES.keys())),
label="Agent type",
info=(
"Changes the system message to pre-condition the agent"
" to act in a desired way."
),
)
agent_trace_tools_checkbox = gr.Checkbox(
value=False,
label="Trace tool calls",
info=(
"Add the invoked tools trace at the end of the"
" message"
),
)
# Initialize the temperature and max tokens based on model specs
temperature = gr.Slider(
label="Temperature",
minimum=0.0,
maximum=1.0,
value=0.8,
step=0.1,
)
max_tokens = gr.Slider(
label="Max Tokens",
minimum=128,
maximum=8192,
value=2048,
step=64,
)
connect_aws_bedrock_btn = gr.Button(
"🔌 Connect to Bedrock",
variant="primary",
visible=False,
)
connect_hf_btn = gr.Button(
"🔌 Connect to Huggingface 🤗",
variant="primary",
visible=False,
)
connect_azure_btn = gr.Button(
"🔌 Connect to Azure",
variant="primary",
visible=False,
)
status_textbox = gr.Textbox(
label="Connection Status",
interactive=False,
)
with gr.Column(scale=2):
chat_interface = gr.ChatInterface(
fn=gr_chat_function,
type="messages",
examples=[exfiltration_ticket, service_discovery_ticket],
example_labels=[
"Enrich & Handle exfiltration ticket 🕵️‍♂️",
"Handle service discovery ticket 🤖💻"],
description="A simple threat analyst agent with MCP tools.",
)
with gr.TabItem("Demo"):
gr.Markdown(
"""
This is a demo of TDAgent, a simple threat analyst agent with MCP tools.
You can configure the agent to use different LLM providers and connect to
various MCP servers to access tools.
""",
)
gr.HTML(
"""<iframe width="560" height="315" src="https://www.youtube.com/embed/C6Z9EOW-3lE" frameborder="0" allowfullscreen></iframe>""", # noqa: E501
)
## UI Events ##
def _toggle_model_choices_ui(
provider: str,
) -> dict[str, Any]:
if provider in MODEL_OPTIONS:
model_choices = list(MODEL_OPTIONS[provider].keys())
return gr.update(
choices=model_choices,
value=model_choices[0],
visible=True,
interactive=True,
)
return gr.update(choices=[], visible=False)
def _toggle_model_aws_bedrock_conf_ui(
provider: str,
) -> tuple[dict[str, Any], ...]:
is_aws = provider == "AWS Bedrock"
return gr.update(visible=is_aws), gr.update(visible=is_aws)
def _toggle_model_hf_conf_ui(
provider: str,
) -> tuple[dict[str, Any], ...]:
is_hf = provider == "HuggingFace"
return gr.update(visible=is_hf), gr.update(visible=is_hf)
def _toggle_model_azure_conf_ui(
provider: str,
) -> tuple[dict[str, Any], ...]:
is_azure = provider == "Azure OpenAI"
return gr.update(visible=is_azure), gr.update(visible=is_azure)
# Initialize a flag to check if connected
def _on_change_model_configuration(*args: str) -> Any: # noqa: ARG001
# If model configuration changes after connecting, issue a warning
if CONNECT_STATE_DEFAULT.value:
CONNECT_STATE_DEFAULT.value = False # Reset the state
return gr.Warning(
"When changing model configuration, you need to reconnect.",
duration=5,
)
return gr.update()
## Connect Event Listeners ##
model_provider.change(
_toggle_model_choices_ui,
inputs=[model_provider],
outputs=[model_id_dropdown],
)
model_provider.change(
_toggle_model_aws_bedrock_conf_ui,
inputs=[model_provider],
outputs=[aws_bedrock_conf_group, connect_aws_bedrock_btn],
)
model_provider.change(
_toggle_model_hf_conf_ui,
inputs=[model_provider],
outputs=[hf_conf_group, connect_hf_btn],
)
model_provider.change(
_toggle_model_azure_conf_ui,
inputs=[model_provider],
outputs=[azure_conf_group, connect_azure_btn],
)
connect_aws_bedrock_btn.click(
gr_connect_to_bedrock,
inputs=[
model_id_textbox,
aws_access_key_textbox,
aws_secret_key_textbox,
aws_session_token_textbox,
aws_region_dropdown,
mcp_list.state,
agent_system_message_radio,
agent_trace_tools_checkbox,
temperature,
max_tokens,
],
outputs=[status_textbox],
)
connect_hf_btn.click(
gr_connect_to_hf,
inputs=[
model_id_textbox,
hf_token,
mcp_list.state,
agent_system_message_radio,
agent_trace_tools_checkbox,
temperature,
max_tokens,
],
outputs=[status_textbox],
)
connect_azure_btn.click(
gr_connect_to_azure,
inputs=[
model_id_textbox,
azure_endpoint,
azure_api_token,
azure_api_version,
mcp_list.state,
agent_system_message_radio,
agent_trace_tools_checkbox,
temperature,
max_tokens,
],
outputs=[status_textbox],
)
model_id_dropdown.change(
lambda x, y: (
gr.update(
value=MODEL_OPTIONS.get(y, {}).get(x),
visible=True,
)
if x
else model_id_textbox.value
),
inputs=[model_id_dropdown, model_provider],
outputs=[model_id_textbox],
)
model_provider.change(
_on_change_model_configuration,
inputs=[model_provider],
)
model_id_dropdown.change(
_on_change_model_configuration,
inputs=[model_id_dropdown, model_provider],
)
## Entry Point ##
if __name__ == "__main__":
gr_app.launch()