Spaces:
Sleeping
Sleeping
import asyncio | |
import json | |
import logging | |
import os | |
import sys | |
import aiohttp | |
from typing import List, Dict, Any, Optional, Union | |
from dotenv import load_dotenv | |
from mcp.server import FastMCP | |
from pydantic import Field | |
from mcp.types import TextContent | |
from aworld.logs.util import logger | |
mcp = FastMCP("aworldsearch-server") | |
async def search_single(query: str, num: int = 5) -> Optional[Dict[str, Any]]: | |
"""Execute a single search query, returns None on error""" | |
try: | |
url = os.getenv('AWORLD_SEARCH_URL') | |
searchMode = os.getenv('AWORLD_SEARCH_SEARCHMODE') | |
source = os.getenv('AWORLD_SEARCH_SOURCE') | |
domain = os.getenv('AWORLD_SEARCH_DOMAIN') | |
uid = os.getenv('AWORLD_SEARCH_UID') | |
if not url or not searchMode or not source or not domain: | |
logger.warning(f"Query failed: url, searchMode, source, domain parameters incomplete") | |
return None | |
headers = { | |
'Content-Type': 'application/json' | |
} | |
data = { | |
"domain": domain, | |
"extParams": {}, | |
"page": 0, | |
"pageSize": num, | |
"query": query, | |
"searchMode": searchMode, | |
"source": source, | |
"userId": uid | |
} | |
async with aiohttp.ClientSession() as session: | |
try: | |
async with session.post(url, headers=headers, json=data) as response: | |
if response.status != 200: | |
logger.warning(f"Query failed: {query}, status code: {response.status}") | |
return None | |
result = await response.json() | |
return result | |
except aiohttp.ClientError: | |
logger.warning(f"Request error: {query}") | |
return None | |
except Exception: | |
logger.warning(f"Query exception: {query}") | |
return None | |
def filter_valid_docs(result: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
"""Filter valid document results, returns empty list if input is None""" | |
if result is None: | |
return [] | |
try: | |
valid_docs = [] | |
# Check success field | |
if not result.get("success"): | |
return valid_docs | |
# Check searchDocs field | |
search_docs = result.get("searchDocs", []) | |
if not search_docs: | |
return valid_docs | |
# Extract required fields | |
required_fields = ["title", "docAbstract", "url", "doc"] | |
for doc in search_docs: | |
# Check if all required fields exist and are not empty | |
is_valid = True | |
for field in required_fields: | |
if field not in doc or not doc[field]: | |
is_valid = False | |
break | |
if is_valid: | |
# Keep only required fields | |
filtered_doc = {field: doc[field] for field in required_fields} | |
valid_docs.append(filtered_doc) | |
return valid_docs | |
except Exception: | |
return [] | |
async def search( | |
query_list: List[str] = Field( | |
description="List format, queries to search for" | |
), | |
num: int = Field( | |
5, | |
description="Maximum number of results per query, default is 5, please keep the total results within 15" | |
) | |
) -> Union[str, TextContent]: | |
"""Execute search main function, supports single query or query list""" | |
try: | |
# Get configuration from environment variables | |
env_total_num = os.getenv('AWORLD_SEARCH_TOTAL_NUM') | |
if env_total_num and env_total_num.isdigit(): | |
# Force override input num parameter with environment variable | |
num = int(env_total_num) | |
# If no queries provided, return empty list | |
if not query_list: | |
# Initialize TextContent with additional parameters | |
return TextContent( | |
type="text", | |
text="", # Empty string instead of None | |
**{"metadata": {}} # Pass as additional fields | |
) | |
# When query count is >= 3 or slice_num is set, use corresponding value | |
slice_num = os.getenv('AWORLD_SEARCH_SLICE_NUM') | |
if slice_num and slice_num.isdigit(): | |
actual_num = int(slice_num) | |
else: | |
actual_num = 2 if len(query_list) >= 3 else num | |
# Execute all queries in parallel | |
tasks = [search_single(q, actual_num) for q in query_list] | |
raw_results = await asyncio.gather(*tasks) | |
# Filter and merge results | |
all_valid_docs = [] | |
for result in raw_results: | |
valid_docs = filter_valid_docs(result) | |
all_valid_docs.extend(valid_docs) | |
# If no valid results found, return empty list | |
if not all_valid_docs: | |
# Initialize TextContent with additional parameters | |
return TextContent( | |
type="text", | |
text="", # Empty string instead of None | |
**{"metadata": {}} # Pass as additional fields | |
) | |
# Format results as JSON | |
result_json = json.dumps(all_valid_docs, ensure_ascii=False) | |
# Create dictionary structure directly | |
combined_query = ",".join(query_list) | |
search_items = [] | |
for doc in all_valid_docs: | |
search_items.append({ | |
"title": doc.get("title", ""), | |
"url": doc.get("url", ""), | |
"content": doc.get("doc", "") # Map doc field to content | |
}) | |
search_output_dict = { | |
"query": combined_query, | |
"results": search_items | |
} | |
# Log results | |
logger.info(f"Completed {len(query_list)} queries, found {len(all_valid_docs)} valid documents") | |
# Initialize TextContent with additional parameters | |
return TextContent( | |
type="text", | |
text=result_json, | |
**{"metadata": search_output_dict} # Pass processed data as metadata | |
) | |
except Exception as e: | |
# Handle errors | |
logger.error(f"Search error: {e}") | |
# Initialize TextContent with additional parameters | |
return TextContent( | |
type="text", | |
text="", # Empty string instead of None | |
**{"metadata": {}} # Pass as additional fields | |
) | |
def main(): | |
from dotenv import load_dotenv | |
load_dotenv(override=True) | |
print("Starting Audio MCP aworldsearch-server...", file=sys.stderr) | |
mcp.run(transport="stdio") | |
# Make the module callable | |
def __call__(): | |
""" | |
Make the module callable for uvx. | |
This function is called when the module is executed directly. | |
""" | |
main() | |
sys.modules[__name__].__call__ = __call__ | |
if __name__ == "__main__": | |
main() | |
# if __name__ == "__main__": | |
# # Configure logging | |
# logging.basicConfig( | |
# level=logging.INFO, | |
# format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
# ) | |
# | |
# | |
# # Test single query | |
# # asyncio.run(search("Alibaba financial report")) | |
# | |
# # Test multiple queries | |
# test_queries = ["Alibaba financial report", "Tencent financial report", "Baidu financial report"] | |
# asyncio.run(search(query_list=test_queries)) |