frimelle HF Staff commited on
Commit
91b2732
·
1 Parent(s): 865324e

add zerogpu setup

Browse files
Files changed (2) hide show
  1. app.py +20 -16
  2. requirements.txt +3 -2
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import uuid
5
  import os
6
  from datetime import datetime
 
7
 
8
  # ----- Constants -----
9
  MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@@ -12,27 +13,31 @@ with open("system_prompt.txt", "r") as f:
12
  LOG_DIR = "chat_logs"
13
  os.makedirs(LOG_DIR, exist_ok=True)
14
 
15
- # ----- Load model and tokenizer -----
16
- device = "cuda" if torch.cuda.is_available() else "cpu"
17
-
18
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
- model = AutoModelForCausalLM.from_pretrained(
20
- MODEL_NAME,
21
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
22
- device_map="auto" if device == "cuda" else None
23
- )
24
- model.eval()
25
-
26
- # ----- Log setup -----
27
  session_id = str(uuid.uuid4())
28
 
 
29
  def log_chat(session_id, user_msg, bot_msg):
30
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
31
  with open(os.path.join(LOG_DIR, f"{session_id}.txt"), "a") as f:
32
  f.write(f"[{timestamp}] User: {user_msg}\n")
33
  f.write(f"[{timestamp}] Bot: {bot_msg}\n\n")
34
 
35
- # ----- Inference -----
 
 
 
 
 
 
 
 
 
 
 
 
36
  def format_chat_prompt(history, new_input):
37
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
38
  for user_msg, bot_msg in history:
@@ -54,14 +59,13 @@ def respond(message, history):
54
  pad_token_id=tokenizer.eos_token_id
55
  )
56
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
57
- # Extract the assistant's final message
58
  response = decoded.split(message)[-1].strip().split("\n")[0].strip()
59
  log_chat(session_id, message, response)
60
  return response
61
 
62
- # ----- Gradio Chat Interface -----
63
  gr.ChatInterface(
64
  fn=respond,
65
  title="BoundrAI",
66
- theme="soft", # optional aesthetic
67
  ).launch()
 
4
  import uuid
5
  import os
6
  from datetime import datetime
7
+ import spaces # required for ZeroGPU
8
 
9
  # ----- Constants -----
10
  MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
 
13
  LOG_DIR = "chat_logs"
14
  os.makedirs(LOG_DIR, exist_ok=True)
15
 
16
+ # Global vars to hold model and tokenizer
17
+ model = None
18
+ tokenizer = None
 
 
 
 
 
 
 
 
 
19
  session_id = str(uuid.uuid4())
20
 
21
+ # ----- Log Chat -----
22
  def log_chat(session_id, user_msg, bot_msg):
23
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
24
  with open(os.path.join(LOG_DIR, f"{session_id}.txt"), "a") as f:
25
  f.write(f"[{timestamp}] User: {user_msg}\n")
26
  f.write(f"[{timestamp}] Bot: {bot_msg}\n\n")
27
 
28
+ # ----- Required by ZeroGPU -----
29
+ @spaces.GPU
30
+ def load_model():
31
+ global model, tokenizer
32
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ MODEL_NAME,
35
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
36
+ device_map="auto"
37
+ )
38
+ model.eval()
39
+
40
+ # ----- Inference Function -----
41
  def format_chat_prompt(history, new_input):
42
  messages = [{"role": "system", "content": SYSTEM_PROMPT}]
43
  for user_msg, bot_msg in history:
 
59
  pad_token_id=tokenizer.eos_token_id
60
  )
61
  decoded = tokenizer.decode(output[0], skip_special_tokens=True)
 
62
  response = decoded.split(message)[-1].strip().split("\n")[0].strip()
63
  log_chat(session_id, message, response)
64
  return response
65
 
66
+ # ----- Gradio App -----
67
  gr.ChatInterface(
68
  fn=respond,
69
  title="BoundrAI",
70
+ theme="soft"
71
  ).launch()
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  huggingface_hub==0.25.2
2
- transformers
3
  gradio
4
- torch
 
 
 
1
  huggingface_hub==0.25.2
 
2
  gradio
3
+ transformers
4
+ torch
5
+ spaces