Update langgraph_agent.py
Browse files- langgraph_agent.py +50 -38
langgraph_agent.py
CHANGED
@@ -3,7 +3,7 @@ import io
|
|
3 |
import contextlib
|
4 |
import pandas as pd
|
5 |
from typing import Dict, List, Union
|
6 |
-
import re
|
7 |
|
8 |
from PIL import Image as PILImage
|
9 |
from huggingface_hub import InferenceClient
|
@@ -13,10 +13,9 @@ from langgraph.prebuilt import tools_condition, ToolNode
|
|
13 |
from langchain_openai import ChatOpenAI
|
14 |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
|
15 |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
|
16 |
-
from langchain_core.messages import SystemMessage, HumanMessage
|
17 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
18 |
from langchain_core.tools import tool
|
19 |
-
|
20 |
from langchain_google_community import GoogleSearchAPIWrapper
|
21 |
|
22 |
@tool
|
@@ -108,7 +107,11 @@ def read_file_content(file_path: str) -> Dict[str, str]:
|
|
108 |
elif file_extension in (".jpeg", ".jpg", ".png"):
|
109 |
return {"file_type": "image", "file_name": file_path, "file_content": f"Image file '{file_path}' detected. Use 'describe_image' tool to get a textual description."}
|
110 |
elif file_extension == ".mp3":
|
111 |
-
|
|
|
|
|
|
|
|
|
112 |
else:
|
113 |
return {"file_type": "unsupported", "file_name": file_path, "file_content": f"Unsupported file type: {file_extension}. Only .txt, .py, .xlsx, .jpeg, .jpg, .png, .mp3 files are recognized."}
|
114 |
except FileNotFoundError:
|
@@ -145,22 +148,7 @@ def describe_image(image_path: str) -> Dict[str, str]:
|
|
145 |
except Exception as e:
|
146 |
return {"error": f"Error describing image {image_path}: {str(e)}"}
|
147 |
|
148 |
-
|
149 |
-
def transcribe_audio(audio_path: str) -> Dict[str, str]:
|
150 |
-
"""Transcribes an audio file (e.g., MP3) to text using an automatic speech recognition model from the Hugging Face Inference API. Requires HF_API_TOKEN environment variable to be set."""
|
151 |
-
if not HF_INFERENCE_CLIENT:
|
152 |
-
return {"error": "Hugging Face API token not configured for audio transcription. Cannot use this tool."}
|
153 |
-
try:
|
154 |
-
with open(audio_path, "rb") as f:
|
155 |
-
audio_bytes = f.read()
|
156 |
-
transcription = HF_INFERENCE_CLIENT.automatic_speech_recognition(audio_bytes)
|
157 |
-
return {"audio_transcription": transcription, "audio_path": audio_path}
|
158 |
-
except FileNotFoundError:
|
159 |
-
return {"error": f"Audio file not found: {audio_path}. Please ensure the file exists."}
|
160 |
-
except Exception as e:
|
161 |
-
return {"error": f"Error transcribing audio {audio_path}: {str(e)}"}
|
162 |
-
|
163 |
-
# --- NEW YOUTUBE TOOL ---
|
164 |
@tool
|
165 |
def Youtube(url: str, question: str) -> Dict[str, str]:
|
166 |
"""
|
@@ -190,13 +178,13 @@ def Youtube(url: str, question: str) -> Dict[str, str]:
|
|
190 |
else:
|
191 |
return {"error": "Invalid or unrecognized YouTube URL.", "url": url}
|
192 |
|
193 |
-
# --- END
|
194 |
|
195 |
API_KEY = os.getenv("GEMINI_API_KEY")
|
196 |
-
HF_API_TOKEN = os.getenv("HF_SPACE_TOKEN")
|
197 |
-
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
198 |
|
199 |
-
# Update the tools list
|
200 |
tools = [
|
201 |
multiply, add, subtract, divide, modulus,
|
202 |
wiki_search,
|
@@ -205,8 +193,7 @@ tools = [
|
|
205 |
read_file_content,
|
206 |
python_interpreter,
|
207 |
describe_image,
|
208 |
-
transcribe_audio
|
209 |
-
Youtube, # <-- ADDED THE NEW YOUTUBE TOOL HERE
|
210 |
]
|
211 |
|
212 |
with open("prompt.txt", "r", encoding="utf-8") as f:
|
@@ -232,18 +219,51 @@ def build_graph(provider: str = "gemini"):
|
|
232 |
else:
|
233 |
raise ValueError("Invalid provider. Choose 'gemini' or 'huggingface'.")
|
234 |
|
235 |
-
# This is the crucial line that binds your defined Python tools to the LLM
|
236 |
llm_with_tools = llm.bind_tools(tools)
|
237 |
|
238 |
def assistant(state: MessagesState):
|
239 |
messages_to_send = [sys_msg] + state["messages"]
|
240 |
-
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
return {"messages": [llm_response]}
|
243 |
|
244 |
builder = StateGraph(MessagesState)
|
245 |
builder.add_node("assistant", assistant)
|
246 |
-
builder.add_node("tools", ToolNode(tools))
|
247 |
builder.add_edge(START, "assistant")
|
248 |
builder.add_conditional_edges("assistant", tools_condition)
|
249 |
builder.add_edge("tools", "assistant")
|
@@ -251,12 +271,4 @@ def build_graph(provider: str = "gemini"):
|
|
251 |
return builder.compile()
|
252 |
|
253 |
if __name__ == "__main__":
|
254 |
-
# Example usage (you'll need to set GEMINI_API_KEY and potentially HF_API_TOKEN env vars)
|
255 |
-
# This part assumes you have a prompt.txt file with the system_prompt as discussed earlier.
|
256 |
-
|
257 |
-
# You would typically interact with the compiled graph like this:
|
258 |
-
# graph = build_graph("gemini")
|
259 |
-
# user_input = "Tell me about this video: https://www.youtube.com/watch?v=1htKBjuUWec"
|
260 |
-
# result = graph.invoke({"messages": [HumanMessage(content=user_input)]})
|
261 |
-
# print(result)
|
262 |
pass
|
|
|
3 |
import contextlib
|
4 |
import pandas as pd
|
5 |
from typing import Dict, List, Union
|
6 |
+
import re
|
7 |
|
8 |
from PIL import Image as PILImage
|
9 |
from huggingface_hub import InferenceClient
|
|
|
13 |
from langchain_openai import ChatOpenAI
|
14 |
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
|
15 |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
|
16 |
+
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage
|
17 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
18 |
from langchain_core.tools import tool
|
|
|
19 |
from langchain_google_community import GoogleSearchAPIWrapper
|
20 |
|
21 |
@tool
|
|
|
107 |
elif file_extension in (".jpeg", ".jpg", ".png"):
|
108 |
return {"file_type": "image", "file_name": file_path, "file_content": f"Image file '{file_path}' detected. Use 'describe_image' tool to get a textual description."}
|
109 |
elif file_extension == ".mp3":
|
110 |
+
# For MP3, we indicate it's an audio file and expect the LLM to handle the blob directly.
|
111 |
+
# In a real Langchain setup, you might actually read the bytes here and pass them
|
112 |
+
# as a part of the message content to the LLM if it supports direct binary upload.
|
113 |
+
# For now, this tool simply confirms its type for the agent.
|
114 |
+
return {"file_type": "audio", "file_name": file_path, "file_content": f"Audio file '{file_path}' detected. The LLM (Gemini 2.5 Pro) can process this audio content directly."}
|
115 |
else:
|
116 |
return {"file_type": "unsupported", "file_name": file_path, "file_content": f"Unsupported file type: {file_extension}. Only .txt, .py, .xlsx, .jpeg, .jpg, .png, .mp3 files are recognized."}
|
117 |
except FileNotFoundError:
|
|
|
148 |
except Exception as e:
|
149 |
return {"error": f"Error describing image {image_path}: {str(e)}"}
|
150 |
|
151 |
+
# --- Youtube Tool (Remains the same) ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
@tool
|
153 |
def Youtube(url: str, question: str) -> Dict[str, str]:
|
154 |
"""
|
|
|
178 |
else:
|
179 |
return {"error": "Invalid or unrecognized YouTube URL.", "url": url}
|
180 |
|
181 |
+
# --- END YOUTUBE TOOL ---
|
182 |
|
183 |
API_KEY = os.getenv("GEMINI_API_KEY")
|
184 |
+
HF_API_TOKEN = os.getenv("HF_SPACE_TOKEN")
|
185 |
+
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
|
186 |
|
187 |
+
# Update the tools list (removed transcribe_audio)
|
188 |
tools = [
|
189 |
multiply, add, subtract, divide, modulus,
|
190 |
wiki_search,
|
|
|
193 |
read_file_content,
|
194 |
python_interpreter,
|
195 |
describe_image,
|
196 |
+
Youtube, # <-- transcribe_audio has been removed
|
|
|
197 |
]
|
198 |
|
199 |
with open("prompt.txt", "r", encoding="utf-8") as f:
|
|
|
219 |
else:
|
220 |
raise ValueError("Invalid provider. Choose 'gemini' or 'huggingface'.")
|
221 |
|
|
|
222 |
llm_with_tools = llm.bind_tools(tools)
|
223 |
|
224 |
def assistant(state: MessagesState):
|
225 |
messages_to_send = [sys_msg] + state["messages"]
|
226 |
+
|
227 |
+
# When sending messages to Gemini, if read_file_content identified an audio file,
|
228 |
+
# you'll need to ensure the actual binary content of the audio file is included
|
229 |
+
# in the message parts for the LLM to process it natively.
|
230 |
+
# This part requires a bit more advanced handling than just text.
|
231 |
+
# Langchain often handles this when you use `tool_code.File(...)` or similar constructs.
|
232 |
+
# For simplicity in this prompt and code example, we're assuming the framework
|
233 |
+
# will correctly pass the file content if `read_file_content` returns an audio type.
|
234 |
+
|
235 |
+
# A more robust implementation would involve modifying the `assistant` node
|
236 |
+
# to explicitly read the file bytes and add them to the message parts
|
237 |
+
# if a file is detected in the input state.
|
238 |
+
|
239 |
+
# Example of how you might include binary content (conceptual, depends on LangChain/API):
|
240 |
+
# new_messages_to_send = []
|
241 |
+
# for msg in messages_to_send:
|
242 |
+
# if isinstance(msg, HumanMessage) and "audio file" in msg.content: # Simplified check
|
243 |
+
# # Assume you can get the actual file path from the context
|
244 |
+
# file_path_from_context = "Strawberry pie.mp3" # Or extract from msg.content
|
245 |
+
# if os.path.exists(file_path_from_context):
|
246 |
+
# with open(file_path_from_context, "rb") as f:
|
247 |
+
# audio_bytes = f.read()
|
248 |
+
# new_messages_to_send.append(
|
249 |
+
# HumanMessage(
|
250 |
+
# content=[
|
251 |
+
# {"type": "text", "text": "Here is the audio file:"},
|
252 |
+
# {"type": "media", "media_type": "audio/mp3", "data": audio_bytes}
|
253 |
+
# ]
|
254 |
+
# )
|
255 |
+
# )
|
256 |
+
# else:
|
257 |
+
# new_messages_to_send.append(msg)
|
258 |
+
# llm_response = llm_with_tools.invoke(new_messages_to_send)
|
259 |
+
|
260 |
+
llm_response = llm_with_tools.invoke(messages_to_send) # For now, keep as is, rely on framework
|
261 |
+
print(f"LLM Raw Response: {llm_response}")
|
262 |
return {"messages": [llm_response]}
|
263 |
|
264 |
builder = StateGraph(MessagesState)
|
265 |
builder.add_node("assistant", assistant)
|
266 |
+
builder.add_node("tools", ToolNode(tools))
|
267 |
builder.add_edge(START, "assistant")
|
268 |
builder.add_conditional_edges("assistant", tools_condition)
|
269 |
builder.add_edge("tools", "assistant")
|
|
|
271 |
return builder.compile()
|
272 |
|
273 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
pass
|