CasperBA's picture
Adjusts sysprompt and rely on serper and open URL for question answering.
595c733
import os
from typing import Annotated, List, Optional, Dict, Any
from typing_extensions import TypedDict
from pathlib import Path
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter
import base64
import io
from dotenv import load_dotenv
from langchain.tools import tool
from langchain_tavily import TavilySearch
# Import math tools
import cmath # needed for square_root of negative numbers
from langchain_community.document_loaders import WikipediaLoader
from langchain_core.messages import SystemMessage, BaseMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
# Load environment variables from .env file
load_dotenv()
# Define the state for the agent
class State(TypedDict):
messages: Annotated[List[BaseMessage], add_messages]
@tool
def wikipedia(query: str) -> str:
"""
Searches Wikipedia for the given query and returns the content of the top 2 most relevant documents.
Use this tool to answer questions about historical events, scientific concepts,
or any other topic that can be found on Wikipedia.
Sometimes the tavily_search tool is better.
Args:
query: The search query.
Returns:
A dictionary containing the formatted search results.
"""
search_docs = WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=50000).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
]
)
return {"wiki_results": formatted_search_docs}
# -----------------------------------------------------------------------------
# Search Tools
# -----------------------------------------------------------------------------
@tool
def tavily_search(query: str) -> str:
"""If Wikipedia searches fail try this tool to Search the web using Tavily Search API and return a formatted string of the top results."""
api_key = os.getenv("TAVILY_API_KEY")
if not api_key:
return "Error: TAVILY_API_KEY environment variable is not set."
try:
search_tool = TavilySearch(api_key=api_key, max_results=5)
results = search_tool.invoke(query)
except Exception as exc:
return f"Error: Tavily search failed: {exc}"
# LangChain TavilySearch returns list[dict]
if isinstance(results, list):
formatted = "\n\n---\n\n".join(
[f"Title: {r.get('title', '')}\nURL: {r.get('url', '')}\nSnippet: {r.get('snippet', '')}" for r in results]
)
return formatted or "No results found."
return str(results)
# -----------------------------------------------------------------------------
# Serper Search Tool (Google)
# -----------------------------------------------------------------------------
@tool
def serper_search(query: str) -> str:
"""Search the web using the Serper API (Google Search) and return a formatted
string of the top results."""
api_key = os.getenv("SERPER_API_KEY")
if not api_key:
return "Error: SERPER_API_KEY environment variable is not set."
import requests
try:
resp = requests.post(
"https://google.serper.dev/search",
headers={"X-API-KEY": api_key, "Content-Type": "application/json"},
json={"q": query, "num": 10}, # return up to 10 results, we'll format top 5
timeout=20,
)
resp.raise_for_status()
data = resp.json()
except Exception as exc:
return f"Error: Serper search failed: {exc}"
results = data.get("organic", [])[:5]
if not results:
return "No results found."
formatted = "\n\n---\n\n".join(
[f"Title: {r.get('title', '')}\nURL: {r.get('link', '')}\nSnippet: {r.get('snippet', '')}" for r in results]
)
return formatted or "No results found."
# -----------------------------------------------------------------------------
# URL Retrieval Tool
# -----------------------------------------------------------------------------
@tool
def open_url(url: str, max_chars: int = 50000) -> str:
"""Download a web page and return its plain-text content (truncated). Supports HTML and other text types.
Args:
url: The HTTP/HTTPS URL to fetch.
Returns:
Cleaned text or an error string.
"""
import requests
from bs4 import BeautifulSoup
try:
resp = requests.get(url, timeout=20, headers={"User-Agent": "Mozilla/5.0 (compatible; LangChain-Agent/1.0)"})
resp.raise_for_status()
content_type = resp.headers.get("Content-Type", "")
# If HTML, strip tags; otherwise return raw text
if "text/html" in content_type:
soup = BeautifulSoup(resp.text, "html.parser")
# Remove non-visible elements
for tag in soup(["script", "style", "noscript"]):
tag.decompose()
text = soup.get_text("\n")
else:
text = resp.text
return text.strip()[:max_chars] or "No readable text found."
except Exception as exc:
return f"Error fetching {url}: {exc}"
# -----------------------------------------------------------------------------
# Composite web search + retrieval tool
# -----------------------------------------------------------------------------
@tool
def web_lookup(query: str) -> dict:
"""
Search the web using Tavily and automatically retrieve the plain-text content
of the top result.
Args:
query: Search query.
Returns:
Dict containing:
- top_results: List with one Tavily result dict
- page_url: URL opened
- page_content: Cleaned page text (truncated)
- error: present only if something went wrong
"""
api_key = os.getenv("TAVILY_API_KEY")
if not api_key:
return {"error": "TAVILY_API_KEY environment variable is not set."}
# Always fetch exactly one result
num_results = 1
try:
search_tool = TavilySearch(api_key=api_key, max_results=num_results)
raw_results = search_tool.invoke(query)
except Exception as exc:
return {"error": f"Tavily search failed: {exc}"}
# TavilySearch sometimes returns a list of dicts, sometimes a dict with a
# "results" key – normalise to a list.
if isinstance(raw_results, list):
results = raw_results
elif isinstance(raw_results, dict) and "results" in raw_results:
results = raw_results["results"]
else:
return {"error": f"Unexpected Tavily response: {type(raw_results)}"}
if not results:
return {"error": "No Tavily results found."}
best_url = results[0].get("url") if isinstance(results[0], dict) else None
if not best_url:
return {"error": "Top Tavily result had no URL field."}
# Use open_url default truncation
page_text = open_url(best_url)
return {
"top_results": results,
"page_url": best_url,
"page_content": page_text,
}
# -----------------------------------------------------------------------------
# Multimedia Tools
# -----------------------------------------------------------------------------
@tool
def transcribe_audio(audio_path: str) -> str:
"""Transcribe the supplied audio file to text using the OpenAI Whisper API (``whisper-1``).
Args:
audio_path: The path to the audio file to transcribe.
Returns:
The transcribed text or an error string.
"""
if not Path(audio_path).exists():
return f"Error: Audio file not found at {audio_path}"
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
return "Error: OPENAI_API_KEY environment variable is not set."
try:
from openai import OpenAI # type: ignore
client = OpenAI(api_key=api_key)
with open(audio_path, "rb") as f:
transcription = client.audio.transcriptions.create(
model="whisper-1",
file=f,
)
text: str | None = getattr(transcription, "text", None)
if text:
return text.strip()
return "Error: Transcription response did not contain text."
except Exception as exc:
return f"Error: OpenAI transcription failed: {exc}"
# -----------------------------------------------------------------------------
# Math Tools
# -----------------------------------------------------------------------------
@tool
def multiply(a: float, b: float) -> float:
"""Multiply two numbers and return the product."""
return a * b
@tool
def add(a: float, b: float) -> float:
"""Add two numbers and return the sum."""
return a + b
@tool
def subtract(a: float, b: float) -> float:
"""Subtract the second number from the first and return the result."""
return a - b
@tool
def divide(a: float, b: float) -> float:
"""Divide the first number by the second and return the quotient.
Raises:
ValueError: If b is zero.
"""
if b == 0:
raise ValueError("Cannot divide by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Return the modulus of two integers."""
return a % b
@tool
def power(a: float, b: float) -> float:
"""Return a to the power of b."""
return a ** b
@tool
def square_root(a: float):
"""Return the square root of a. Supports complex results for negative inputs."""
if a >= 0:
return a ** 0.5
return cmath.sqrt(a)
# -----------------------------------------------------------------------------
# File handling tools
# -----------------------------------------------------------------------------
@tool
def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
"""
Download a file from a URL and return the local file path.
Args:
url: The URL to download the file from.
filename: The optional name to save the file as. If not provided, it's inferred from the URL.
"""
import requests
from pathlib import Path
# If a filename isn't provided, infer it from the URL.
if not filename:
filename = url.split("/")[-1]
download_dir = Path("downloads")
download_dir.mkdir(parents=True, exist_ok=True)
local_path = download_dir / filename
try:
resp = requests.get(url, stream=True, timeout=30)
resp.raise_for_status()
with open(local_path, 'wb') as f:
for chunk in resp.iter_content(1024):
f.write(chunk)
except Exception as e:
return f"Error downloading file from {url}: {e}"
return str(local_path)
@tool
def analyze_csv_file(file_path: str) -> str:
"""
Read a CSV at file_path and return JSON records.
"""
import pandas as pd
from pathlib import Path
if not Path(file_path).exists():
return f"Error: file not found at {file_path}"
df = pd.read_csv(file_path)
return df.to_json(orient="records")
@tool
def analyze_excel_file(file_path: str) -> str:
"""
Read an Excel file at file_path and return JSON per sheet.
"""
import pandas as pd
from pathlib import Path
import json
if not Path(file_path).exists():
return f"Error: file not found at {file_path}"
xls = pd.read_excel(file_path, sheet_name=None)
result = {name: df.to_json(orient="records") for name, df in xls.items()}
return json.dumps(result)
def decode_image(image_base64: str) -> Image.Image:
"""Decode a base64 encoded image string to a PIL Image."""
image_data = base64.b64decode(image_base64)
return Image.open(io.BytesIO(image_data))
def encode_image(image_path: str) -> str:
"""Encode an image file to a base64 string."""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def save_image(img: Image.Image, subdir: str = "transformed") -> str:
"""Save a PIL image to a file and return the path."""
output_dir = Path("images") / subdir
output_dir.mkdir(parents=True, exist_ok=True)
# Create a unique filename
import uuid
filename = f"{uuid.uuid4()}.png"
filepath = output_dir / filename
img.save(filepath)
return str(filepath)
### ============== IMAGE PROCESSING AND GENERATION TOOLS =============== ###
@tool
def analyze_image(image_base64: str) -> Dict[str, Any]:
"""
Analyze basic properties of an image (size, mode, color analysis, thumbnail preview).
Args:
image_base64 (str): Base64 encoded image string
Returns:
Dictionary with analysis result
"""
try:
img = decode_image(image_base64)
width, height = img.size
mode = img.mode
if mode in ("RGB", "RGBA"):
arr = np.array(img)
avg_colors = arr.mean(axis=(0, 1))
dominant = ["Red", "Green", "Blue"][np.argmax(avg_colors[:3])]
brightness = avg_colors.mean()
color_analysis = {
"average_rgb": avg_colors.tolist(),
"brightness": brightness,
"dominant_color": dominant,
}
else:
color_analysis = {"note": f"No color analysis for mode {mode}"}
thumbnail = img.copy()
thumbnail.thumbnail((100, 100))
thumb_path = save_image(thumbnail, "thumbnails")
thumbnail_base64 = encode_image(thumb_path)
return {
"dimensions": (width, height),
"mode": mode,
"color_analysis": color_analysis,
"thumbnail": thumbnail_base64,
}
except Exception as e:
return {"error": str(e)}
@tool
def transform_image(
image_base64: str, operation: str, params: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Apply transformations: resize, rotate, crop, flip, brightness, contrast, blur, sharpen, grayscale.
Args:
image_base64 (str): Base64 encoded input image
operation (str): Transformation operation
params (Dict[str, Any], optional): Parameters for the operation
Returns:
Dictionary with transformed image (base64)
"""
try:
img = decode_image(image_base64)
params = params or {}
if operation == "resize":
img = img.resize(
(
params.get("width", img.width // 2),
params.get("height", img.height // 2),
)
)
elif operation == "rotate":
img = img.rotate(params.get("angle", 90), expand=True)
elif operation == "crop":
img = img.crop(
(
params.get("left", 0),
params.get("top", 0),
params.get("right", img.width),
params.get("bottom", img.height),
)
)
elif operation == "flip":
if params.get("direction", "horizontal") == "horizontal":
img = img.transpose(Image.FLIP_LEFT_RIGHT)
else:
img = img.transpose(Image.FLIP_TOP_BOTTOM)
elif operation == "adjust_brightness":
img = ImageEnhance.Brightness(img).enhance(params.get("factor", 1.5))
elif operation == "adjust_contrast":
img = ImageEnhance.Contrast(img).enhance(params.get("factor", 1.5))
elif operation == "blur":
img = img.filter(ImageFilter.GaussianBlur(params.get("radius", 2)))
elif operation == "sharpen":
img = img.filter(ImageFilter.SHARPEN)
elif operation == "grayscale":
img = img.convert("L")
else:
return {"error": f"Unknown operation: {operation}"}
result_path = save_image(img)
result_base64 = encode_image(result_path)
return {"transformed_image": result_base64}
except Exception as e:
return {"error": str(e)}
class Agent:
def __init__(self):
"""
Initializes the Agent by setting up the LLM, tools, and the LangGraph graph.
"""
# Initialize the LLM
# Make sure to set the NEBIUS_API_KEY environment variable
nebius_api_key = os.environ.get("NEBIUS_API_KEY")
if not nebius_api_key:
try:
from huggingface_hub import HfApi
nebius_api_key = HfApi().get_secret("NEBIUS_API_KEY")
except Exception as e:
print(f"Could not get NEBIUS_API_KEY from secrets: {e}")
raise ValueError("NEBIUS_API_KEY environment variable or secret not set.")
llm = ChatOpenAI(
model="Qwen/Qwen3-235B-A22B-Instruct-2507",
api_key=nebius_api_key,
base_url="https://api.studio.nebius.com/v1/"
)
#llm = ChatOpenAI(
# model="gpt-4.1-2025-04-14",
#)
# Load default system prompt
prompt_path = Path(__file__).with_name("system_promt.txt")
self.default_system_prompt = (
prompt_path.read_text(encoding="utf-8")
if prompt_path.exists()
else "You are a helpful assistant. Answer user questions accurately. If tools are available, think whether they are needed. Provide the final answer only."
)
# -----------------------------------------------------------------------------
# Assemble tool groups for clarity
# -----------------------------------------------------------------------------
self.retrieval_tools = [serper_search, open_url]
self.media_tools = [transcribe_audio]
self.file_tools = [download_file_from_url, analyze_csv_file, analyze_excel_file]
self.math_tools = [multiply, add, subtract, divide, modulus, power, square_root]
self.image_tools = [analyze_image, transform_image]
self.tools = self.retrieval_tools + self.media_tools + self.file_tools + self.math_tools + self.image_tools
# Bind tools
# -----------------------------------------------------------------------------
self.llm_with_tools = llm.bind_tools(self.tools)
# -----------------------------------------------------------------------------
# Agent Graph Definition
# -----------------------------------------------------------------------------
graph_builder = StateGraph(State)
graph_builder.add_node("assistant", self.assistant_node)
graph_builder.add_node("tools", ToolNode(self.tools))
graph_builder.add_node("parser", self.parse_node)
graph_builder.add_edge(START, "assistant")
graph_builder.add_conditional_edges(
"assistant",
self.should_continue,
{"continue": "tools", "end": "parser"}
)
graph_builder.add_edge("tools", "assistant")
graph_builder.add_edge("parser", "__end__")
self.graph = graph_builder.compile()
def assistant_node(self, state: State):
"""
The assistant node in the graph. It calls the LLM with the current state
to decide the next action (respond or call a tool).
"""
messages = state["messages"]
system_message = SystemMessage(content=self.default_system_prompt)
# Ensure the system message is the first message
if not messages or not isinstance(messages[0], SystemMessage):
messages.insert(0, system_message)
response = self.llm_with_tools.invoke(messages)
return {"messages": [response]}
def should_continue(self, state: State) -> str:
"""
Determines whether to continue with tool calls or end the process.
"""
if state["messages"][-1].tool_calls:
return "continue"
return "end"
def parse_node(self, state: State):
"""
Parses the final answer to remove the <think> tags.
"""
import re
last_message = state["messages"][-1]
content = last_message.content
# Use regex to find and extract the content after </think>
match_think = re.search(r"</think>\s*(.*)", content, re.DOTALL)
if match_think:
content = match_think.group(1).strip()
# Check for 'FINAL ANSWER:' and extract the content after it
match_final_answer = re.search(r"FINAL ANSWER:\s*(.*)", content, re.IGNORECASE | re.DOTALL)
if match_final_answer:
content = match_final_answer.group(1).strip()
last_message.content = content
return {"messages": [last_message]}
def __call__(self, item: dict, api_url: str) -> str:
"""
Main entry point for the agent.
Args:
item: A dictionary containing the question, file_name, etc.
api_url: The base URL of the API service.
Returns:
The agent's final answer as a string.
"""
question = item.get("question", "")
file_name = item.get("file_name")
print(f"Agent received question: {question[:100]}...")
initial_content = f"Question: {question}"
if file_name:
task_id = item.get("task_id")
# Construct the correct URL for the file using the task_id
file_url = f"{api_url}/files/{task_id}"
print(f"File detected. Download URL: {file_url}")
# Add information about the file to the initial prompt
initial_content += f'\n\nThere is a file associated with this question named `{file_name}`. To access its contents, first, download it using the `download_file_from_url` tool. Use the URL `"{file_url}"` and be sure to pass the filename `"{file_name}"` to the `filename` argument. After downloading, use the appropriate tool to analyze the file (e.g., `transcribe_audio` for audio files).'
initial_state = {"messages": [HumanMessage(content=initial_content)]}
# Invoke the graph
final_state = self.graph.invoke(initial_state)
# The final answer is the last message from the assistant
answer = final_state["messages"][-1].content
print(f"Agent returning answer: {answer[:100]}...")
return answer