Coool2 commited on
Commit
5e30500
·
verified ·
1 Parent(s): 1d635fd

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +31 -5
agent.py CHANGED
@@ -21,6 +21,7 @@ from llama_index.core import Settings
21
 
22
  from transformers import AutoModelForCausalLM, AutoTokenizer
23
  from llama_index.llms.huggingface import HuggingFaceLLM
 
24
 
25
  llama_debug = LlamaDebugHandler(print_trace_on_end=True)
26
  callback_manager = CallbackManager([llama_debug])
@@ -579,26 +580,52 @@ class EnhancedGAIAAgent:
579
  Answer:"""
580
 
581
  try:
582
- # Use a simple, fast LLM for formatting
583
  formatting_response = proj_llm.complete(format_prompt)
584
  answer = str(formatting_response).strip()
585
 
 
 
 
 
586
  return answer
587
 
588
  except Exception as e:
589
  print(f"Error in formatting: {e}")
590
  return self._extract_fallback_answer(raw_response)
591
 
592
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
593
  async def solve_gaia_question(self, question_data: Dict[str, Any]) -> str:
594
  question = question_data.get("Question", "")
595
  task_id = question_data.get("task_id", "")
596
  print("data",question_data)
597
-
 
 
 
 
 
 
 
 
 
598
  context_prompt = f"""
599
  GAIA Task ID: {task_id}
600
  Question: {question}
601
- {f"Associated files: {question_data.get('file_name', '')}" if 'file_name' in question_data else 'No files provided'}
602
 
603
  Analyze this question and provide your reasoning and final answer.
604
  """
@@ -611,7 +638,6 @@ class EnhancedGAIAAgent:
611
  # Post-process to extract exact GAIA format
612
  formatted_answer = await self.format_gaia_answer(str(raw_response), question)
613
 
614
- print(f"Raw response: {raw_response}")
615
  print(f"Formatted answer: {formatted_answer}")
616
 
617
  return formatted_answer
 
21
 
22
  from transformers import AutoModelForCausalLM, AutoTokenizer
23
  from llama_index.llms.huggingface import HuggingFaceLLM
24
+ import requests
25
 
26
  llama_debug = LlamaDebugHandler(print_trace_on_end=True)
27
  callback_manager = CallbackManager([llama_debug])
 
580
  Answer:"""
581
 
582
  try:
 
583
  formatting_response = proj_llm.complete(format_prompt)
584
  answer = str(formatting_response).strip()
585
 
586
+ # Extract just the answer after "Answer:"
587
+ if "Answer:" in answer:
588
+ answer = answer.split("Answer:")[-1].strip()
589
+
590
  return answer
591
 
592
  except Exception as e:
593
  print(f"Error in formatting: {e}")
594
  return self._extract_fallback_answer(raw_response)
595
 
596
+ def download_gaia_file(self, task_id: str, api_url: str = "https://agents-course-unit4-scoring.hf.space") -> str:
597
+ """Download file associated with task_id"""
598
+ try:
599
+ response = requests.get(f"{api_url}/files/{task_id}", timeout=30)
600
+ response.raise_for_status()
601
+
602
+ # Save file locally
603
+ filename = f"task_{task_id}_file"
604
+ with open(filename, 'wb') as f:
605
+ f.write(response.content)
606
+ return filename
607
+ except Exception as e:
608
+ print(f"Failed to download file for task {task_id}: {e}")
609
+ return None
610
+
611
  async def solve_gaia_question(self, question_data: Dict[str, Any]) -> str:
612
  question = question_data.get("Question", "")
613
  task_id = question_data.get("task_id", "")
614
  print("data",question_data)
615
+
616
+ try:
617
+ file_path = self.download_gaia_file(task_id)
618
+ except FileNotFoundError as e:
619
+ print(f"File not found for task {task_id}: {e}")
620
+ file_path = None
621
+ except Exception as e:
622
+ print(f"Unexpected error downloading file for task {task_id}: {e}")
623
+ file_path = None
624
+
625
  context_prompt = f"""
626
  GAIA Task ID: {task_id}
627
  Question: {question}
628
+ {'File downloaded: ' + file_path if file_path else 'No files referenced'}
629
 
630
  Analyze this question and provide your reasoning and final answer.
631
  """
 
638
  # Post-process to extract exact GAIA format
639
  formatted_answer = await self.format_gaia_answer(str(raw_response), question)
640
 
 
641
  print(f"Formatted answer: {formatted_answer}")
642
 
643
  return formatted_answer