pmolchanov commited on
Commit
29d26a3
·
verified ·
1 Parent(s): 0c8c67c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -0
app.py CHANGED
@@ -3,6 +3,11 @@ import gradio as gr
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, StopStringCriteria, StoppingCriteriaList
4
  import torch
5
 
 
 
 
 
 
6
  # Load the tokenizer and model
7
  repo_name = "nvidia/Hymba-1.5B-Instruct"
8
 
@@ -10,6 +15,9 @@ tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
10
  model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)
11
  model = model.cuda().to(torch.bfloat16)
12
 
 
 
 
13
  # Chat with Hymba
14
  # prompt = input()
15
  prompt = "Who are you?"
@@ -22,6 +30,10 @@ messages.append({"role": "user", "content": prompt})
22
  # Apply chat template
23
  tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to('cuda')
24
  stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings="</s>")])
 
 
 
 
25
  outputs = model.generate(
26
  tokenized_chat,
27
  max_new_tokens=256,
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer, StopStringCriteria, StoppingCriteriaList
4
  import torch
5
 
6
+ import torch
7
+ import os
8
+ os.system("nvidia-smi")
9
+ print("TORCH_CUDA", torch.cuda.is_available())
10
+
11
  # Load the tokenizer and model
12
  repo_name = "nvidia/Hymba-1.5B-Instruct"
13
 
 
15
  model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True)
16
  model = model.cuda().to(torch.bfloat16)
17
 
18
+ print("model is loaded")
19
+
20
+
21
  # Chat with Hymba
22
  # prompt = input()
23
  prompt = "Who are you?"
 
30
  # Apply chat template
31
  tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to('cuda')
32
  stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings="</s>")])
33
+
34
+ print("generating prompt")
35
+
36
+
37
  outputs = model.generate(
38
  tokenized_chat,
39
  max_new_tokens=256,