Spaces:
Running on Zero

Ruurd commited on
Commit
b1cf46e
·
verified ·
1 Parent(s): 800af7e

Update model to Llama 3.2 3B-Instruct based, changed prompt format

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -18,7 +18,7 @@ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=Tr
18
  vocab_size = len(tokenizer)
19
  eos_token_id = tokenizer.eos_token_id
20
  mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
21
- assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
22
 
23
  # def load_model():
24
  # ckpt_path = hf_hub_download(
@@ -195,6 +195,17 @@ def generate_diffusion_text(input_ids, top_p, top_k):
195
  conf = probs[range(len(sampled)), sampled].cpu().numpy()
196
  return sampled, conf
197
 
 
 
 
 
 
 
 
 
 
 
 
198
  # --- Inference Wrapper ---
199
  def diffusion_chat(question, max_it, pause_length, sharpness,
200
  clustering, noise_start, use_confidence_noising,
@@ -204,7 +215,7 @@ def diffusion_chat(question, max_it, pause_length, sharpness,
204
  question = placeholder
205
 
206
  print('started generation')
207
- prompt = f"User: {question}\nAssistant:"
208
  input_ids = tokenizer.encode(prompt, add_special_tokens=False)
209
  answer_start = find_answer_start(input_ids, assistant_marker_ids)
210
  if answer_start is None:
 
18
  vocab_size = len(tokenizer)
19
  eos_token_id = tokenizer.eos_token_id
20
  mask_token_id = tokenizer.encode('MASK', add_special_tokens=False)[0]
21
+ assistant_marker_ids = tokenizer.encode("<|start_header_id|>assistant<|end_header_id|>", add_special_tokens=False)
22
 
23
  # def load_model():
24
  # ckpt_path = hf_hub_download(
 
195
  conf = probs[range(len(sampled)), sampled].cpu().numpy()
196
  return sampled, conf
197
 
198
+ def format_chat_prompt(question):
199
+ return (
200
+ "<|begin_of_text|>\n"
201
+ "<|start_header_id|>system<|end_header_id|>\n"
202
+ "You are a helpful assistant.\n"
203
+ "<|start_header_id|>user<|end_header_id|>\n"
204
+ f"{question}\n"
205
+ "<|start_header_id|>assistant<|end_header_id|>\n"
206
+ )
207
+
208
+
209
  # --- Inference Wrapper ---
210
  def diffusion_chat(question, max_it, pause_length, sharpness,
211
  clustering, noise_start, use_confidence_noising,
 
215
  question = placeholder
216
 
217
  print('started generation')
218
+ prompt = format_chat_prompt(question)
219
  input_ids = tokenizer.encode(prompt, add_special_tokens=False)
220
  answer_start = find_answer_start(input_ids, assistant_marker_ids)
221
  if answer_start is None: