Spaces:
Running
Running
from __future__ import annotations | |
from collections import OrderedDict | |
from collections.abc import Mapping, Sequence | |
from types import MappingProxyType | |
from typing import TYPE_CHECKING, Any | |
import boto3 | |
import botocore | |
import botocore.exceptions | |
import gradio as gr | |
import gradio.themes as gr_themes | |
from langchain_aws import ChatBedrock | |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint | |
from langchain_mcp_adapters.client import MultiServerMCPClient | |
from langgraph.prebuilt import create_react_agent | |
from openai import OpenAI | |
from openai.types.chat import ChatCompletion | |
from tdagent.grcomponents import MutableCheckBoxGroup, MutableCheckBoxGroupEntry | |
if TYPE_CHECKING: | |
from langgraph.graph.graph import CompiledGraph | |
#### Constants #### | |
SYSTEM_MESSAGE = SystemMessage( | |
""" | |
You are a security analyst assistant responsible for collecting, analyzing | |
and disseminating actionable intelligence related to cyber threats, | |
vulnerabilities and threat actors. | |
When presented with potential incidents information or tickets, you should | |
evaluate the presented evidence, decide what is missing and gather | |
additional data using any tool at your disposal. After gathering more | |
information you must evaluate if the incident is a threat or | |
not and, if possible, remediation actions. | |
You must always present the conducted analysis and final conclusion. | |
Never use external means of communication, like emails or SMS, unless | |
instructed to do so. | |
""".strip(), | |
) | |
GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType( | |
{ | |
"user": HumanMessage, | |
"assistant": AIMessage, | |
}, | |
) | |
MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order | |
( | |
( | |
"HuggingFace", | |
{ | |
"Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3", | |
"Llama 3.1 8B Instruct": "meta-llama/Llama-3.1-8B-Instruct", | |
# "Qwen3 235B A22B": "Qwen/Qwen3-235B-A22B", # Slow inference | |
"Microsoft Phi-3.5-mini Instruct": "microsoft/Phi-3.5-mini-instruct", | |
# "Deepseek R1 distill-llama 70B": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", # noqa: E501 | |
# "Deepseek V3": "deepseek-ai/DeepSeek-V3", | |
}, | |
), | |
( | |
"AWS Bedrock", | |
{ | |
"Anthropic Claude 3.5 Sonnet (EU)": ( | |
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0" | |
), | |
# "Anthropic Claude 3.7 Sonnet": ( | |
# "anthropic.claude-3-7-sonnet-20250219-v1:0" | |
# ), | |
}, | |
), | |
), | |
) | |
#### Shared variables #### | |
llm_agent: CompiledGraph | None = None | |
#### Utility functions #### | |
## Bedrock LLM creation ## | |
def create_bedrock_llm( | |
bedrock_model_id: str, | |
aws_access_key: str, | |
aws_secret_key: str, | |
aws_session_token: str, | |
aws_region: str, | |
temperature: float = 0.8, | |
max_tokens: int = 512, | |
) -> tuple[ChatBedrock | None, str]: | |
"""Create a LangGraph Bedrock agent.""" | |
boto3_config = { | |
"aws_access_key_id": aws_access_key, | |
"aws_secret_access_key": aws_secret_key, | |
"aws_session_token": aws_session_token if aws_session_token else None, | |
"region_name": aws_region, | |
} | |
# Verify credentials | |
try: | |
sts = boto3.client("sts", **boto3_config) | |
sts.get_caller_identity() | |
except botocore.exceptions.ClientError as err: | |
return None, str(err) | |
try: | |
bedrock_client = boto3.client("bedrock-runtime", **boto3_config) | |
llm = ChatBedrock( | |
model_id=bedrock_model_id, | |
client=bedrock_client, | |
model_kwargs={"temperature": temperature, "max_tokens": max_tokens}, | |
) | |
except Exception as e: # noqa: BLE001 | |
return None, str(e) | |
return llm, "" | |
## Hugging Face LLM creation ## | |
def create_hf_llm( | |
hf_model_id: str, | |
huggingfacehub_api_token: str | None = None, | |
) -> tuple[ChatHuggingFace | None, str]: | |
"""Create a LangGraph Hugging Face agent.""" | |
try: | |
llm = HuggingFaceEndpoint( | |
model=hf_model_id, | |
temperature=0.8, | |
task="text-generation", | |
huggingfacehub_api_token=huggingfacehub_api_token, | |
) | |
chat_llm = ChatHuggingFace(llm=llm) | |
except Exception as e: # noqa: BLE001 | |
return None, str(e) | |
return chat_llm, "" | |
## OpenAI LLM creation ## | |
def create_openai_llm( | |
model_id: str, | |
token_id: str, | |
) -> tuple[ChatCompletion | None, str]: | |
"""Create a LangGraph OpenAI agent.""" | |
try: | |
client = OpenAI( | |
base_url="https://api.studio.nebius.com/v1/", | |
api_key=token_id, | |
) | |
llm = client.chat.completions.create( | |
messages=[], # needs to be fixed | |
model=model_id, | |
max_tokens=512, | |
temperature=0.8, | |
) | |
except Exception as e: # noqa: BLE001 | |
return None, str(e) | |
return llm, "" | |
#### UI functionality #### | |
async def gr_connect_to_bedrock( # noqa: PLR0913 | |
model_id: str, | |
access_key: str, | |
secret_key: str, | |
session_token: str, | |
region: str, | |
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
temperature: float = 0.8, | |
max_tokens: int = 512, | |
) -> str: | |
"""Initialize Bedrock agent.""" | |
global llm_agent # noqa: PLW0603 | |
if not access_key or not secret_key: | |
return "β Please provide both Access Key ID and Secret Access Key" | |
llm, error = create_bedrock_llm( | |
model_id, | |
access_key.strip(), | |
secret_key.strip(), | |
session_token.strip(), | |
region, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
) | |
if llm is None: | |
return f"β Connection failed: {error}" | |
# client = MultiServerMCPClient( | |
# { | |
# "toolkit": { | |
# "url": "https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse", | |
# "transport": "sse", | |
# }, | |
# } | |
# ) | |
# tools = await client.get_tools() | |
if mcp_servers: | |
client = MultiServerMCPClient( | |
{ | |
server.name.replace(" ", "-"): { | |
"url": server.value, | |
"transport": "sse", | |
} | |
for server in mcp_servers | |
}, | |
) | |
tools = await client.get_tools() | |
else: | |
tools = [] | |
llm_agent = create_react_agent( | |
model=llm, | |
tools=tools, | |
prompt=SYSTEM_MESSAGE, | |
) | |
return "β Successfully connected to AWS Bedrock!" | |
async def gr_connect_to_hf( | |
model_id: str, | |
hf_access_token_textbox: str | None, | |
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
) -> str: | |
"""Initialize Hugging Face agent.""" | |
global llm_agent # noqa: PLW0603 | |
llm, error = create_hf_llm(model_id, hf_access_token_textbox) | |
if llm is None: | |
return f"β Connection failed: {error}" | |
tools = [] | |
if mcp_servers: | |
client = MultiServerMCPClient( | |
{ | |
server.name.replace(" ", "-"): { | |
"url": server.value, | |
"transport": "sse", | |
} | |
for server in mcp_servers | |
}, | |
) | |
tools = await client.get_tools() | |
llm_agent = create_react_agent( | |
model=llm, | |
tools=tools, | |
prompt=SYSTEM_MESSAGE, | |
) | |
return "β Successfully connected to Hugging Face!" | |
async def gr_connect_to_nebius( | |
model_id: str, | |
nebius_access_token_textbox: str, | |
mcp_servers: Sequence[MutableCheckBoxGroupEntry] | None, | |
) -> str: | |
"""Initialize Hugging Face agent.""" | |
global llm_agent # noqa: PLW0603 | |
llm, error = create_openai_llm(model_id, nebius_access_token_textbox) | |
if llm is None: | |
return f"β Connection failed: {error}" | |
tools = [] | |
if mcp_servers: | |
client = MultiServerMCPClient( | |
{ | |
server.name.replace(" ", "-"): { | |
"url": server.value, | |
"transport": "sse", | |
} | |
for server in mcp_servers | |
}, | |
) | |
tools = await client.get_tools() | |
llm_agent = create_react_agent( | |
model=str(llm), | |
tools=tools, | |
prompt=SYSTEM_MESSAGE, | |
) | |
return "β Successfully connected to nebius!" | |
async def gr_chat_function( # noqa: D103 | |
message: str, | |
history: list[Mapping[str, str]], | |
) -> str: | |
if llm_agent is None: | |
return "Please configure your credentials first." | |
messages = [] | |
for hist_msg in history: | |
role = hist_msg["role"] | |
message_type = GRADIO_ROLE_TO_LG_MESSAGE_TYPE[role] | |
messages.append(message_type(content=hist_msg["content"])) | |
messages.append(HumanMessage(content=message)) | |
try: | |
llm_response = await llm_agent.ainvoke( | |
{ | |
"messages": messages, | |
}, | |
) | |
return llm_response["messages"][-1].content | |
except Exception as err: | |
raise gr.Error( | |
f"We encountered an error while invoking the model:\n{err}", | |
print_exception=True, | |
) from err | |
## UI components ## | |
# Function to toggle visibility and set model IDs | |
def toggle_model_fields( | |
provider: str, | |
) -> tuple[ | |
dict[str, Any], | |
dict[str, Any], | |
dict[str, Any], | |
dict[str, Any], | |
dict[str, Any], | |
dict[str, Any], | |
]: # ignore: F821 | |
"""Toggle visibility of model fields based on the selected provider.""" | |
# Update model choices based on the selected provider | |
if provider in MODEL_OPTIONS: | |
model_choices = list(MODEL_OPTIONS[provider].keys()) | |
model_pretty = gr.update( | |
choices=model_choices, | |
value=model_choices[0], | |
visible=True, | |
interactive=True, | |
) | |
else: | |
model_pretty = gr.update(choices=[], visible=False) | |
# Visibility settings for fields specific to each provider | |
is_aws = provider == "AWS Bedrock" | |
is_hf = provider == "HuggingFace" | |
return ( | |
model_pretty, | |
gr.update(visible=is_aws, interactive=is_aws), | |
gr.update(visible=is_aws, interactive=is_aws), | |
gr.update(visible=is_aws, interactive=is_aws), | |
gr.update(visible=is_aws, interactive=is_aws), | |
gr.update(visible=is_hf, interactive=is_hf), | |
) | |
async def update_connection_status( # noqa: PLR0913 | |
provider: str, | |
pretty_model: str, | |
mcp_list_state: Sequence[MutableCheckBoxGroupEntry] | None, | |
aws_access_key_textbox: str, | |
aws_secret_key_textbox: str, | |
aws_session_token_textbox: str, | |
aws_region_dropdown: str, | |
hf_token: str, | |
temperature: float, | |
max_tokens: int, | |
) -> str: | |
"""Update the connection status based on the selected provider and model.""" | |
if not provider or not pretty_model: | |
return "β Please select a provider and model." | |
model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model) | |
connection = "β Invalid provider" | |
if model_id: | |
if provider == "AWS Bedrock": | |
connection = await gr_connect_to_bedrock( | |
model_id, | |
aws_access_key_textbox, | |
aws_secret_key_textbox, | |
aws_session_token_textbox, | |
aws_region_dropdown, | |
mcp_list_state, | |
temperature, | |
max_tokens, | |
) | |
elif provider == "HuggingFace": | |
connection = await gr_connect_to_hf(model_id, hf_token, mcp_list_state) | |
elif provider == "Nebius": | |
connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state) | |
return connection | |
with ( | |
gr.Blocks( | |
theme=gr_themes.Origin( | |
primary_hue="teal", | |
spacing_size="sm", | |
font="sans-serif", | |
), | |
title="TDAgent", | |
) as gr_app, | |
gr.Row(), | |
): | |
with gr.Column(scale=1): | |
with gr.Accordion("π MCP Servers", open=False): | |
mcp_list = MutableCheckBoxGroup( | |
values=[ | |
MutableCheckBoxGroupEntry( | |
name="TDAgent tools", | |
value="https://agents-mcp-hackathon-tdagenttools.hf.space/gradio_api/mcp/sse", | |
), | |
], | |
label="MCP Servers", | |
new_value_label="MCP endpoint", | |
new_name_label="MCP endpoint name", | |
new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse", | |
new_name_placeholder="Swiss army knife of MCPs", | |
) | |
with gr.Accordion("βοΈ Provider Configuration", open=True): | |
model_provider = gr.Dropdown( | |
choices=list(MODEL_OPTIONS.keys()), | |
value=None, | |
label="Select Model Provider", | |
) | |
aws_access_key_textbox = gr.Textbox( | |
label="AWS Access Key ID", | |
type="password", | |
placeholder="Enter your AWS Access Key ID", | |
visible=False, | |
) | |
aws_secret_key_textbox = gr.Textbox( | |
label="AWS Secret Access Key", | |
type="password", | |
placeholder="Enter your AWS Secret Access Key", | |
visible=False, | |
) | |
aws_region_dropdown = gr.Dropdown( | |
label="AWS Region", | |
choices=[ | |
"us-east-1", | |
"us-west-2", | |
"eu-west-1", | |
"eu-central-1", | |
"ap-southeast-1", | |
], | |
value="eu-west-1", | |
visible=False, | |
) | |
aws_session_token_textbox = gr.Textbox( | |
label="AWS Session Token", | |
type="password", | |
placeholder="Enter your AWS session token", | |
visible=False, | |
) | |
hf_token = gr.Textbox( | |
label="HuggingFace Token", | |
type="password", | |
placeholder="Enter your Hugging Face Access Token", | |
visible=False, | |
) | |
with gr.Accordion("π§ Model Configuration", open=True): | |
model_display_id = gr.Dropdown( | |
label="Select Model ID", | |
choices=[], | |
visible=False, | |
) | |
model_provider.change( | |
toggle_model_fields, | |
inputs=[model_provider], | |
outputs=[ | |
model_display_id, | |
aws_access_key_textbox, | |
aws_secret_key_textbox, | |
aws_session_token_textbox, | |
aws_region_dropdown, | |
hf_token, | |
], | |
) | |
# Initialize the temperature and max tokens based on model specifications | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.8, | |
step=0.1, | |
) | |
max_tokens = gr.Slider( | |
label="Max Tokens", | |
minimum=64, | |
maximum=4096, | |
value=512, | |
step=64, | |
) | |
connect_btn = gr.Button("π Connect to Model", variant="primary") | |
status_textbox = gr.Textbox(label="Connection Status", interactive=False) | |
connect_btn.click( | |
update_connection_status, | |
inputs=[ | |
model_provider, | |
model_display_id, | |
mcp_list.state, | |
aws_access_key_textbox, | |
aws_secret_key_textbox, | |
aws_session_token_textbox, | |
aws_region_dropdown, | |
hf_token, | |
temperature, | |
max_tokens, | |
], | |
outputs=[status_textbox], | |
) | |
with gr.Column(scale=2): | |
chat_interface = gr.ChatInterface( | |
fn=gr_chat_function, | |
type="messages", | |
examples=[], # Add examples if needed | |
title="π©βπ» TDAgent", | |
description="This is a simple agent that uses MCP tools.", | |
) | |
if __name__ == "__main__": | |
gr_app.launch() | |