Spaces:
Running
on
Zero
Running
on
Zero
Change back to User-Assistant conversation
Browse files
app.py
CHANGED
@@ -41,10 +41,10 @@ token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(toke
|
|
41 |
|
42 |
def load_model():
|
43 |
ckpt_path = hf_hub_download(
|
44 |
-
repo_id="ruurd/
|
45 |
filename="diffusion-model.pth",
|
46 |
token=os.getenv("HF_TOKEN"),
|
47 |
-
revision="5a22a8b6168466dbbf704efd00d8cbf2eee51426",
|
48 |
)
|
49 |
|
50 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -71,8 +71,6 @@ def load_model():
|
|
71 |
model.eval()
|
72 |
return model
|
73 |
|
74 |
-
|
75 |
-
|
76 |
rng = np.random.default_rng()
|
77 |
|
78 |
# --- Utility Functions ---
|
@@ -204,13 +202,11 @@ def diffusion_chat(question, eot_weight, mask_weight, max_it, pause_length, shar
|
|
204 |
|
205 |
print('started generation')
|
206 |
prompt = f"User: {question}\nAssistant:"
|
207 |
-
prompt = question
|
208 |
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
209 |
answer_start = find_answer_start(input_ids, assistant_marker_ids)
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
answer_start = len(input_ids)
|
214 |
|
215 |
if len(input_ids) < 256:
|
216 |
input_ids += [pad_token] * (256 - len(input_ids))
|
|
|
41 |
|
42 |
def load_model():
|
43 |
ckpt_path = hf_hub_download(
|
44 |
+
repo_id="ruurd/tini",
|
45 |
filename="diffusion-model.pth",
|
46 |
token=os.getenv("HF_TOKEN"),
|
47 |
+
# revision="5a22a8b6168466dbbf704efd00d8cbf2eee51426",
|
48 |
)
|
49 |
|
50 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
71 |
model.eval()
|
72 |
return model
|
73 |
|
|
|
|
|
74 |
rng = np.random.default_rng()
|
75 |
|
76 |
# --- Utility Functions ---
|
|
|
202 |
|
203 |
print('started generation')
|
204 |
prompt = f"User: {question}\nAssistant:"
|
|
|
205 |
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
206 |
answer_start = find_answer_start(input_ids, assistant_marker_ids)
|
207 |
+
if answer_start is None:
|
208 |
+
yield "Error: Could not find Assistant marker in input."
|
209 |
+
return
|
|
|
210 |
|
211 |
if len(input_ids) < 256:
|
212 |
input_ids += [pad_token] * (256 - len(input_ids))
|