boryasbora commited on
Commit
9b860ac
·
verified ·
1 Parent(s): 6f785d9

Update huggingface_llm.py

Browse files
Files changed (1) hide show
  1. huggingface_llm.py +3 -3
huggingface_llm.py CHANGED
@@ -16,13 +16,13 @@ class HuggingFaceLLM(LLM):
16
 
17
  def __init__(self, **kwargs):
18
  super().__init__(**kwargs)
19
- if self.device == "cpu":
20
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
  self._load_model()
22
 
23
  def _load_model(self):
24
  self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
25
- self._model = AutoModelForCausalLM.from_pretrained(self.model_id).to(self.device)
 
26
 
27
  @property
28
  def _llm_type(self) -> str:
 
16
 
17
  def __init__(self, **kwargs):
18
  super().__init__(**kwargs)
19
+ self.device = "cuda" if torch.cuda.is_available() and self.device != "cpu" else "cpu"
 
20
  self._load_model()
21
 
22
  def _load_model(self):
23
  self._tokenizer = AutoTokenizer.from_pretrained(self.model_id)
24
+ self._model = AutoModelForCausalLM.from_pretrained(self.model_id)
25
+ self._model = self._model.to(torch.device(self.device))
26
 
27
  @property
28
  def _llm_type(self) -> str: