magdap116 commited on
Commit
496166b
·
verified ·
1 Parent(s): 8324716

Upload tooling.py

Browse files
Files changed (1) hide show
  1. tooling.py +92 -0
tooling.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import DuckDuckGoSearchTool, HfApiModel, load_tool, CodeAgent, PythonInterpreterTool, VisitWebpageTool, \
2
+ Tool
3
+ import hashlib
4
+ import json
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ import os
7
+
8
+
9
+ class ModelMathTool(Tool):
10
+ name = "math_model"
11
+ description = "Answers advanced math questions using a pretrained math model."
12
+
13
+ inputs = {
14
+ "problem": {
15
+ "type": "string",
16
+ "description": "Math problem to solve.",
17
+ }
18
+ }
19
+
20
+ output_type = "string"
21
+
22
+ def __init__(self, model_id="Qwen/Qwen2.5-Math-7B"):
23
+ print(f"Loading math model: {model_id}")
24
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
25
+ self.model = HfApiModel(model_id=model_id, max_tokens=512)
26
+
27
+ def forward(self, problem: str) -> str:
28
+ print(f"[MathModelTool] Question: {problem}")
29
+ response = self.model.__call__(problem)
30
+ return response
31
+
32
+
33
+ # (Keep Constants as is)
34
+ # --- Constants ---
35
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
36
+
37
+ web_search = DuckDuckGoSearchTool()
38
+ python_interpreter = PythonInterpreterTool()
39
+ visit_webpage_tool = VisitWebpageTool()
40
+ model_math_tool = ModelMathTool()
41
+
42
+ # If the agent does not answer, the model is overloaded, please use another model or the following Hugging Face Endpoint that also contains qwen2.5 coder:
43
+ # model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud'
44
+
45
+ model = HfApiModel(model_id="HuggingFaceH4/zephyr-7b-beta", max_tokens=512, token=tok)
46
+
47
+
48
+ def get_cache_key(question: str) -> str:
49
+ return hashlib.sha256(question.encode()).hexdigest()
50
+
51
+
52
+ def load_cached_answer(question: str) -> str | None:
53
+ key = get_cache_key(question)
54
+ path = f"cache/{key}.json"
55
+ if os.path.exists(path):
56
+ with open(path, "r") as f:
57
+ data = json.load(f)
58
+ return data.get("answer")
59
+ return None
60
+
61
+
62
+ def cache_answer(question: str, answer: str):
63
+ key = get_cache_key(question)
64
+ path = f"cache/{key}.json"
65
+ with open(path, "w") as f:
66
+ json.dump({"question": question, "answer": answer}, f)
67
+
68
+
69
+ class BasicAgent:
70
+ def __init__(self):
71
+ print("BasicAgent initialized.")
72
+ self.agent = CodeAgent(
73
+ model=model,
74
+ tools=[model_math_tool],
75
+ max_steps=1,
76
+ verbosity_level=0,
77
+ grammar=None,
78
+ planning_interval=3,
79
+
80
+ )
81
+
82
+ def __call__(self, question: str) -> str:
83
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
84
+ answer = self.agent.run(question)
85
+ return answer
86
+
87
+
88
+
89
+ agent = BasicAgent()
90
+
91
+ response = agent.__call__(question="How much is 2*2?")
92
+ print(response)