R-help-chat / mods /tool_calling_llm.py
jedick
Remove debugging print statements
859642d
import re
import json
import uuid
import warnings
from abc import ABC
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import (
SystemMessage,
AIMessage,
BaseMessage,
BaseMessageChunk,
ToolCall,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.prompts import SystemMessagePromptTemplate
from pydantic import BaseModel
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
{tools}
You must always select one of the above tools and respond with only a JSON object matching the following schema:
{{
"tool": <name of selected tool 1>,
"tool_input": <parameters for selected tool 1, matching the tool's JSON schema>
}},
{{
"tool": <name of selected tool 2>,
"tool_input": <parameters for selected tool 2, matching the tool's JSON schema>
}}
""" # noqa: E501
def extract_think(content):
# Added by Cursor 20250726 jmd
# Extract content within <think>...</think>
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
think_text = think_match.group(1).strip() if think_match else ""
# Extract text after </think>
if think_match:
post_think = content[think_match.end() :].lstrip()
else:
# Check if content starts with <think> but missing closing tag
if content.strip().startswith("<think>"):
# Extract everything after <think>
think_start = content.find("<think>") + len("<think>")
think_text = content[think_start:].strip()
post_think = ""
else:
# No <think> found, so return entire content as post_think
post_think = content
return think_text, post_think
class ToolCallingLLM(BaseChatModel, ABC):
"""ToolCallingLLM mixin to enable tool calling features on non tool calling models.
Note: This is an incomplete mixin and should not be used directly. It must be used to extent an existing Chat Model.
Setup:
Install dependencies for your Chat Model.
Any API Keys or setup needed for your Chat Model is still applicable.
Key init args β€” completion params:
Refer to the documentation of the Chat Model you wish to extend with Tool Calling.
Key init args β€” client params:
Refer to the documentation of the Chat Model you wish to extend with Tool Calling.
See full list of supported init args and their descriptions in the params section.
Instantiate:
```
# Example implementation using LiteLLM
from langchain_community.chat_models import ChatLiteLLM
class LiteLLMFunctions(ToolCallingLLM, ChatLiteLLM):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
@property
def _llm_type(self) -> str:
return "litellm_functions"
llm = LiteLLMFunctions(model="ollama/phi3")
```
Invoke:
```
messages = [
("human", "What is the capital of France?")
]
llm.invoke(messages)
```
```
AIMessage(content='The capital of France is Paris.', id='run-497d0e1a-d63b-45e8-9c8b-5e76d99b9468-0')
```
Tool calling:
```
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
ai_msg.tool_calls
```
```
[{'name': 'GetWeather', 'args': {'location': 'Austin, TX'}, 'id': 'call_25ed526917b94d8fa5db3fe30a8cf3c0'}]
```
Response metadata
Refer to the documentation of the Chat Model you wish to extend with Tool Calling.
""" # noqa: E501
tool_system_prompt_template: str = DEFAULT_SYSTEM_TEMPLATE
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def _generate_system_message_and_functions(
self,
kwargs: Dict[str, Any],
) -> Tuple[BaseMessage, List]:
functions = kwargs.get("tools", [])
# Convert functions to OpenAI tool schema
functions = [convert_to_openai_tool(fn) for fn in functions]
# Create system message with tool descriptions
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
self.tool_system_prompt_template
)
system_message = system_message_prompt_template.format(
tools=json.dumps(functions, indent=2)
)
return system_message, functions
def _process_response(
self, response_message: BaseMessage, functions: List[Dict]
) -> AIMessage:
if not isinstance(response_message.content, str):
raise ValueError("ToolCallingLLM does not support non-string output.")
# Extract <think>...</think> content and text after </think> for further processing 20250726 jmd
think_text, post_think = extract_think(response_message.content)
## For debugging
# print("post_think")
# print(post_think)
# Remove backticks around code blocks
post_think = re.sub(r"^```json", "", post_think)
post_think = re.sub(r"^```", "", post_think)
post_think = re.sub(r"```$", "", post_think)
# Remove intervening backticks from adjacent code blocks
post_think = re.sub(r"```\n```json", ",", post_think)
# Remove trailing comma (if there is one)
post_think = post_think.rstrip(",")
# Parse output for JSON (support multiple objects separated by commas)
try:
# Works for one JSON object, or multiple JSON objects enclosed in "[]"
parsed_json_results = json.loads(f"{post_think}")
if not isinstance(parsed_json_results, list):
parsed_json_results = [parsed_json_results]
except:
try:
# Works for multiple JSON objects not enclosed in "[]"
parsed_json_results = json.loads(f"[{post_think}]")
except json.JSONDecodeError:
# Return entire response if JSON wasn't parsed or is missing
return AIMessage(content=response_message.content)
# print("parsed_json_results")
# print(parsed_json_results)
tool_calls = []
for parsed_json_result in parsed_json_results:
# Get tool name from output
called_tool_name = (
parsed_json_result["tool"]
if "tool" in parsed_json_result
else (
parsed_json_result["name"] if "name" in parsed_json_result else None
)
)
# Check if tool name is in functions list
called_tool = next(
(fn for fn in functions if fn["function"]["name"] == called_tool_name),
None,
)
if called_tool is None:
# Issue a warning and skip this tool call
warnings.warn(f"Called tool ({called_tool_name}) not in functions list")
continue
# Get tool arguments from output
called_tool_arguments = (
parsed_json_result["tool_input"]
if "tool_input" in parsed_json_result
else (
parsed_json_result["parameters"]
if "parameters" in parsed_json_result
else {}
)
)
tool_calls.append(
ToolCall(
name=called_tool_name,
args=called_tool_arguments,
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
)
)
if not tool_calls:
# If nothing valid, return original content
return AIMessage(content=response_message.content)
# Put together response message
response_message = AIMessage(
content=f"<think>\n{think_text}\n</think>",
tool_calls=tool_calls,
)
return response_message
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
system_message, functions = self._generate_system_message_and_functions(kwargs)
response_message = super()._generate( # type: ignore[safe-super]
[system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
)
response = self._process_response(
response_message.generations[0].message, functions
)
return ChatResult(generations=[ChatGeneration(message=response)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
system_message, functions = self._generate_system_message_and_functions(kwargs)
response_message = await super()._agenerate(
[system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
)
response = self._process_response(
response_message.generations[0].message, functions
)
return ChatResult(generations=[ChatGeneration(message=response)])
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
system_message, functions = self._generate_system_message_and_functions(kwargs)
generation: Optional[BaseMessageChunk] = None
async for chunk in super().astream(
[system_message] + super()._convert_input(input).to_messages(),
stop=stop,
**kwargs,
):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
response = self._process_response(generation, functions)
yield cast(BaseMessageChunk, response)