Spaces:
Runtime error
Runtime error
import os | |
from typing import Annotated, List, Optional, Dict, Any | |
from typing_extensions import TypedDict | |
from pathlib import Path | |
import numpy as np | |
from PIL import Image, ImageEnhance, ImageFilter | |
import base64 | |
import io | |
from dotenv import load_dotenv | |
from langchain.tools import tool | |
from langchain_tavily import TavilySearch | |
# Import math tools | |
import cmath # needed for square_root of negative numbers | |
from langchain_community.document_loaders import WikipediaLoader | |
from langchain_core.messages import SystemMessage, BaseMessage, HumanMessage | |
from langchain_openai import ChatOpenAI | |
from langgraph.graph import StateGraph, START | |
from langgraph.graph.message import add_messages | |
from langgraph.prebuilt import ToolNode | |
# Load environment variables from .env file | |
load_dotenv() | |
# Define the state for the agent | |
class State(TypedDict): | |
messages: Annotated[List[BaseMessage], add_messages] | |
def wikipedia(query: str) -> str: | |
""" | |
Searches Wikipedia for the given query and returns the content of the top 2 most relevant documents. | |
Use this tool to answer questions about historical events, scientific concepts, | |
or any other topic that can be found on Wikipedia. | |
Sometimes the tavily_search tool is better. | |
Args: | |
query: The search query. | |
Returns: | |
A dictionary containing the formatted search results. | |
""" | |
search_docs = WikipediaLoader(query=query, load_max_docs=2, doc_content_chars_max=50000).load() | |
formatted_search_docs = "\n\n---\n\n".join( | |
[ | |
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' | |
for doc in search_docs | |
] | |
) | |
return {"wiki_results": formatted_search_docs} | |
# ----------------------------------------------------------------------------- | |
# Search Tools | |
# ----------------------------------------------------------------------------- | |
def tavily_search(query: str) -> str: | |
"""If Wikipedia searches fail try this tool to Search the web using Tavily Search API and return a formatted string of the top results.""" | |
api_key = os.getenv("TAVILY_API_KEY") | |
if not api_key: | |
return "Error: TAVILY_API_KEY environment variable is not set." | |
try: | |
search_tool = TavilySearch(api_key=api_key, max_results=5) | |
results = search_tool.invoke(query) | |
except Exception as exc: | |
return f"Error: Tavily search failed: {exc}" | |
# LangChain TavilySearch returns list[dict] | |
if isinstance(results, list): | |
formatted = "\n\n---\n\n".join( | |
[f"Title: {r.get('title', '')}\nURL: {r.get('url', '')}\nSnippet: {r.get('snippet', '')}" for r in results] | |
) | |
return formatted or "No results found." | |
return str(results) | |
# ----------------------------------------------------------------------------- | |
# Serper Search Tool (Google) | |
# ----------------------------------------------------------------------------- | |
def serper_search(query: str) -> str: | |
"""Search the web using the Serper API (Google Search) and return a formatted | |
string of the top results.""" | |
api_key = os.getenv("SERPER_API_KEY") | |
if not api_key: | |
return "Error: SERPER_API_KEY environment variable is not set." | |
import requests | |
try: | |
resp = requests.post( | |
"https://google.serper.dev/search", | |
headers={"X-API-KEY": api_key, "Content-Type": "application/json"}, | |
json={"q": query, "num": 10}, # return up to 10 results, we'll format top 5 | |
timeout=20, | |
) | |
resp.raise_for_status() | |
data = resp.json() | |
except Exception as exc: | |
return f"Error: Serper search failed: {exc}" | |
results = data.get("organic", [])[:5] | |
if not results: | |
return "No results found." | |
formatted = "\n\n---\n\n".join( | |
[f"Title: {r.get('title', '')}\nURL: {r.get('link', '')}\nSnippet: {r.get('snippet', '')}" for r in results] | |
) | |
return formatted or "No results found." | |
# ----------------------------------------------------------------------------- | |
# URL Retrieval Tool | |
# ----------------------------------------------------------------------------- | |
def open_url(url: str, max_chars: int = 50000) -> str: | |
"""Download a web page and return its plain-text content (truncated). Supports HTML and other text types. | |
Args: | |
url: The HTTP/HTTPS URL to fetch. | |
Returns: | |
Cleaned text or an error string. | |
""" | |
import requests | |
from bs4 import BeautifulSoup | |
try: | |
resp = requests.get(url, timeout=20, headers={"User-Agent": "Mozilla/5.0 (compatible; LangChain-Agent/1.0)"}) | |
resp.raise_for_status() | |
content_type = resp.headers.get("Content-Type", "") | |
# If HTML, strip tags; otherwise return raw text | |
if "text/html" in content_type: | |
soup = BeautifulSoup(resp.text, "html.parser") | |
# Remove non-visible elements | |
for tag in soup(["script", "style", "noscript"]): | |
tag.decompose() | |
text = soup.get_text("\n") | |
else: | |
text = resp.text | |
return text.strip()[:max_chars] or "No readable text found." | |
except Exception as exc: | |
return f"Error fetching {url}: {exc}" | |
# ----------------------------------------------------------------------------- | |
# Composite web search + retrieval tool | |
# ----------------------------------------------------------------------------- | |
def web_lookup(query: str) -> dict: | |
""" | |
Search the web using Tavily and automatically retrieve the plain-text content | |
of the top result. | |
Args: | |
query: Search query. | |
Returns: | |
Dict containing: | |
- top_results: List with one Tavily result dict | |
- page_url: URL opened | |
- page_content: Cleaned page text (truncated) | |
- error: present only if something went wrong | |
""" | |
api_key = os.getenv("TAVILY_API_KEY") | |
if not api_key: | |
return {"error": "TAVILY_API_KEY environment variable is not set."} | |
# Always fetch exactly one result | |
num_results = 1 | |
try: | |
search_tool = TavilySearch(api_key=api_key, max_results=num_results) | |
raw_results = search_tool.invoke(query) | |
except Exception as exc: | |
return {"error": f"Tavily search failed: {exc}"} | |
# TavilySearch sometimes returns a list of dicts, sometimes a dict with a | |
# "results" key – normalise to a list. | |
if isinstance(raw_results, list): | |
results = raw_results | |
elif isinstance(raw_results, dict) and "results" in raw_results: | |
results = raw_results["results"] | |
else: | |
return {"error": f"Unexpected Tavily response: {type(raw_results)}"} | |
if not results: | |
return {"error": "No Tavily results found."} | |
best_url = results[0].get("url") if isinstance(results[0], dict) else None | |
if not best_url: | |
return {"error": "Top Tavily result had no URL field."} | |
# Use open_url default truncation | |
page_text = open_url(best_url) | |
return { | |
"top_results": results, | |
"page_url": best_url, | |
"page_content": page_text, | |
} | |
# ----------------------------------------------------------------------------- | |
# Multimedia Tools | |
# ----------------------------------------------------------------------------- | |
def transcribe_audio(audio_path: str) -> str: | |
"""Transcribe the supplied audio file to text using the OpenAI Whisper API (``whisper-1``). | |
Args: | |
audio_path: The path to the audio file to transcribe. | |
Returns: | |
The transcribed text or an error string. | |
""" | |
if not Path(audio_path).exists(): | |
return f"Error: Audio file not found at {audio_path}" | |
api_key = os.getenv("OPENAI_API_KEY") | |
if not api_key: | |
return "Error: OPENAI_API_KEY environment variable is not set." | |
try: | |
from openai import OpenAI # type: ignore | |
client = OpenAI(api_key=api_key) | |
with open(audio_path, "rb") as f: | |
transcription = client.audio.transcriptions.create( | |
model="whisper-1", | |
file=f, | |
) | |
text: str | None = getattr(transcription, "text", None) | |
if text: | |
return text.strip() | |
return "Error: Transcription response did not contain text." | |
except Exception as exc: | |
return f"Error: OpenAI transcription failed: {exc}" | |
# ----------------------------------------------------------------------------- | |
# Math Tools | |
# ----------------------------------------------------------------------------- | |
def multiply(a: float, b: float) -> float: | |
"""Multiply two numbers and return the product.""" | |
return a * b | |
def add(a: float, b: float) -> float: | |
"""Add two numbers and return the sum.""" | |
return a + b | |
def subtract(a: float, b: float) -> float: | |
"""Subtract the second number from the first and return the result.""" | |
return a - b | |
def divide(a: float, b: float) -> float: | |
"""Divide the first number by the second and return the quotient. | |
Raises: | |
ValueError: If b is zero. | |
""" | |
if b == 0: | |
raise ValueError("Cannot divide by zero.") | |
return a / b | |
def modulus(a: int, b: int) -> int: | |
"""Return the modulus of two integers.""" | |
return a % b | |
def power(a: float, b: float) -> float: | |
"""Return a to the power of b.""" | |
return a ** b | |
def square_root(a: float): | |
"""Return the square root of a. Supports complex results for negative inputs.""" | |
if a >= 0: | |
return a ** 0.5 | |
return cmath.sqrt(a) | |
# ----------------------------------------------------------------------------- | |
# File handling tools | |
# ----------------------------------------------------------------------------- | |
def download_file_from_url(url: str, filename: Optional[str] = None) -> str: | |
""" | |
Download a file from a URL and return the local file path. | |
Args: | |
url: The URL to download the file from. | |
filename: The optional name to save the file as. If not provided, it's inferred from the URL. | |
""" | |
import requests | |
from pathlib import Path | |
# If a filename isn't provided, infer it from the URL. | |
if not filename: | |
filename = url.split("/")[-1] | |
download_dir = Path("downloads") | |
download_dir.mkdir(parents=True, exist_ok=True) | |
local_path = download_dir / filename | |
try: | |
resp = requests.get(url, stream=True, timeout=30) | |
resp.raise_for_status() | |
with open(local_path, 'wb') as f: | |
for chunk in resp.iter_content(1024): | |
f.write(chunk) | |
except Exception as e: | |
return f"Error downloading file from {url}: {e}" | |
return str(local_path) | |
def analyze_csv_file(file_path: str) -> str: | |
""" | |
Read a CSV at file_path and return JSON records. | |
""" | |
import pandas as pd | |
from pathlib import Path | |
if not Path(file_path).exists(): | |
return f"Error: file not found at {file_path}" | |
df = pd.read_csv(file_path) | |
return df.to_json(orient="records") | |
def analyze_excel_file(file_path: str) -> str: | |
""" | |
Read an Excel file at file_path and return JSON per sheet. | |
""" | |
import pandas as pd | |
from pathlib import Path | |
import json | |
if not Path(file_path).exists(): | |
return f"Error: file not found at {file_path}" | |
xls = pd.read_excel(file_path, sheet_name=None) | |
result = {name: df.to_json(orient="records") for name, df in xls.items()} | |
return json.dumps(result) | |
def decode_image(image_base64: str) -> Image.Image: | |
"""Decode a base64 encoded image string to a PIL Image.""" | |
image_data = base64.b64decode(image_base64) | |
return Image.open(io.BytesIO(image_data)) | |
def encode_image(image_path: str) -> str: | |
"""Encode an image file to a base64 string.""" | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode('utf-8') | |
def save_image(img: Image.Image, subdir: str = "transformed") -> str: | |
"""Save a PIL image to a file and return the path.""" | |
output_dir = Path("images") / subdir | |
output_dir.mkdir(parents=True, exist_ok=True) | |
# Create a unique filename | |
import uuid | |
filename = f"{uuid.uuid4()}.png" | |
filepath = output_dir / filename | |
img.save(filepath) | |
return str(filepath) | |
### ============== IMAGE PROCESSING AND GENERATION TOOLS =============== ### | |
def analyze_image(image_base64: str) -> Dict[str, Any]: | |
""" | |
Analyze basic properties of an image (size, mode, color analysis, thumbnail preview). | |
Args: | |
image_base64 (str): Base64 encoded image string | |
Returns: | |
Dictionary with analysis result | |
""" | |
try: | |
img = decode_image(image_base64) | |
width, height = img.size | |
mode = img.mode | |
if mode in ("RGB", "RGBA"): | |
arr = np.array(img) | |
avg_colors = arr.mean(axis=(0, 1)) | |
dominant = ["Red", "Green", "Blue"][np.argmax(avg_colors[:3])] | |
brightness = avg_colors.mean() | |
color_analysis = { | |
"average_rgb": avg_colors.tolist(), | |
"brightness": brightness, | |
"dominant_color": dominant, | |
} | |
else: | |
color_analysis = {"note": f"No color analysis for mode {mode}"} | |
thumbnail = img.copy() | |
thumbnail.thumbnail((100, 100)) | |
thumb_path = save_image(thumbnail, "thumbnails") | |
thumbnail_base64 = encode_image(thumb_path) | |
return { | |
"dimensions": (width, height), | |
"mode": mode, | |
"color_analysis": color_analysis, | |
"thumbnail": thumbnail_base64, | |
} | |
except Exception as e: | |
return {"error": str(e)} | |
def transform_image( | |
image_base64: str, operation: str, params: Optional[Dict[str, Any]] = None | |
) -> Dict[str, Any]: | |
""" | |
Apply transformations: resize, rotate, crop, flip, brightness, contrast, blur, sharpen, grayscale. | |
Args: | |
image_base64 (str): Base64 encoded input image | |
operation (str): Transformation operation | |
params (Dict[str, Any], optional): Parameters for the operation | |
Returns: | |
Dictionary with transformed image (base64) | |
""" | |
try: | |
img = decode_image(image_base64) | |
params = params or {} | |
if operation == "resize": | |
img = img.resize( | |
( | |
params.get("width", img.width // 2), | |
params.get("height", img.height // 2), | |
) | |
) | |
elif operation == "rotate": | |
img = img.rotate(params.get("angle", 90), expand=True) | |
elif operation == "crop": | |
img = img.crop( | |
( | |
params.get("left", 0), | |
params.get("top", 0), | |
params.get("right", img.width), | |
params.get("bottom", img.height), | |
) | |
) | |
elif operation == "flip": | |
if params.get("direction", "horizontal") == "horizontal": | |
img = img.transpose(Image.FLIP_LEFT_RIGHT) | |
else: | |
img = img.transpose(Image.FLIP_TOP_BOTTOM) | |
elif operation == "adjust_brightness": | |
img = ImageEnhance.Brightness(img).enhance(params.get("factor", 1.5)) | |
elif operation == "adjust_contrast": | |
img = ImageEnhance.Contrast(img).enhance(params.get("factor", 1.5)) | |
elif operation == "blur": | |
img = img.filter(ImageFilter.GaussianBlur(params.get("radius", 2))) | |
elif operation == "sharpen": | |
img = img.filter(ImageFilter.SHARPEN) | |
elif operation == "grayscale": | |
img = img.convert("L") | |
else: | |
return {"error": f"Unknown operation: {operation}"} | |
result_path = save_image(img) | |
result_base64 = encode_image(result_path) | |
return {"transformed_image": result_base64} | |
except Exception as e: | |
return {"error": str(e)} | |
class Agent: | |
def __init__(self): | |
""" | |
Initializes the Agent by setting up the LLM, tools, and the LangGraph graph. | |
""" | |
# Initialize the LLM | |
# Make sure to set the NEBIUS_API_KEY environment variable | |
nebius_api_key = os.environ.get("NEBIUS_API_KEY") | |
if not nebius_api_key: | |
try: | |
from huggingface_hub import HfApi | |
nebius_api_key = HfApi().get_secret("NEBIUS_API_KEY") | |
except Exception as e: | |
print(f"Could not get NEBIUS_API_KEY from secrets: {e}") | |
raise ValueError("NEBIUS_API_KEY environment variable or secret not set.") | |
llm = ChatOpenAI( | |
model="Qwen/Qwen3-235B-A22B-Instruct-2507", | |
api_key=nebius_api_key, | |
base_url="https://api.studio.nebius.com/v1/" | |
) | |
#llm = ChatOpenAI( | |
# model="gpt-4.1-2025-04-14", | |
#) | |
# Load default system prompt | |
prompt_path = Path(__file__).with_name("system_promt.txt") | |
self.default_system_prompt = ( | |
prompt_path.read_text(encoding="utf-8") | |
if prompt_path.exists() | |
else "You are a helpful assistant. Answer user questions accurately. If tools are available, think whether they are needed. Provide the final answer only." | |
) | |
# ----------------------------------------------------------------------------- | |
# Assemble tool groups for clarity | |
# ----------------------------------------------------------------------------- | |
self.retrieval_tools = [serper_search, open_url] | |
self.media_tools = [transcribe_audio] | |
self.file_tools = [download_file_from_url, analyze_csv_file, analyze_excel_file] | |
self.math_tools = [multiply, add, subtract, divide, modulus, power, square_root] | |
self.image_tools = [analyze_image, transform_image] | |
self.tools = self.retrieval_tools + self.media_tools + self.file_tools + self.math_tools + self.image_tools | |
# Bind tools | |
# ----------------------------------------------------------------------------- | |
self.llm_with_tools = llm.bind_tools(self.tools) | |
# ----------------------------------------------------------------------------- | |
# Agent Graph Definition | |
# ----------------------------------------------------------------------------- | |
graph_builder = StateGraph(State) | |
graph_builder.add_node("assistant", self.assistant_node) | |
graph_builder.add_node("tools", ToolNode(self.tools)) | |
graph_builder.add_node("parser", self.parse_node) | |
graph_builder.add_edge(START, "assistant") | |
graph_builder.add_conditional_edges( | |
"assistant", | |
self.should_continue, | |
{"continue": "tools", "end": "parser"} | |
) | |
graph_builder.add_edge("tools", "assistant") | |
graph_builder.add_edge("parser", "__end__") | |
self.graph = graph_builder.compile() | |
def assistant_node(self, state: State): | |
""" | |
The assistant node in the graph. It calls the LLM with the current state | |
to decide the next action (respond or call a tool). | |
""" | |
messages = state["messages"] | |
system_message = SystemMessage(content=self.default_system_prompt) | |
# Ensure the system message is the first message | |
if not messages or not isinstance(messages[0], SystemMessage): | |
messages.insert(0, system_message) | |
response = self.llm_with_tools.invoke(messages) | |
return {"messages": [response]} | |
def should_continue(self, state: State) -> str: | |
""" | |
Determines whether to continue with tool calls or end the process. | |
""" | |
if state["messages"][-1].tool_calls: | |
return "continue" | |
return "end" | |
def parse_node(self, state: State): | |
""" | |
Parses the final answer to remove the <think> tags. | |
""" | |
import re | |
last_message = state["messages"][-1] | |
content = last_message.content | |
# Use regex to find and extract the content after </think> | |
match_think = re.search(r"</think>\s*(.*)", content, re.DOTALL) | |
if match_think: | |
content = match_think.group(1).strip() | |
# Check for 'FINAL ANSWER:' and extract the content after it | |
match_final_answer = re.search(r"FINAL ANSWER:\s*(.*)", content, re.IGNORECASE | re.DOTALL) | |
if match_final_answer: | |
content = match_final_answer.group(1).strip() | |
last_message.content = content | |
return {"messages": [last_message]} | |
def __call__(self, item: dict, api_url: str) -> str: | |
""" | |
Main entry point for the agent. | |
Args: | |
item: A dictionary containing the question, file_name, etc. | |
api_url: The base URL of the API service. | |
Returns: | |
The agent's final answer as a string. | |
""" | |
question = item.get("question", "") | |
file_name = item.get("file_name") | |
print(f"Agent received question: {question[:100]}...") | |
initial_content = f"Question: {question}" | |
if file_name: | |
task_id = item.get("task_id") | |
# Construct the correct URL for the file using the task_id | |
file_url = f"{api_url}/files/{task_id}" | |
print(f"File detected. Download URL: {file_url}") | |
# Add information about the file to the initial prompt | |
initial_content += f'\n\nThere is a file associated with this question named `{file_name}`. To access its contents, first, download it using the `download_file_from_url` tool. Use the URL `"{file_url}"` and be sure to pass the filename `"{file_name}"` to the `filename` argument. After downloading, use the appropriate tool to analyze the file (e.g., `transcribe_audio` for audio files).' | |
initial_state = {"messages": [HumanMessage(content=initial_content)]} | |
# Invoke the graph | |
final_state = self.graph.invoke(initial_state) | |
# The final answer is the last message from the assistant | |
answer = final_state["messages"][-1].content | |
print(f"Agent returning answer: {answer[:100]}...") | |
return answer |