Den Pavloff commited on
Commit
e4e9267
·
1 Parent(s): 8a1b058

hf token problem fix 2

Browse files
Files changed (1) hide show
  1. util.py +11 -16
util.py CHANGED
@@ -210,33 +210,28 @@ class KaniModel:
210
  self.player = player
211
  self.hf_token = token
212
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
213
-
214
  print(f"Loading model: {self.conf.model_name}")
215
  print(f"Target device: {self.device}")
216
-
217
- # Load model with proper configuration
218
- load_kwargs = {
219
- "dtype": torch.bfloat16,
220
- "device_map": self.conf.device_map,
221
- "trust_remote_code": True
222
- }
223
  if self.hf_token:
224
- load_kwargs["token"] = self.hf_token
225
 
 
 
226
  self.model = AutoModelForCausalLM.from_pretrained(
227
  self.conf.model_name,
228
- **load_kwargs
 
 
229
  )
230
 
231
- tokenizer_kwargs = {"trust_remote_code": True}
232
- if self.hf_token:
233
- tokenizer_kwargs["token"] = self.hf_token
234
-
235
  self.tokenizer = AutoTokenizer.from_pretrained(
236
  self.conf.model_name,
237
- **tokenizer_kwargs
238
  )
239
-
240
  print(f"Model loaded successfully on device: {next(self.model.parameters()).device}")
241
 
242
  def get_input_ids(self, text_prompt: str, speaker_id:str) -> tuple[torch.tensor]:
 
210
  self.player = player
211
  self.hf_token = token
212
  self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
213
+
214
  print(f"Loading model: {self.conf.model_name}")
215
  print(f"Target device: {self.device}")
216
+
217
+ # Set HF_TOKEN in environment to avoid parameter passing issues
 
 
 
 
 
218
  if self.hf_token:
219
+ os.environ['HF_TOKEN'] = self.hf_token
220
 
221
+ # Load model with proper configuration
222
+ # Don't pass token parameter - it will be read from HF_TOKEN env var
223
  self.model = AutoModelForCausalLM.from_pretrained(
224
  self.conf.model_name,
225
+ dtype=torch.bfloat16,
226
+ device_map=self.conf.device_map,
227
+ trust_remote_code=True
228
  )
229
 
 
 
 
 
230
  self.tokenizer = AutoTokenizer.from_pretrained(
231
  self.conf.model_name,
232
+ trust_remote_code=True
233
  )
234
+
235
  print(f"Model loaded successfully on device: {next(self.model.parameters()).device}")
236
 
237
  def get_input_ids(self, text_prompt: str, speaker_id:str) -> tuple[torch.tensor]: