Fix Huggingface Inference

#12
by jponf - opened
pyproject.toml CHANGED
@@ -13,6 +13,7 @@ requires-python = ">=3.10,<4"
13
  readme = "README.md"
14
  license = ""
15
  dependencies = [
 
16
  "gradio[mcp]~=5.31",
17
  "huggingface-hub>=0.32.3",
18
  "langchain-aws>=0.2.24",
 
13
  readme = "README.md"
14
  license = ""
15
  dependencies = [
16
+ "aiohttp>=3.12.9",
17
  "gradio[mcp]~=5.31",
18
  "huggingface-hub>=0.32.3",
19
  "langchain-aws>=0.2.24",
requirements-dev.txt CHANGED
@@ -1,8 +1,13 @@
1
  # This file was autogenerated by uv via the following command:
2
  # uv export --format requirements-txt --no-hashes --no-annotate --group dev --group test -o requirements-dev.txt
3
  aiofiles==24.1.0
 
 
 
4
  annotated-types==0.7.0
5
  anyio==4.9.0
 
 
6
  audioop-lts==0.2.1 ; python_full_version >= '3.13'
7
  boolean-py==5.0
8
  boto3==1.38.27
@@ -23,6 +28,7 @@ exceptiongroup==1.3.0 ; python_full_version < '3.11'
23
  fastapi==0.115.12
24
  ffmpy==0.6.0
25
  filelock==3.18.0
 
26
  fsspec==2025.5.1
27
  gradio==5.32.1
28
  gradio-client==1.10.2
@@ -58,6 +64,7 @@ mcp==1.9.0
58
  mdurl==0.1.2
59
  mpmath==1.3.0
60
  msgpack==1.1.0
 
61
  mypy==1.16.0
62
  mypy-extensions==1.1.0
63
  networkx==3.4.2 ; python_full_version < '3.11'
@@ -94,6 +101,7 @@ pip-requirements-parser==32.0.1
94
  platformdirs==4.3.8
95
  pluggy==1.6.0
96
  pre-commit==3.8.0
 
97
  py-serializable==2.0.0
98
  pycparser==2.22 ; platform_python_implementation == 'PyPy'
99
  pydantic==2.11.5
@@ -150,4 +158,5 @@ virtualenv==20.31.2
150
  websockets==15.0.1
151
  xdoctest==1.2.0
152
  xxhash==3.5.0
 
153
  zstandard==0.23.0
 
1
  # This file was autogenerated by uv via the following command:
2
  # uv export --format requirements-txt --no-hashes --no-annotate --group dev --group test -o requirements-dev.txt
3
  aiofiles==24.1.0
4
+ aiohappyeyeballs==2.6.1
5
+ aiohttp==3.12.9
6
+ aiosignal==1.3.2
7
  annotated-types==0.7.0
8
  anyio==4.9.0
9
+ async-timeout==5.0.1 ; python_full_version < '3.11'
10
+ attrs==25.3.0
11
  audioop-lts==0.2.1 ; python_full_version >= '3.13'
12
  boolean-py==5.0
13
  boto3==1.38.27
 
28
  fastapi==0.115.12
29
  ffmpy==0.6.0
30
  filelock==3.18.0
31
+ frozenlist==1.6.2
32
  fsspec==2025.5.1
33
  gradio==5.32.1
34
  gradio-client==1.10.2
 
64
  mdurl==0.1.2
65
  mpmath==1.3.0
66
  msgpack==1.1.0
67
+ multidict==6.4.4
68
  mypy==1.16.0
69
  mypy-extensions==1.1.0
70
  networkx==3.4.2 ; python_full_version < '3.11'
 
101
  platformdirs==4.3.8
102
  pluggy==1.6.0
103
  pre-commit==3.8.0
104
+ propcache==0.3.1
105
  py-serializable==2.0.0
106
  pycparser==2.22 ; platform_python_implementation == 'PyPy'
107
  pydantic==2.11.5
 
158
  websockets==15.0.1
159
  xdoctest==1.2.0
160
  xxhash==3.5.0
161
+ yarl==1.20.0
162
  zstandard==0.23.0
requirements.txt CHANGED
@@ -1,8 +1,13 @@
1
  # This file was autogenerated by uv via the following command:
2
  # uv export --format requirements-txt --no-hashes --no-annotate --no-dev -o requirements.txt
3
  aiofiles==24.1.0
 
 
 
4
  annotated-types==0.7.0
5
  anyio==4.9.0
 
 
6
  audioop-lts==0.2.1 ; python_full_version >= '3.13'
7
  boto3==1.38.27
8
  botocore==1.38.27
@@ -17,6 +22,7 @@ exceptiongroup==1.3.0 ; python_full_version < '3.11'
17
  fastapi==0.115.12
18
  ffmpy==0.6.0
19
  filelock==3.18.0
 
20
  fsspec==2025.5.1
21
  gradio==5.32.1
22
  gradio-client==1.10.2
@@ -49,6 +55,7 @@ markupsafe==3.0.2
49
  mcp==1.9.0
50
  mdurl==0.1.2 ; sys_platform != 'emscripten'
51
  mpmath==1.3.0
 
52
  networkx==3.4.2 ; python_full_version < '3.11'
53
  networkx==3.5 ; python_full_version >= '3.11'
54
  numpy==1.26.4 ; python_full_version < '3.12'
@@ -74,6 +81,7 @@ packaging==24.2
74
  pandas==2.2.3
75
  pillow==11.2.1
76
  pluggy==1.6.0
 
77
  pycparser==2.22 ; platform_python_implementation == 'PyPy'
78
  pydantic==2.11.5
79
  pydantic-core==2.33.2
@@ -125,4 +133,5 @@ uvicorn==0.34.3 ; sys_platform != 'emscripten'
125
  websockets==15.0.1
126
  xdoctest==1.2.0
127
  xxhash==3.5.0
 
128
  zstandard==0.23.0
 
1
  # This file was autogenerated by uv via the following command:
2
  # uv export --format requirements-txt --no-hashes --no-annotate --no-dev -o requirements.txt
3
  aiofiles==24.1.0
4
+ aiohappyeyeballs==2.6.1
5
+ aiohttp==3.12.9
6
+ aiosignal==1.3.2
7
  annotated-types==0.7.0
8
  anyio==4.9.0
9
+ async-timeout==5.0.1 ; python_full_version < '3.11'
10
+ attrs==25.3.0
11
  audioop-lts==0.2.1 ; python_full_version >= '3.13'
12
  boto3==1.38.27
13
  botocore==1.38.27
 
22
  fastapi==0.115.12
23
  ffmpy==0.6.0
24
  filelock==3.18.0
25
+ frozenlist==1.6.2
26
  fsspec==2025.5.1
27
  gradio==5.32.1
28
  gradio-client==1.10.2
 
55
  mcp==1.9.0
56
  mdurl==0.1.2 ; sys_platform != 'emscripten'
57
  mpmath==1.3.0
58
+ multidict==6.4.4
59
  networkx==3.4.2 ; python_full_version < '3.11'
60
  networkx==3.5 ; python_full_version >= '3.11'
61
  numpy==1.26.4 ; python_full_version < '3.12'
 
81
  pandas==2.2.3
82
  pillow==11.2.1
83
  pluggy==1.6.0
84
+ propcache==0.3.1
85
  pycparser==2.22 ; platform_python_implementation == 'PyPy'
86
  pydantic==2.11.5
87
  pydantic-core==2.33.2
 
133
  websockets==15.0.1
134
  xdoctest==1.2.0
135
  xxhash==3.5.0
136
+ yarl==1.20.0
137
  zstandard==0.23.0
tdagent/grchat.py CHANGED
@@ -1,5 +1,6 @@
1
  from __future__ import annotations
2
 
 
3
  from collections.abc import Mapping, Sequence
4
  from types import MappingProxyType
5
  from typing import TYPE_CHECKING, Any
@@ -8,9 +9,10 @@ import boto3
8
  import botocore
9
  import botocore.exceptions
10
  import gradio as gr
 
11
  from langchain_aws import ChatBedrock
12
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
13
- from langchain_huggingface import HuggingFaceEndpoint
14
  from langchain_mcp_adapters.client import MultiServerMCPClient
15
  from langgraph.prebuilt import create_react_agent
16
  from openai import OpenAI
@@ -51,15 +53,32 @@ GRADIO_ROLE_TO_LG_MESSAGE_TYPE = MappingProxyType(
51
  },
52
  )
53
 
54
- MODEL_OPTIONS = {
55
- "AWS Bedrock": {
56
- "Anthropic Claude 3.5 Sonnet": "eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
57
- # "Anthropic Claude 3.7 Sonnet": "anthropic.claude-3-7-sonnet-20250219-v1:0",
58
- },
59
- "HuggingFace": {
60
- "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct",
61
- },
62
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  #### Shared variables ####
65
 
@@ -109,18 +128,20 @@ def create_bedrock_llm(
109
  def create_hf_llm(
110
  hf_model_id: str,
111
  huggingfacehub_api_token: str | None = None,
112
- ) -> tuple[HuggingFaceEndpoint | None, str]:
113
  """Create a LangGraph Hugging Face agent."""
114
  try:
115
  llm = HuggingFaceEndpoint(
116
  model=hf_model_id,
117
- huggingfacehub_api_token=huggingfacehub_api_token,
118
  temperature=0.8,
 
 
119
  )
 
120
  except Exception as e: # noqa: BLE001
121
  return None, str(e)
122
 
123
- return llm, ""
124
 
125
 
126
  ## OpenAI LLM creation ##
@@ -286,14 +307,18 @@ async def gr_chat_function( # noqa: D103
286
  messages.append(message_type(content=hist_msg["content"]))
287
 
288
  messages.append(HumanMessage(content=message))
289
-
290
- llm_response = await llm_agent.ainvoke(
291
- {
292
- "messages": messages,
293
- },
294
- )
295
-
296
- return llm_response["messages"][-1].content
 
 
 
 
297
 
298
 
299
  ## UI components ##
@@ -314,7 +339,12 @@ def toggle_model_fields(
314
  # Update model choices based on the selected provider
315
  if provider in MODEL_OPTIONS:
316
  model_choices = list(MODEL_OPTIONS[provider].keys())
317
- model_pretty = gr.update(choices=model_choices, visible=True, interactive=True)
 
 
 
 
 
318
  else:
319
  model_pretty = gr.update(choices=[], visible=False)
320
 
@@ -346,7 +376,9 @@ async def update_connection_status( # noqa: PLR0913
346
  """Update the connection status based on the selected provider and model."""
347
  if not provider or not pretty_model:
348
  return "❌ Please select a provider and model."
 
349
  model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model)
 
350
  if model_id:
351
  if provider == "AWS Bedrock":
352
  connection = await gr_connect_to_bedrock(
@@ -363,15 +395,21 @@ async def update_connection_status( # noqa: PLR0913
363
  connection = await gr_connect_to_hf(model_id, hf_token, mcp_list_state)
364
  elif provider == "Nebius":
365
  connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
366
- else:
367
- return "❌ Invalid provider"
368
- return connection if connection else "❌ Invalid provider"
369
 
370
 
371
- with gr.Blocks(
372
- theme=gr.themes.Origin(primary_hue="teal", spacing_size="sm", font="sans-serif"),
373
- title="TDAgent",
374
- ) as gr_app, gr.Row():
 
 
 
 
 
 
 
375
  with gr.Column(scale=1):
376
  with gr.Accordion("πŸ”Œ MCP Servers", open=False):
377
  mcp_list = MutableCheckBoxGroup(
@@ -382,6 +420,10 @@ with gr.Blocks(
382
  ),
383
  ],
384
  label="MCP Servers",
 
 
 
 
385
  )
386
 
387
  with gr.Accordion("βš™οΈ Provider Configuration", open=True):
 
1
  from __future__ import annotations
2
 
3
+ from collections import OrderedDict
4
  from collections.abc import Mapping, Sequence
5
  from types import MappingProxyType
6
  from typing import TYPE_CHECKING, Any
 
9
  import botocore
10
  import botocore.exceptions
11
  import gradio as gr
12
+ import gradio.themes as gr_themes
13
  from langchain_aws import ChatBedrock
14
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
15
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
16
  from langchain_mcp_adapters.client import MultiServerMCPClient
17
  from langgraph.prebuilt import create_react_agent
18
  from openai import OpenAI
 
53
  },
54
  )
55
 
56
+ MODEL_OPTIONS = OrderedDict( # Initialize with tuples to preserve options order
57
+ (
58
+ (
59
+ "HuggingFace",
60
+ {
61
+ "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.3",
62
+ "Llama 3.1 8B Instruct": "meta-llama/Llama-3.1-8B-Instruct",
63
+ # "Qwen3 235B A22B": "Qwen/Qwen3-235B-A22B", # Slow inference
64
+ "Microsoft Phi-3.5-mini Instruct": "microsoft/Phi-3.5-mini-instruct",
65
+ # "Deepseek R1 distill-llama 70B": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", # noqa: E501
66
+ # "Deepseek V3": "deepseek-ai/DeepSeek-V3",
67
+ },
68
+ ),
69
+ (
70
+ "AWS Bedrock",
71
+ {
72
+ "Anthropic Claude 3.5 Sonnet (EU)": (
73
+ "eu.anthropic.claude-3-5-sonnet-20240620-v1:0"
74
+ ),
75
+ # "Anthropic Claude 3.7 Sonnet": (
76
+ # "anthropic.claude-3-7-sonnet-20250219-v1:0"
77
+ # ),
78
+ },
79
+ ),
80
+ ),
81
+ )
82
 
83
  #### Shared variables ####
84
 
 
128
  def create_hf_llm(
129
  hf_model_id: str,
130
  huggingfacehub_api_token: str | None = None,
131
+ ) -> tuple[ChatHuggingFace | None, str]:
132
  """Create a LangGraph Hugging Face agent."""
133
  try:
134
  llm = HuggingFaceEndpoint(
135
  model=hf_model_id,
 
136
  temperature=0.8,
137
+ task="text-generation",
138
+ huggingfacehub_api_token=huggingfacehub_api_token,
139
  )
140
+ chat_llm = ChatHuggingFace(llm=llm)
141
  except Exception as e: # noqa: BLE001
142
  return None, str(e)
143
 
144
+ return chat_llm, ""
145
 
146
 
147
  ## OpenAI LLM creation ##
 
307
  messages.append(message_type(content=hist_msg["content"]))
308
 
309
  messages.append(HumanMessage(content=message))
310
+ try:
311
+ llm_response = await llm_agent.ainvoke(
312
+ {
313
+ "messages": messages,
314
+ },
315
+ )
316
+ return llm_response["messages"][-1].content
317
+ except Exception as err:
318
+ raise gr.Error(
319
+ f"We encountered an error while invoking the model:\n{err}",
320
+ print_exception=True,
321
+ ) from err
322
 
323
 
324
  ## UI components ##
 
339
  # Update model choices based on the selected provider
340
  if provider in MODEL_OPTIONS:
341
  model_choices = list(MODEL_OPTIONS[provider].keys())
342
+ model_pretty = gr.update(
343
+ choices=model_choices,
344
+ value=model_choices[0],
345
+ visible=True,
346
+ interactive=True,
347
+ )
348
  else:
349
  model_pretty = gr.update(choices=[], visible=False)
350
 
 
376
  """Update the connection status based on the selected provider and model."""
377
  if not provider or not pretty_model:
378
  return "❌ Please select a provider and model."
379
+
380
  model_id = MODEL_OPTIONS.get(provider, {}).get(pretty_model)
381
+ connection = "❌ Invalid provider"
382
  if model_id:
383
  if provider == "AWS Bedrock":
384
  connection = await gr_connect_to_bedrock(
 
395
  connection = await gr_connect_to_hf(model_id, hf_token, mcp_list_state)
396
  elif provider == "Nebius":
397
  connection = await gr_connect_to_nebius(model_id, hf_token, mcp_list_state)
398
+
399
+ return connection
 
400
 
401
 
402
+ with (
403
+ gr.Blocks(
404
+ theme=gr_themes.Origin(
405
+ primary_hue="teal",
406
+ spacing_size="sm",
407
+ font="sans-serif",
408
+ ),
409
+ title="TDAgent",
410
+ ) as gr_app,
411
+ gr.Row(),
412
+ ):
413
  with gr.Column(scale=1):
414
  with gr.Accordion("πŸ”Œ MCP Servers", open=False):
415
  mcp_list = MutableCheckBoxGroup(
 
420
  ),
421
  ],
422
  label="MCP Servers",
423
+ new_value_label="MCP endpoint",
424
+ new_name_label="MCP endpoint name",
425
+ new_value_placeholder="https://my-cool-mcp-server.com/mcp/sse",
426
+ new_name_placeholder="Swiss army knife of MCPs",
427
  )
428
 
429
  with gr.Accordion("βš™οΈ Provider Configuration", open=True):
tdagent/grcomponents/mcbgroup.py CHANGED
@@ -19,7 +19,7 @@ class MutableCheckBoxGroupEntry(NamedTuple):
19
  class MutableCheckBoxGroup(gr.Blocks):
20
  """Check box group with controls to add or remove values."""
21
 
22
- def __init__( # noqa: PLR0913
23
  self,
24
  values: list[MutableCheckBoxGroupEntry] | None = None,
25
  label: str = "Extendable List",
@@ -68,16 +68,24 @@ class MutableCheckBoxGroup(gr.Blocks):
68
  self.input_value = gr.Textbox(
69
  label=self.new_value_label,
70
  placeholder=self.new_value_placeholder,
71
- scale=4,
72
  )
73
  self.input_name = gr.Textbox(
74
  label=self.new_name_label,
75
  placeholder=self.new_name_placeholder,
76
  scale=2,
77
  )
78
- with gr.Column():
79
- self.add_btn = gr.Button("Add", variant="primary", scale=1)
80
- self.delete_btn = gr.Button("Delete Selected", variant="stop")
 
 
 
 
 
 
 
 
81
 
82
  # Vertical checkbox group
83
  self.items_group = gr.CheckboxGroup(
 
19
  class MutableCheckBoxGroup(gr.Blocks):
20
  """Check box group with controls to add or remove values."""
21
 
22
+ def __init__(
23
  self,
24
  values: list[MutableCheckBoxGroupEntry] | None = None,
25
  label: str = "Extendable List",
 
68
  self.input_value = gr.Textbox(
69
  label=self.new_value_label,
70
  placeholder=self.new_value_placeholder,
71
+ scale=3,
72
  )
73
  self.input_name = gr.Textbox(
74
  label=self.new_name_label,
75
  placeholder=self.new_name_placeholder,
76
  scale=2,
77
  )
78
+ with gr.Row():
79
+ self.add_btn = gr.Button(
80
+ "Add",
81
+ variant="primary",
82
+ scale=1,
83
+ )
84
+ self.delete_btn = gr.Button(
85
+ "Delete Selected",
86
+ variant="stop",
87
+ scale=1,
88
+ )
89
 
90
  # Vertical checkbox group
91
  self.items_group = gr.CheckboxGroup(
uv.lock CHANGED
The diff for this file is too large to render. See raw diff