Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,063 Bytes
08fac87 84ccc57 08fac87 f52b66d 08fac87 84ccc57 17ad0bb 84ccc57 08fac87 158fae7 08fac87 158fae7 17ad0bb 158fae7 08fac87 158fae7 08fac87 17ad0bb 08fac87 84ccc57 17ad0bb 84ccc57 859642d 355c5a2 4eb4f2a b42e964 f52b66d 08fac87 b42e964 3575a77 b42e964 3575a77 b42e964 3575a77 885cffb 84ccc57 f52b66d 08fac87 f52b66d 08fac87 f52b66d 08fac87 17ad0bb 08fac87 f52b66d 08fac87 f52b66d 158fae7 08fac87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
import re
import json
import uuid
import warnings
from abc import ABC
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import (
SystemMessage,
AIMessage,
BaseMessage,
BaseMessageChunk,
ToolCall,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.prompts import SystemMessagePromptTemplate
from pydantic import BaseModel
from langchain_core.runnables import Runnable, RunnableConfig
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
{tools}
You must always select one of the above tools and respond with only a JSON object matching the following schema:
{{
"tool": <name of selected tool 1>,
"tool_input": <parameters for selected tool 1, matching the tool's JSON schema>
}},
{{
"tool": <name of selected tool 2>,
"tool_input": <parameters for selected tool 2, matching the tool's JSON schema>
}}
""" # noqa: E501
def extract_think(content):
# Added by Cursor 20250726 jmd
# Extract content within <think>...</think>
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
think_text = think_match.group(1).strip() if think_match else ""
# Extract text after </think>
if think_match:
post_think = content[think_match.end() :].lstrip()
else:
# Check if content starts with <think> but missing closing tag
if content.strip().startswith("<think>"):
# Extract everything after <think>
think_start = content.find("<think>") + len("<think>")
think_text = content[think_start:].strip()
post_think = ""
else:
# No <think> found, so return entire content as post_think
post_think = content
return think_text, post_think
class ToolCallingLLM(BaseChatModel, ABC):
"""ToolCallingLLM mixin to enable tool calling features on non tool calling models.
Note: This is an incomplete mixin and should not be used directly. It must be used to extent an existing Chat Model.
Setup:
Install dependencies for your Chat Model.
Any API Keys or setup needed for your Chat Model is still applicable.
Key init args β completion params:
Refer to the documentation of the Chat Model you wish to extend with Tool Calling.
Key init args β client params:
Refer to the documentation of the Chat Model you wish to extend with Tool Calling.
See full list of supported init args and their descriptions in the params section.
Instantiate:
```
# Example implementation using LiteLLM
from langchain_community.chat_models import ChatLiteLLM
class LiteLLMFunctions(ToolCallingLLM, ChatLiteLLM):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
@property
def _llm_type(self) -> str:
return "litellm_functions"
llm = LiteLLMFunctions(model="ollama/phi3")
```
Invoke:
```
messages = [
("human", "What is the capital of France?")
]
llm.invoke(messages)
```
```
AIMessage(content='The capital of France is Paris.', id='run-497d0e1a-d63b-45e8-9c8b-5e76d99b9468-0')
```
Tool calling:
```
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
ai_msg.tool_calls
```
```
[{'name': 'GetWeather', 'args': {'location': 'Austin, TX'}, 'id': 'call_25ed526917b94d8fa5db3fe30a8cf3c0'}]
```
Response metadata
Refer to the documentation of the Chat Model you wish to extend with Tool Calling.
""" # noqa: E501
tool_system_prompt_template: str = DEFAULT_SYSTEM_TEMPLATE
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
def _generate_system_message_and_functions(
self,
kwargs: Dict[str, Any],
) -> Tuple[BaseMessage, List]:
functions = kwargs.get("tools", [])
# Convert functions to OpenAI tool schema
functions = [convert_to_openai_tool(fn) for fn in functions]
# Create system message with tool descriptions
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
self.tool_system_prompt_template
)
system_message = system_message_prompt_template.format(
tools=json.dumps(functions, indent=2)
)
return system_message, functions
def _process_response(
self, response_message: BaseMessage, functions: List[Dict]
) -> AIMessage:
if not isinstance(response_message.content, str):
raise ValueError("ToolCallingLLM does not support non-string output.")
# Extract <think>...</think> content and text after </think> for further processing 20250726 jmd
think_text, post_think = extract_think(response_message.content)
## For debugging
# print("post_think")
# print(post_think)
# Remove backticks around code blocks
post_think = re.sub(r"^```json", "", post_think)
post_think = re.sub(r"^```", "", post_think)
post_think = re.sub(r"```$", "", post_think)
# Remove intervening backticks from adjacent code blocks
post_think = re.sub(r"```\n```json", ",", post_think)
# Remove trailing comma (if there is one)
post_think = post_think.rstrip(",")
# Parse output for JSON (support multiple objects separated by commas)
try:
# Works for one JSON object, or multiple JSON objects enclosed in "[]"
parsed_json_results = json.loads(f"{post_think}")
if not isinstance(parsed_json_results, list):
parsed_json_results = [parsed_json_results]
except:
try:
# Works for multiple JSON objects not enclosed in "[]"
parsed_json_results = json.loads(f"[{post_think}]")
except json.JSONDecodeError:
# Return entire response if JSON wasn't parsed or is missing
return AIMessage(content=response_message.content)
# print("parsed_json_results")
# print(parsed_json_results)
tool_calls = []
for parsed_json_result in parsed_json_results:
# Get tool name from output
called_tool_name = (
parsed_json_result["tool"]
if "tool" in parsed_json_result
else (
parsed_json_result["name"] if "name" in parsed_json_result else None
)
)
# Check if tool name is in functions list
called_tool = next(
(fn for fn in functions if fn["function"]["name"] == called_tool_name),
None,
)
if called_tool is None:
# Issue a warning and skip this tool call
warnings.warn(f"Called tool ({called_tool_name}) not in functions list")
continue
# Get tool arguments from output
called_tool_arguments = (
parsed_json_result["tool_input"]
if "tool_input" in parsed_json_result
else (
parsed_json_result["parameters"]
if "parameters" in parsed_json_result
else {}
)
)
tool_calls.append(
ToolCall(
name=called_tool_name,
args=called_tool_arguments,
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
)
)
if not tool_calls:
# If nothing valid, return original content
return AIMessage(content=response_message.content)
# Put together response message
response_message = AIMessage(
content=f"<think>\n{think_text}\n</think>",
tool_calls=tool_calls,
)
return response_message
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
system_message, functions = self._generate_system_message_and_functions(kwargs)
response_message = super()._generate( # type: ignore[safe-super]
[system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
)
response = self._process_response(
response_message.generations[0].message, functions
)
return ChatResult(generations=[ChatGeneration(message=response)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
system_message, functions = self._generate_system_message_and_functions(kwargs)
response_message = await super()._agenerate(
[system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
)
response = self._process_response(
response_message.generations[0].message, functions
)
return ChatResult(generations=[ChatGeneration(message=response)])
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
system_message, functions = self._generate_system_message_and_functions(kwargs)
generation: Optional[BaseMessageChunk] = None
async for chunk in super().astream(
[system_message] + super()._convert_input(input).to_messages(),
stop=stop,
**kwargs,
):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
response = self._process_response(generation, functions)
yield cast(BaseMessageChunk, response)
|