onisj commited on
Commit
46c7672
·
verified ·
1 Parent(s): 9cd535d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -86,7 +86,7 @@ def initialize_llm():
86
 
87
  try:
88
  tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN)
89
- model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="mps")
90
  logger.info(f"Initialized local Hugging Face model: {HF_MODEL}")
91
  return (model, tokenizer), "hf_local"
92
  except Exception as e:
@@ -155,7 +155,7 @@ async def parse_question(state: JARVISState) -> JARVISState:
155
  inputs = tokenizer.apply_chat_template(
156
  [{"role": "system", "content": prompt[0].content}, {"role": "user", "content": prompt[1].content}],
157
  return_tensors="pt"
158
- ).to("mps")
159
  outputs = model.generate(inputs, max_new_tokens=512, temperature=0.7)
160
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
161
  tools_needed = json.loads(response.strip())
@@ -322,7 +322,7 @@ Document results: {document_results}""")
322
  try:
323
  if llm_type == "hf_local":
324
  model, tokenizer = llm_client
325
- inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to("mps")
326
  outputs = model.generate(inputs, max_new_tokens=512, temperature=0.7)
327
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
328
  else:
 
86
 
87
  try:
88
  tokenizer = AutoTokenizer.from_pretrained(HF_MODEL, token=HF_API_TOKEN)
89
+ model = AutoModelForCausalLM.from_pretrained(HF_MODEL, token=HF_API_TOKEN, device_map="auto")
90
  logger.info(f"Initialized local Hugging Face model: {HF_MODEL}")
91
  return (model, tokenizer), "hf_local"
92
  except Exception as e:
 
155
  inputs = tokenizer.apply_chat_template(
156
  [{"role": "system", "content": prompt[0].content}, {"role": "user", "content": prompt[1].content}],
157
  return_tensors="pt"
158
+ ).to(model.device)
159
  outputs = model.generate(inputs, max_new_tokens=512, temperature=0.7)
160
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
161
  tools_needed = json.loads(response.strip())
 
322
  try:
323
  if llm_type == "hf_local":
324
  model, tokenizer = llm_client
325
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(model.device)
326
  outputs = model.generate(inputs, max_new_tokens=512, temperature=0.7)
327
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
328
  else: