magdap116 commited on
Commit
718ea39
·
verified ·
1 Parent(s): f34a83a

Update tooling.py

Browse files
Files changed (1) hide show
  1. tooling.py +28 -11
tooling.py CHANGED
@@ -1,8 +1,6 @@
1
- from smolagents import Tool, HfApiModel
2
- import hashlib
3
- import json
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- import os
6
 
7
 
8
  class ModelMathTool(Tool):
@@ -18,13 +16,32 @@ class ModelMathTool(Tool):
18
 
19
  output_type = "string"
20
 
21
- def __init__(self, model_id="Qwen/Qwen2.5-Math-7B"):
22
- print(f"Loading math model: {model_id}")
23
- self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
24
- self.model = HfApiModel(model_id=model_id, max_tokens=512)
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def forward(self, problem: str) -> str:
27
  print(f"[MathModelTool] Question: {problem}")
28
- response = self.model.__call__(problem)
29
- return response
 
 
 
 
 
 
 
 
30
 
 
1
+ from smolagents import Tool
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
3
+ import torch
 
 
4
 
5
 
6
  class ModelMathTool(Tool):
 
16
 
17
  output_type = "string"
18
 
19
+ def __init__(self, model_name= "deepseek-ai/deepseek-math-7b-base"):
20
+ print(f"Loading math model: {model_name}")
21
+
22
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ print("loaded tokenizer")
24
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
25
+ print("loaded auto model")
26
+
27
+ self.model.generation_config = GenerationConfig.from_pretrained(model_name)
28
+ print("loaded coonfig")
29
+
30
+ self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
31
+ print("loaded pad token")
32
+
33
+
34
 
35
  def forward(self, problem: str) -> str:
36
  print(f"[MathModelTool] Question: {problem}")
37
+
38
+ inputs = self.tokenizer(problem, return_tensors="pt")
39
+ outputs =self.model.generate(**inputs, max_new_tokens=100)
40
+
41
+ result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
42
+
43
+ return result
44
+
45
+
46
+
47