magdap116 commited on
Commit
0644a6d
·
verified ·
1 Parent(s): 51a04fc

Update tooling.py

Browse files
Files changed (1) hide show
  1. tooling.py +128 -125
tooling.py CHANGED
@@ -1,125 +1,128 @@
1
- from smolagents import Tool
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
3
- import torch
4
- from wikipedia_utils import *
5
- from youtube_utils import *
6
-
7
-
8
- class MathModelQuerer(Tool):
9
- name = "math_model"
10
- description = "Answers advanced math questions using a pretrained math model."
11
-
12
- inputs = {
13
- "problem": {
14
- "type": "string",
15
- "description": "Math problem to solve.",
16
- }
17
- }
18
-
19
- output_type = "string"
20
-
21
- def __init__(self, model_name="deepseek-ai/deepseek-math-7b-base"):
22
- print(f"Loading math model: {model_name}")
23
-
24
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
- print("loaded tokenizer")
26
- self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
27
- print("loaded auto model")
28
-
29
- self.model.generation_config = GenerationConfig.from_pretrained(model_name)
30
- print("loaded coonfig")
31
-
32
- self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
33
- print("loaded pad token")
34
-
35
- def forward(self, problem: str) -> str:
36
- try:
37
- print(f"[MathModelTool] Question: {problem}")
38
-
39
- inputs = self.tokenizer(problem, return_tensors="pt")
40
- outputs = self.model.generate(**inputs, max_new_tokens=100)
41
-
42
- result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
43
-
44
- return result
45
- except:
46
- return f"Failed using the tool {self.name}"
47
-
48
-
49
- class CodeModelQuerer(Tool):
50
- name = "code_querer"
51
- description = "Given a problem description, generates a piece of code used specialized LLM model. Returns output of the model."
52
-
53
- inputs = {
54
- "problem": {
55
- "type": "string",
56
- "description": "Description of a code sample to be generated",
57
- }
58
- }
59
-
60
- output_type = "string"
61
-
62
- def __init__(self, model_name="Qwen/Qwen2.5-Coder-32B-Instruct"):
63
- from smolagents import HfApiModel
64
- print(f"Loading llm for Code tool: {model_name}")
65
- self.model = HfApiModel()
66
-
67
- def forward(self, problem: str) -> str:
68
- try:
69
- return self.model.generate(problem, max_new_tokens=512)
70
- except:
71
- return f"Failed using the tool {self.name}"
72
-
73
-
74
- class WikipediaPageFetcher(Tool):
75
- name = "wiki_page_fetcher"
76
- description = "Searches Wikipedia and provides summary about the queried topic as a string."
77
-
78
- inputs = {
79
- "query": {
80
- "type": "string",
81
- "description": "Topic of wikipedia search",
82
- }
83
- }
84
-
85
- output_type = "string"
86
-
87
- def forward(self, query: str) -> str:
88
- try:
89
- wiki_query = query(query)
90
- wiki_page = fetch_wikipedia_page(wiki_query)
91
- return wiki_page
92
- except:
93
- return f"Failed using the tool {self.name}"
94
-
95
-
96
- class YoutubeTranscriptFetcher(Tool):
97
- name = "youtube_transcript_fetcher"
98
- description = "Attempts to fetch a youtube transcript in english, if provided with a query \\" \
99
- " that contains a youtube link with video id. Returns a transcript content as a string. Alternatively, if tool is provided with a\\"" \
100
- youtube video id, it can fetch the transcript directly."
101
-
102
- inputs = {
103
- "query": {
104
- "type": "string",
105
- "description": "A query that includes youtube id."
106
- },
107
- "video_id" : {
108
- "type" : "string",
109
- "description" : "Optional string with video id from youtube.",
110
- "nullable" : True
111
- }
112
- }
113
-
114
- output_type = "string"
115
-
116
- def forward(self, query: str, video_id=None) -> str:
117
- try:
118
- if video_id is None:
119
- video_id = get_youtube_video_id(query)
120
-
121
- fetched_transcript = fetch_transcript_english(video_id)
122
-
123
- return post_process_transcript(fetched_transcript)
124
- except:
125
- return f"Failed using the tool {self.name}"
 
 
 
 
1
+ from smolagents import Tool
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
3
+ import torch
4
+ from wikipedia_utils import *
5
+ from youtube_utils import *
6
+
7
+
8
+ class MathModelQuerer(Tool):
9
+ name = "math_model"
10
+ description = "Answers advanced math questions using a pretrained math model."
11
+
12
+ inputs = {
13
+ "problem": {
14
+ "type": "string",
15
+ "description": "Math problem to solve.",
16
+ }
17
+ }
18
+
19
+ output_type = "string"
20
+
21
+ def __init__(self, model_name="deepseek-ai/deepseek-math-7b-base"):
22
+ print(f"Loading math model: {model_name}")
23
+
24
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ print("loaded tokenizer")
26
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
27
+ print("loaded auto model")
28
+
29
+ self.model.generation_config = GenerationConfig.from_pretrained(model_name)
30
+ print("loaded coonfig")
31
+
32
+ self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
33
+ print("loaded pad token")
34
+
35
+ def forward(self, problem: str) -> str:
36
+ try:
37
+ print(f"[MathModelTool] Question: {problem}")
38
+
39
+ inputs = self.tokenizer(problem, return_tensors="pt")
40
+ outputs = self.model.generate(**inputs, max_new_tokens=100)
41
+
42
+ result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
43
+
44
+ return result
45
+ except:
46
+ return f"Failed using the tool {self.name}"
47
+
48
+
49
+ class CodeModelQuerer(Tool):
50
+ name = "code_querer"
51
+ description = "Given a problem description, generates a piece of code used specialized LLM model. Returns output of the model."
52
+
53
+ inputs = {
54
+ "problem": {
55
+ "type": "string",
56
+ "description": "Description of a code sample to be generated",
57
+ }
58
+ }
59
+
60
+ output_type = "string"
61
+
62
+ def __init__(self, model_name="Qwen/Qwen2.5-Coder-32B-Instruct"):
63
+ from smolagents import HfApiModel
64
+ print(f"Loading llm for Code tool: {model_name}")
65
+ self.model = HfApiModel()
66
+
67
+ def forward(self, problem: str) -> str:
68
+ try:
69
+ return self.model.generate(problem, max_new_tokens=512)
70
+ except:
71
+ return f"Failed using the tool {self.name}"
72
+
73
+
74
+ class WikipediaPageFetcher(Tool):
75
+ name = "wiki_page_fetcher"
76
+ description = "Searches Wikipedia and provides summary about the queried topic as a string.\
77
+ Use for all wikipedia queries regardless of the language and version.\
78
+ Only provide query as an input parameter."
79
+
80
+ inputs = {
81
+ "query": {
82
+ "type": "string",
83
+ "description": "Topic of wikipedia search",
84
+ }
85
+ }
86
+
87
+ output_type = "string"
88
+
89
+ def forward(self, query: str) -> str:
90
+ try:
91
+ wiki_query = query(query)
92
+ wiki_page = fetch_wikipedia_page(wiki_query)
93
+ return wiki_page
94
+ except:
95
+ return f"Failed using the tool {self.name}"
96
+
97
+
98
+ class YoutubeTranscriptFetcher(Tool):
99
+ name = "youtube_transcript_fetcher"
100
+ description = "Attempts to fetch a youtube transcript in english, if provided with a query \\" \
101
+ " that contains a youtube link with video id. Returns a transcript content as a string. Alternatively, if tool is provided with a\\"" \
102
+ youtube video id, it can fetch the transcript directly. Video id consist of last 11 strings of the url. Only provide this parameter, if the video id doesn't have\
103
+ to be parsed from a url."
104
+
105
+ inputs = {
106
+ "query": {
107
+ "type": "string",
108
+ "description": "A query that includes youtube id."
109
+ },
110
+ "video_id" : {
111
+ "type" : "string",
112
+ "description" : "Optional string with video id from youtube.",
113
+ "nullable" : True
114
+ }
115
+ }
116
+
117
+ output_type = "string"
118
+
119
+ def forward(self, query: str, video_id=None) -> str:
120
+ try:
121
+ if video_id is None:
122
+ video_id = get_youtube_video_id(query)
123
+
124
+ fetched_transcript = fetch_transcript_english(video_id)
125
+
126
+ return post_process_transcript(fetched_transcript)
127
+ except:
128
+ return f"Failed using the tool {self.name}"