import re import json import uuid import warnings from abc import ABC from typing import ( Any, AsyncIterator, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union, cast, ) from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) from langchain_core.language_models import BaseChatModel, LanguageModelInput from langchain_core.messages import ( SystemMessage, AIMessage, BaseMessage, BaseMessageChunk, ToolCall, ) from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.prompts import SystemMessagePromptTemplate from pydantic import BaseModel from langchain_core.runnables import Runnable, RunnableConfig from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools: {tools} You must always select one of the above tools and respond with only a JSON object matching the following schema: {{ "tool": , "tool_input": }}, {{ "tool": , "tool_input": }} """ # noqa: E501 def extract_think(content): # Added by Cursor 20250726 jmd # Extract content within ... think_match = re.search(r"(.*?)", content, re.DOTALL) think_text = think_match.group(1).strip() if think_match else "" # Extract text after if think_match: post_think = content[think_match.end() :].lstrip() else: # Check if content starts with but missing closing tag if content.strip().startswith(""): # Extract everything after think_start = content.find("") + len("") think_text = content[think_start:].strip() post_think = "" else: # No found, so return entire content as post_think post_think = content return think_text, post_think class ToolCallingLLM(BaseChatModel, ABC): """ToolCallingLLM mixin to enable tool calling features on non tool calling models. Note: This is an incomplete mixin and should not be used directly. It must be used to extent an existing Chat Model. Setup: Install dependencies for your Chat Model. Any API Keys or setup needed for your Chat Model is still applicable. Key init args — completion params: Refer to the documentation of the Chat Model you wish to extend with Tool Calling. Key init args — client params: Refer to the documentation of the Chat Model you wish to extend with Tool Calling. See full list of supported init args and their descriptions in the params section. Instantiate: ``` # Example implementation using LiteLLM from langchain_community.chat_models import ChatLiteLLM class LiteLLMFunctions(ToolCallingLLM, ChatLiteLLM): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @property def _llm_type(self) -> str: return "litellm_functions" llm = LiteLLMFunctions(model="ollama/phi3") ``` Invoke: ``` messages = [ ("human", "What is the capital of France?") ] llm.invoke(messages) ``` ``` AIMessage(content='The capital of France is Paris.', id='run-497d0e1a-d63b-45e8-9c8b-5e76d99b9468-0') ``` Tool calling: ``` from pydantic import BaseModel, Field class GetWeather(BaseModel): '''Get the current weather in a given location''' location: str = Field(..., description="The city and state, e.g. San Francisco, CA") class GetPopulation(BaseModel): '''Get the current population in a given location''' location: str = Field(..., description="The city and state, e.g. San Francisco, CA") llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?") ai_msg.tool_calls ``` ``` [{'name': 'GetWeather', 'args': {'location': 'Austin, TX'}, 'id': 'call_25ed526917b94d8fa5db3fe30a8cf3c0'}] ``` Response metadata Refer to the documentation of the Chat Model you wish to extend with Tool Calling. """ # noqa: E501 tool_system_prompt_template: str = DEFAULT_SYSTEM_TEMPLATE def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) def _generate_system_message_and_functions( self, kwargs: Dict[str, Any], ) -> Tuple[BaseMessage, List]: functions = kwargs.get("tools", []) # Convert functions to OpenAI tool schema functions = [convert_to_openai_tool(fn) for fn in functions] # Create system message with tool descriptions system_message_prompt_template = SystemMessagePromptTemplate.from_template( self.tool_system_prompt_template ) system_message = system_message_prompt_template.format( tools=json.dumps(functions, indent=2) ) return system_message, functions def _process_response( self, response_message: BaseMessage, functions: List[Dict] ) -> AIMessage: if not isinstance(response_message.content, str): raise ValueError("ToolCallingLLM does not support non-string output.") # Extract ... content and text after for further processing 20250726 jmd think_text, post_think = extract_think(response_message.content) ## For debugging # print("post_think") # print(post_think) # Remove backticks around code blocks post_think = re.sub(r"^```json", "", post_think) post_think = re.sub(r"^```", "", post_think) post_think = re.sub(r"```$", "", post_think) # Remove intervening backticks from adjacent code blocks post_think = re.sub(r"```\n```json", ",", post_think) # Remove trailing comma (if there is one) post_think = post_think.rstrip(",") # Parse output for JSON (support multiple objects separated by commas) try: # Works for one JSON object, or multiple JSON objects enclosed in "[]" parsed_json_results = json.loads(f"{post_think}") if not isinstance(parsed_json_results, list): parsed_json_results = [parsed_json_results] except: try: # Works for multiple JSON objects not enclosed in "[]" parsed_json_results = json.loads(f"[{post_think}]") except json.JSONDecodeError: # Return entire response if JSON wasn't parsed or is missing return AIMessage(content=response_message.content) # print("parsed_json_results") # print(parsed_json_results) tool_calls = [] for parsed_json_result in parsed_json_results: # Get tool name from output called_tool_name = ( parsed_json_result["tool"] if "tool" in parsed_json_result else ( parsed_json_result["name"] if "name" in parsed_json_result else None ) ) # Check if tool name is in functions list called_tool = next( (fn for fn in functions if fn["function"]["name"] == called_tool_name), None, ) if called_tool is None: # Issue a warning and skip this tool call warnings.warn(f"Called tool ({called_tool_name}) not in functions list") continue # Get tool arguments from output called_tool_arguments = ( parsed_json_result["tool_input"] if "tool_input" in parsed_json_result else ( parsed_json_result["parameters"] if "parameters" in parsed_json_result else {} ) ) tool_calls.append( ToolCall( name=called_tool_name, args=called_tool_arguments, id=f"call_{str(uuid.uuid4()).replace('-', '')}", ) ) if not tool_calls: # If nothing valid, return original content return AIMessage(content=response_message.content) # Put together response message response_message = AIMessage( content=f"\n{think_text}\n", tool_calls=tool_calls, ) return response_message def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: system_message, functions = self._generate_system_message_and_functions(kwargs) response_message = super()._generate( # type: ignore[safe-super] [system_message] + messages, stop=stop, run_manager=run_manager, **kwargs ) response = self._process_response( response_message.generations[0].message, functions ) return ChatResult(generations=[ChatGeneration(message=response)]) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: system_message, functions = self._generate_system_message_and_functions(kwargs) response_message = await super()._agenerate( [system_message] + messages, stop=stop, run_manager=run_manager, **kwargs ) response = self._process_response( response_message.generations[0].message, functions ) return ChatResult(generations=[ChatGeneration(message=response)]) async def astream( self, input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, stop: Optional[List[str]] = None, **kwargs: Any, ) -> AsyncIterator[BaseMessageChunk]: system_message, functions = self._generate_system_message_and_functions(kwargs) generation: Optional[BaseMessageChunk] = None async for chunk in super().astream( [system_message] + super()._convert_input(input).to_messages(), stop=stop, **kwargs, ): if generation is None: generation = chunk else: generation += chunk assert generation is not None response = self._process_response(generation, functions) yield cast(BaseMessageChunk, response)