TDAgent / tdagent /grchat.py
Josep Pon Farreny
feat: Add langgraph chat with bedrock
4a76b1f
raw
history blame
5.43 kB
from __future__ import annotations
from collections.abc import Mapping
from types import MappingProxyType
import boto3
import botocore
import botocore.exceptions
import gradio as gr
from langchain_aws import ChatBedrock
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langgraph.graph.graph import CompiledGraph
from langgraph.prebuilt import create_react_agent
#### Constants ####
SYSTEM_MESSAGE = SystemMessage(
"You are a helpful assistant.",
)
GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType(
{
"user": HumanMessage,
"assistant": AIMessage,
},
)
#### Shared variables ####
llm_agent: CompiledGraph | None = None
#### Utility functions ####
def create_bedrock_llm(
bedrock_model_id: str,
aws_access_key: str,
aws_secret_key: str,
aws_session_token: str,
aws_region: str,
) -> tuple[ChatBedrock | None, str]:
"""Create a LangGraph Bedrock agent."""
boto3_config = {
"aws_access_key_id": aws_access_key,
"aws_secret_access_key": aws_secret_key,
"aws_session_token": aws_session_token if aws_session_token else None,
"region_name": aws_region,
}
# Verify credentials
try:
sts = boto3.client("sts", **boto3_config)
sts.get_caller_identity()
except botocore.exceptions.ClientError as err:
return None, str(err)
try:
bedrock_client = boto3.client("bedrock-runtime", **boto3_config)
llm = ChatBedrock(
model=bedrock_model_id,
client=bedrock_client,
model_kwargs={"temperature": 0.7},
)
except Exception as e: # noqa: BLE001
return None, str(e)
return llm, ""
#### UI functionality ####
async def gr_connect_to_bedrock(
model_id: str,
access_key: str,
secret_key: str,
session_token: str,
region: str,
) -> str:
"""Initialize Bedrock agent."""
global llm_agent # noqa: PLW0603
if not access_key or not secret_key:
return "❌ Please provide both Access Key ID and Secret Access Key"
llm, error = create_bedrock_llm(
model_id,
access_key.strip(),
secret_key.strip(),
session_token.strip(),
region,
)
if llm is None:
return f"❌ Connection failed: {error}"
llm_agent = create_react_agent(
model=llm,
tools=[],
prompt=SYSTEM_MESSAGE,
)
return "βœ… Successfully connected to AWS Bedrock!"
async def gr_chat_function( # noqa: D103
message: str,
history: list[Mapping[str, str]],
) -> str:
if llm_agent is None:
return "Please configure your credentials first."
messages = []
for hist_msg in history:
role = hist_msg["role"]
message_type = GRADIO_ROLE_TO_LG_MESSAGE_TYPE[role]
messages.append(message_type(content=hist_msg["content"]))
messages.append(HumanMessage(content=message))
llm_response = await llm_agent.ainvoke(
{
"messages": messages,
},
)
return llm_response["messages"][-1].content
## UI components ##
with gr.Blocks() as gr_app:
gr.Markdown("# πŸ” Secure Bedrock Chatbot")
# Credentials section (collapsible)
with gr.Accordion("πŸ”‘ Bedrock Configuration", open=True):
gr.Markdown(
"**Note**: Credentials are only stored in memory during your session.",
)
with gr.Row():
bedrock_model_id_textbox = gr.Textbox(
label="Bedrock Model Id",
value="eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
)
with gr.Row():
aws_access_key_textbox = gr.Textbox(
label="AWS Access Key ID",
type="password",
placeholder="Enter your AWS Access Key ID",
)
aws_secret_key_textbox = gr.Textbox(
label="AWS Secret Access Key",
type="password",
placeholder="Enter your AWS Secret Access Key",
)
with gr.Row():
aws_session_token_textbox = gr.Textbox(
label="AWS Session Token",
type="password",
placeholder="Enter your AWS session token",
)
with gr.Row():
aws_region_dropdown = gr.Dropdown(
label="AWS Region",
choices=[
"us-east-1",
"us-west-2",
"eu-west-1",
"eu-central-1",
"ap-southeast-1",
],
value="eu-west-1",
)
connect_btn = gr.Button("πŸ”Œ Connect to Bedrock", variant="primary")
status_textbox = gr.Textbox(label="Connection Status", interactive=False)
connect_btn.click(
gr_connect_to_bedrock,
inputs=[
bedrock_model_id_textbox,
aws_access_key_textbox,
aws_secret_key_textbox,
aws_session_token_textbox,
aws_region_dropdown,
],
outputs=[status_textbox],
)
chat_interface = gr.ChatInterface(
fn=gr_chat_function,
type="messages",
examples=[],
title="Agent with MCP Tools",
description="This is a simple agent that uses MCP tools.",
)
if __name__ == "__main__":
gr_app.launch()