Spaces:
Sleeping
Sleeping
""" | |
agent.py | |
This file defines the core logic for a sophisticated AI agent using LangGraph. | |
This version uses Groq's vision-capable models and includes proper reasoning steps. | |
""" | |
# ---------------------------------------------------------- | |
# Section 0: Imports and Configuration | |
# ---------------------------------------------------------- | |
import json | |
import os | |
import pickle | |
import re | |
import subprocess | |
import textwrap | |
import functools | |
from pathlib import Path | |
from typing import Dict, Any | |
import requests | |
from cachetools import TTLCache | |
from langchain.schema import Document | |
from langchain.tools.retriever import create_retriever_tool | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
from langchain_core.tools import Tool, tool | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langgraph.graph import START, StateGraph, MessagesState | |
from langgraph.prebuilt import ToolNode, tools_condition | |
from dotenv import load_dotenv | |
load_dotenv() | |
# --- Configuration and Caching --- | |
JSONL_PATH, FAISS_CACHE, EMBED_MODEL = Path("metadata.jsonl"), Path("faiss_index.pkl"), "sentence-transformers/all-mpnet-base-v2" | |
RETRIEVER_K, CACHE_TTL = 5, 600 | |
API_CACHE = TTLCache(maxsize=256, ttl=CACHE_TTL) | |
def cached_get(key: str, fetch_fn): | |
if key in API_CACHE: | |
return API_CACHE[key] | |
val = fetch_fn() | |
API_CACHE[key] = val | |
return val | |
# ---------------------------------------------------------- | |
# Section 2: Tool Functions | |
# ---------------------------------------------------------- | |
def python_repl(code: str) -> str: | |
"""Executes a string of Python code and returns the stdout/stderr.""" | |
code = textwrap.dedent(code).strip() | |
try: | |
result = subprocess.run( | |
["python", "-c", code], | |
capture_output=True, | |
text=True, | |
timeout=10, | |
check=False | |
) | |
if result.returncode == 0: | |
return f"Execution successful.\nSTDOUT:\n```\n{result.stdout}\n```" | |
else: | |
return f"Execution failed.\nSTDOUT:\n```\n{result.stdout}\n```\nSTDERR:\n```\n{result.stderr}\n```" | |
except subprocess.TimeoutExpired: | |
return "Execution timed out (>10s)." | |
def web_search_func(query: str, cache_func) -> str: | |
"""Performs a web search using Tavily and returns a compilation of results.""" | |
if not query or not query.strip(): | |
return "Error: Empty search query" | |
key = f"web:{query}" | |
try: | |
results = cache_func(key, lambda: TavilySearchResults(max_results=5).invoke(query)) | |
if not results: | |
return "No search results found" | |
formatted_results = [] | |
for res in results: | |
if isinstance(res, dict) and 'url' in res and 'content' in res: | |
formatted_results.append(f"Source: {res['url']}\nContent: {res['content']}") | |
return "\n\n---\n\n".join(formatted_results) if formatted_results else "No valid results found" | |
except Exception as e: | |
return f"Search error: {e}" | |
def wiki_search_func(query: str, cache_func) -> str: | |
"""Searches Wikipedia and returns the top 2 results.""" | |
if not query or not query.strip(): | |
return "Error: Empty search query" | |
key = f"wiki:{query}" | |
try: | |
docs = cache_func(key, lambda: WikipediaLoader( | |
query=query, | |
load_max_docs=2, | |
doc_content_chars_max=2000 | |
).load()) | |
if not docs: | |
return "No Wikipedia articles found" | |
return "\n\n---\n\n".join([ | |
f"Source: {d.metadata.get('source', 'Unknown')}\n\n{d.page_content}" | |
for d in docs | |
]) | |
except Exception as e: | |
return f"Wikipedia search error: {e}" | |
def arxiv_search_func(query: str, cache_func) -> str: | |
"""Searches Arxiv for scientific papers and returns the top 2 results.""" | |
if not query or not query.strip(): | |
return "Error: Empty search query" | |
key = f"arxiv:{query}" | |
try: | |
docs = cache_func(key, lambda: ArxivLoader(query=query, load_max_docs=2).load()) | |
if not docs: | |
return "No Arxiv papers found" | |
return "\n\n---\n\n".join([ | |
f"Source: {d.metadata.get('source', 'Unknown')}\n" | |
f"Published: {d.metadata.get('Published', 'Unknown')}\n" | |
f"Title: {d.metadata.get('Title', 'Unknown')}\n\n" | |
f"Summary:\n{d.page_content}" | |
for d in docs | |
]) | |
except Exception as e: | |
return f"Arxiv search error: {e}" | |
def analyze_task_and_reason(task_description: str) -> str: | |
""" | |
Analyzes the task and provides reasoning about what approach to take. | |
This tool helps determine what other tools might be needed. | |
""" | |
analysis = { | |
"task_type": "unknown", | |
"has_image": False, | |
"needs_search": False, | |
"needs_computation": False, | |
"approach": "Direct answer" | |
} | |
task_lower = task_description.lower() | |
# Check for image-related content | |
if any(keyword in task_lower for keyword in [ | |
'image', 'picture', 'photo', 'visual', 'see in', 'shown in', | |
'attachment analysis', 'url:', 'http', '.jpg', '.png', '.gif' | |
]): | |
analysis["has_image"] = True | |
analysis["task_type"] = "image_analysis" | |
analysis["approach"] = "Process image with vision model, then analyze content" | |
# Check for search needs | |
if any(keyword in task_lower for keyword in [ | |
'current', 'recent', 'latest', 'news', 'today', 'what is', | |
'who is', 'when did', 'research', 'find information' | |
]): | |
analysis["needs_search"] = True | |
if analysis["task_type"] == "unknown": | |
analysis["task_type"] = "information_search" | |
analysis["approach"] = "Search for current information" | |
# Check for computation needs | |
if any(keyword in task_lower for keyword in [ | |
'calculate', 'compute', 'math', 'formula', 'equation', | |
'algorithm', 'code', 'program', 'python' | |
]): | |
analysis["needs_computation"] = True | |
if analysis["task_type"] == "unknown": | |
analysis["task_type"] = "computation" | |
analysis["approach"] = "Use Python for calculations" | |
reasoning = f"""TASK ANALYSIS COMPLETE: | |
Task Type: {analysis['task_type']} | |
Has Image: {analysis['has_image']} | |
Needs Search: {analysis['needs_search']} | |
Needs Computation: {analysis['needs_computation']} | |
RECOMMENDED APPROACH: {analysis['approach']} | |
REASONING: | |
- If this involves an image, I should process it directly with my vision capabilities | |
- If this needs current information, I should use web search or Wikipedia | |
- If this needs calculations, I should use the Python tool | |
- I should always provide a comprehensive final answer | |
NEXT STEPS: Proceed with the identified approach and use appropriate tools.""" | |
return reasoning | |
# ---------------------------------------------------------- | |
# Section 3: SYSTEM PROMPT | |
# ---------------------------------------------------------- | |
SYSTEM_PROMPT = """You are an expert multimodal AI assistant with vision capabilities and access to various tools. | |
**CORE CAPABILITIES:** | |
1. **Vision Processing**: You can directly process and analyze images from URLs | |
2. **Web Search**: Access current information via web search and Wikipedia | |
3. **Computation**: Execute Python code for calculations and data processing | |
4. **Research**: Search academic papers and retrieve similar examples | |
**CRITICAL WORKFLOW:** | |
1. **ANALYZE FIRST**: Always start by using the 'analyze_task_and_reason' tool to understand what you're being asked to do | |
2. **PROCESS IMAGES DIRECTLY**: When you encounter image URLs, process them directly with your vision model - DO NOT use separate image tools | |
3. **USE TOOLS STRATEGICALLY**: Based on your analysis, use appropriate tools (web search, Python, etc.) | |
4. **VALIDATE PARAMETERS**: Always check that you're passing correct parameters to tools | |
5. **SYNTHESIZE**: Combine all information into a comprehensive answer | |
**IMAGE HANDLING:** | |
- You have native vision capabilities - process image URLs directly | |
- Look for image URLs in the task description | |
- When you see an image URL, examine it carefully and describe what you see | |
- Relate your visual observations to the question being asked | |
**TOOL USAGE RULES:** | |
- Always use 'analyze_task_and_reason' first to plan your approach | |
- Use web_search for current events, factual information, or research | |
- Use python_repl for calculations, data processing, or code execution | |
- Use wiki_search for encyclopedic information | |
- Use arxiv_search for academic/scientific papers | |
- Use retrieve_examples for similar solved problems | |
**OUTPUT FORMAT:** | |
Always end your response with your answer clearly stated on the last line. | |
**PARAMETER VALIDATION:** | |
- Check that search queries are meaningful and specific | |
- Ensure Python code is safe and well-formed | |
- Verify image URLs are accessible before processing | |
""" | |
# ---------------------------------------------------------- | |
# Section 4: Factory Function for Agent Executor | |
# ---------------------------------------------------------- | |
def create_agent_executor(provider: str = "groq"): | |
""" | |
Factory function to create and compile the LangGraph agent executor. | |
""" | |
print(f"Initializing agent with provider: {provider}") | |
# Step 1: Initialize LLM with vision capabilities | |
if provider == "groq": | |
# Use Groq's vision-capable model | |
try: | |
llm = ChatGroq( | |
model_name="meta-llama/llama-4-maverick-17b-128e-instruct" # Vision-capable model | |
) | |
print("Initialized Groq LLM with vision capabilities") | |
except Exception as e: | |
print(f"Error initializing Groq: {e}") | |
raise | |
else: | |
raise ValueError(f"Provider '{provider}' not supported in this version") | |
# Step 2: Build Retriever (if metadata exists) | |
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL) | |
if FAISS_CACHE.exists(): | |
with open(FAISS_CACHE, "rb") as f: | |
vector_store = pickle.load(f) | |
print("Loaded existing FAISS index") | |
else: | |
if JSONL_PATH.exists(): | |
docs = [] | |
with open(JSONL_PATH, "rt", encoding="utf-8") as f: | |
for line in f: | |
rec = json.loads(line) | |
docs.append(Document( | |
page_content=f"Question: {rec['Question']}\n\nFinal answer: {rec['Final answer']}", | |
metadata={"source": rec["task_id"]} | |
)) | |
vector_store = FAISS.from_documents(docs, embeddings) | |
with open(FAISS_CACHE, "wb") as f: | |
pickle.dump(vector_store, f) | |
print(f"Created new FAISS index with {len(docs)} documents") | |
else: | |
# Create minimal vector store | |
docs = [Document(page_content="Sample document", metadata={"source": "sample"})] | |
vector_store = FAISS.from_documents(docs, embeddings) | |
print("Created minimal FAISS index") | |
retriever = vector_store.as_retriever(search_kwargs={"k": RETRIEVER_K}) | |
# Step 3: Create tools list | |
tools_list = [ | |
analyze_task_and_reason, | |
Tool( | |
name="web_search", | |
func=functools.partial(web_search_func, cache_func=cached_get), | |
description="Search the web for current information. Use specific, focused queries." | |
), | |
Tool( | |
name="wiki_search", | |
func=functools.partial(wiki_search_func, cache_func=cached_get), | |
description="Search Wikipedia for encyclopedic information." | |
), | |
Tool( | |
name="arxiv_search", | |
func=functools.partial(arxiv_search_func, cache_func=cached_get), | |
description="Search Arxiv for academic papers and research." | |
), | |
python_repl, | |
create_retriever_tool( | |
retriever=retriever, | |
name="retrieve_examples", | |
description="Retrieve similar solved examples from the knowledge base." | |
), | |
] | |
llm_with_tools = llm.bind_tools(tools_list) | |
# Step 4: Define Graph Nodes | |
def assistant_node(state: MessagesState): | |
"""Main assistant node that processes user input and tool responses.""" | |
messages = [SystemMessage(content=SYSTEM_PROMPT)] + state["messages"] | |
try: | |
result = llm_with_tools.invoke(messages) | |
return {"messages": [result]} | |
except Exception as e: | |
error_msg = f"LLM Error: {e}" | |
print(error_msg) | |
return {"messages": [AIMessage(content=f"I encountered an error: {error_msg}")]} | |
def tools_node_wrapper(state: MessagesState): | |
"""Wrapper for tool execution with error handling.""" | |
try: | |
tool_node = ToolNode(tools_list) | |
return tool_node.invoke(state) | |
except Exception as e: | |
error_msg = f"Tool execution error: {e}" | |
print(error_msg) | |
return {"messages": [AIMessage(content=error_msg)]} | |
# Step 5: Build Graph | |
builder = StateGraph(MessagesState) | |
builder.add_node("assistant", assistant_node) | |
builder.add_node("tools", tools_node_wrapper) | |
builder.add_edge(START, "assistant") | |
builder.add_conditional_edges( | |
"assistant", | |
tools_condition, | |
{"tools": "tools", "__end__": "__end__"} | |
) | |
builder.add_edge("tools", "assistant") | |
agent_executor = builder.compile() | |
print("Agent Executor created successfully with vision capabilities") | |
return agent_executor |