|
import os |
|
import logging |
|
import requests |
|
import subprocess |
|
import time |
|
import re |
|
from functools import lru_cache |
|
from typing import List, Dict, Optional, Tuple |
|
import spacy |
|
from openai import OpenAI |
|
import gradio as gr |
|
|
|
from tool import Browser, SearchInformationTool |
|
|
|
|
|
try: |
|
nlp = spacy.load("en_core_web_sm") |
|
except OSError: |
|
subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"]) |
|
nlp = spacy.load("en_core_web_sm") |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
|
|
|
def initialize_openai_client(): |
|
try: |
|
return OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
except Exception as e: |
|
logger.error(f"Failed to initialize OpenAI client: {str(e)}") |
|
raise |
|
|
|
class EnAgent: |
|
def __init__(self, api_url: str = DEFAULT_API_URL): |
|
self.api_url = api_url |
|
self.openai_client = initialize_openai_client() |
|
self.browser = Browser() |
|
self.search_tool = SearchInformationTool(browser=self.browser) |
|
logger.info("EnAgent initialized.") |
|
|
|
def fetch_questions(self) -> Optional[List[Dict]]: |
|
try: |
|
response = requests.get(f"{self.api_url}/questions", timeout=15) |
|
response.raise_for_status() |
|
return response.json() |
|
except Exception as e: |
|
logger.error(f"Error fetching questions: {e}") |
|
return None |
|
|
|
def submit_answers(self, answers_payload: List[Dict], username: str, agent_code: str) -> Optional[Dict]: |
|
try: |
|
response = requests.post( |
|
f"{self.api_url}/submit", |
|
json={"username": username, "agent_code": agent_code, "answers": answers_payload}, |
|
timeout=60 |
|
) |
|
response.raise_for_status() |
|
return response.json() |
|
except Exception as e: |
|
logger.error(f"Error submitting answers: {e}") |
|
return None |
|
|
|
def answer_question_with_context(self, context: str) -> str: |
|
full_prompt = f"""{context} |
|
When answering, provide only the exact answer requested. |
|
Do not include explanations, steps, justifications, or additional text. |
|
""" |
|
try: |
|
answer = self.agent.run(full_prompt) |
|
answer = self._clean_answer(answer) |
|
if self.verbose: |
|
print(f"Generated answer: {answer}") |
|
return answer |
|
except Exception as e: |
|
error_msg = f"Error answering question: {e}" |
|
if self.verbose: |
|
print(error_msg) |
|
return error_msg |
|
|
|
def _clean_answer(self, answer: any) -> str: |
|
""" |
|
Clean up your response by removing common prefixes and formatting. |
|
Args: |
|
answer: The raw answer from the model |
|
Returns: |
|
The cleaned answer as a string |
|
""" |
|
if not isinstance(answer, str): |
|
if isinstance(answer, float): |
|
if answer.is_integer(): |
|
formatted_answer = str(int(answer)) |
|
return formatted_answer |
|
elif isinstance(answer, int): |
|
return str(answer) |
|
else: |
|
return str(answer) |
|
|
|
answer = answer.strip() |
|
|
|
prefixes_to_remove = [ |
|
"The answer is ", |
|
"Answer: ", |
|
"Final answer: ", |
|
"The result is ", |
|
"To answer this question: ", |
|
"Based on the information provided, ", |
|
"According to the information: ", |
|
] |
|
|
|
for prefix in prefixes_to_remove: |
|
if answer.startswith(prefix): |
|
answer = answer[len(prefix):].strip() |
|
|
|
if (answer.startswith('"') and answer.endswith('"')) or (answer.startswith("'") and answer.endswith("'")): |
|
answer = answer[1:-1].strip() |
|
|
|
return answer |
|
|
|
def analyze_question_intent(self, question: str) -> str: |
|
doc = nlp(question.lower()) |
|
for token in doc: |
|
if token.text in ["how", "many", "much", "number", "count"]: |
|
return "count" |
|
elif token.text in ["who", "name", "person"]: |
|
return "name" |
|
elif token.text in ["when", "date", "year"]: |
|
return "date" |
|
elif token.text in ["where", "place", "location"]: |
|
return "location" |
|
elif token.text in ["what", "which"]: |
|
return "fact" |
|
return "unknown" |
|
|
|
def extract_number_between_years(self, text: str, start: int, end: int) -> Optional[int]: |
|
year_matches = re.findall(r"\b(19|20)\d{2}\b", text) |
|
years = [int(y) for y in year_matches if start <= int(y) <= end] |
|
return len(set(years)) if years else None |
|
|
|
def format_answer(self, question: str, answer: str, intent: str) -> str: |
|
answer = answer.strip() |
|
logger.info(f"Intent: {intent} | Raw answer: {answer}") |
|
|
|
if intent == "count": |
|
year_matches = re.findall(r"\b(19|20)\d{2}\b", question) |
|
years = list(map(int, year_matches)) |
|
|
|
if len(years) >= 2: |
|
start, end = sorted(years[:2]) |
|
number = self.extract_number_between_years(answer, start, end) |
|
if number is not None: |
|
logger.info(f"Extracted number from years: {number}") |
|
return str(number) |
|
|
|
album_match = re.search(r"(one|two|three|four|five|\d+)\s+(studio\s+)?albums?", answer.lower()) |
|
if album_match: |
|
number_word = album_match.group(1) |
|
number = self.convert_word_to_number(number_word) if number_word.isalpha() else int(number_word) |
|
if number: |
|
logger.info(f"Extracted number from album phrase: {number}") |
|
return str (number) |
|
|
|
numbers = re.findall(r"\d+", answer) |
|
if numbers: |
|
logger.info(f"Extracted fallback number: {numbers[0]}") |
|
return numbers[0] |
|
|
|
return answer |
|
|
|
elif intent == "name": |
|
doc = nlp(answer) |
|
persons = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG"]] |
|
return persons[0] if persons else answer |
|
|
|
|
|
elif intent == "date": |
|
doc = nlp(answer) |
|
for ent in doc.ents: |
|
if ent.label_ == "DATE": |
|
return ent.text |
|
return answer |
|
|
|
elif intent == "location": |
|
doc = nlp(answer) |
|
for ent in doc.ents: |
|
if ent.label_ == "GPE": |
|
return ent.text |
|
return answer |
|
|
|
elif intent == "fact": |
|
return answer |
|
|
|
return answer |
|
|
|
def find_country_with_min_athletes(text: str) -> Optional[str]: |
|
matches = re.findall(r"\b([A-Z][a-z]+(?: [A-Z][a-z]+)?)\s*\((\d+)\)", text) |
|
if not matches: |
|
return None |
|
min_count = min(int(c) for _, c in matches) |
|
filtered = [country for country, count in matches if int(count) == min_count] |
|
return sorted(filtered)[0] if filtered else None |
|
|
|
def extract_ioc_code(country_name: str, ioc_text: str) -> Optional[str]: |
|
pattern = re.compile(rf"{re.escape(country_name)}\s*\((\w{{3}})\)", re.IGNORECASE) |
|
match = pattern.search(ioc_text) |
|
return match.group(1).upper() if match else None |
|
|
|
def preprocess_question(self, question: str) -> str: |
|
question = question.strip().lower() |
|
question = re.sub(r"[^\w\s]", "", question) |
|
question = re.sub(r"\s+", " ", question) |
|
return question |
|
|
|
def search_with_reference(self, query: str) -> str: |
|
domains = [] |
|
query_lower = query.lower() |
|
wikipedia_related_keywords = ["information", "article", "search", "learn", "facts", "data", "country", "athlete"] |
|
if any(keyword in query_lower for keyword in wikipedia_related_keywords) or "wikipedia" in query_lower or "wikipedia.org" in query_lower: |
|
domains.append("en.wikipedia.org") |
|
if "wikipedia" in query_lower or "wikipedia.org" in query_lower: |
|
domains.append("en.wikipedia.org") |
|
if "baseball reference" in query_lower: |
|
domains.append("www.baseball-reference.com") |
|
if "imdb" in query_lower: |
|
domains.append("www.imdb.com") |
|
if domains: |
|
domain_filters = " OR ".join([f"site:{domain}" for domain in domains]) |
|
query = f"{query} ({domain_filters})" |
|
search_result = self.search_tool.forward(query) |
|
if not search_result or "An error occurred" in search_result or "No results found" in search_result: |
|
logger.warning("Search returned no usable results.") |
|
return "" |
|
return search_result[:1000] |
|
|
|
@lru_cache(maxsize=128) |
|
def answer_question(self, question: str) -> str: |
|
logger.info(f"Answering question with reasoning: {question[:50]}...") |
|
try: |
|
source_text = self.search_with_reference(question) |
|
intent = self.analyze_question_intent(question) |
|
system_prompt = ( |
|
"You are a concise assistant. You do it step by step.To search for information, you can use Wikipedia and the sources of information specified in the question. You are only answering the question." |
|
"When answering, provide only the exact answer requested." |
|
"Do not include explanations, steps, justifications, or additional text." |
|
"For example, if you are asked: What is the capital of France?, simply answer: Paris." |
|
"For example, to answer four chairs, simply answer: 4" |
|
) |
|
content_block = f"Question: {question}" |
|
if source_text: |
|
content_block += f"\n\nSource:\n{source_text}" |
|
messages = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": content_block} |
|
] |
|
response = self.openai_client.chat.completions.create( |
|
model="gpt-4o", |
|
messages=messages, |
|
temperature=0.3, |
|
max_tokens=50 |
|
) |
|
if response.choices: |
|
raw = response.choices[0].message.content.strip() |
|
match = re.search(r"(?i)answer\s*[:\\-]?\s*(.*)", raw) |
|
final = match.group(1).strip() if match else raw |
|
return self.format_answer(question, final, intent) |
|
return "No answer." |
|
except Exception as e: |
|
logger.error(f"Error answering question: {e}") |
|
return f"Error: {e}" |
|
|
|
def process_questions(self, questions: List[str]) -> List[Dict]: |
|
results = [] |
|
for question in questions: |
|
time.sleep(1) |
|
answer = self.answer_question(question) |
|
results.append({"Question": question, "Answer": answer}) |
|
return results |
|
|
|
def process_uploaded_file(self, file_path: str) -> List[str]: |
|
try: |
|
ext = os.path.splitext(file_path)[1].lower() |
|
if ext == ".pdf": |
|
return self.extract_questions_from_pdf(file_path) |
|
elif ext == ".txt": |
|
return self.extract_questions_from_txt(file_path) |
|
elif ext == ".md": |
|
return self.extract_questions_from_markdown(file_path) |
|
elif ext in [".xls", ".xlsx"]: |
|
return self.extract_questions_from_excel(file_path) |
|
elif ext == ".csv": |
|
return self.extract_questions_from_csv(file_path) |
|
elif ext in [".mp4", ".avi", ".mov"]: |
|
return self.extract_images_from_video(file_path) |
|
else: |
|
logger.error("Unsupported file format.") |
|
return [] |
|
except Exception as e: |
|
logger.error(f"Error processing file: {e}") |
|
return [] |
|
|
|
def extract_questions_from_pdf(self, file_path: str) -> List[str]: |
|
return ["Question from PDF"] |
|
|
|
def extract_questions_from_txt(self, file_path: str) -> List[str]: |
|
return ["Question from TXT"] |
|
|
|
def extract_questions_from_markdown(self, file_path: str) -> List[str]: |
|
return ["Question from Markdown"] |
|
|
|
def extract_questions_from_excel(self, file_path: str) -> List[str]: |
|
try: |
|
import pandas as pd |
|
df = pd.read_excel(file_path) |
|
for col in df.columns: |
|
if df[col].dtype == object: |
|
return df[col].dropna().astype(str).tolist() |
|
return [] |
|
except Exception as e: |
|
logger.error(f"Error extracting from Excel: {e}") |
|
return [] |
|
|
|
def extract_questions_from_csv(self, file_path: str) -> List[str]: |
|
return ["Question from CSV"] |
|
|
|
def extract_images_from_video(self, file_path: str) -> List[str]: |
|
return ["Frame 1", "Frame 2"] |
|
|
|
def run_and_submit_all(profile: Optional[gr.OAuthProfile]) -> Tuple[str, Optional[List[Dict]]]: |
|
try: |
|
if profile is None or not hasattr(profile, "username"): |
|
return "❌ Please log in to Hugging Face.", None |
|
username = profile.username |
|
space_id = os.getenv("SPACE_ID") |
|
if not space_id: |
|
return "❌ SPACE_ID environment variable not set.", None |
|
agent = EnAgent() |
|
questions = agent.fetch_questions() |
|
if not questions: |
|
return "❌ Failed to fetch questions.", None |
|
results = agent.process_questions([q["question"] for q in questions]) |
|
answers_payload = [ |
|
{ |
|
"task_id": q["id"] if "id" in q else q["task_id"], |
|
"final_answer": next((r["Answer"] for r in results if r["Question"] == q["question"]), "") |
|
} |
|
for q in questions |
|
] |
|
submission_result = agent.submit_answers(answers_payload, username, space_id) |
|
if not submission_result: |
|
return "❌ Failed to submit answers.", None |
|
return "✅ Answers submitted successfully.", results |
|
except Exception as e: |
|
logger.error(f"Unexpected error in run_and_submit_all: {e}") |
|
return f"❌ Unexpected error: {e}", None |
|
|
|
|