arshiaafshani commited on
Commit
8e0463b
·
verified ·
1 Parent(s): 5a3d95f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -16,10 +16,14 @@ def patched_get_type(schema):
16
  client_utils.get_type = patched_get_type
17
 
18
  # Load model and tokenizer
19
- model_name = "arshiaafshani/Arsh-llm"
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
22
 
 
 
 
 
23
  # Create pipeline
24
  pipe = pipeline(
25
  "text-generation",
@@ -29,8 +33,12 @@ pipe = pipeline(
29
  )
30
 
31
  def respond(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
32
- # Prepare prompt
33
- prompt = f"{system_message}\n\nUser: {message}\nAssistant:"
 
 
 
 
34
 
35
  # Generate response
36
  output = pipe(
@@ -44,7 +52,8 @@ def respond(message, chat_history, system_message, max_tokens, temperature, top_
44
  pad_token_id=tokenizer.eos_token_id
45
  )
46
 
47
- response = output[0]['generated_text'].split("Assistant:")[-1].strip()
 
48
 
49
  # Update chat history
50
  chat_history.append((message, response))
@@ -80,4 +89,4 @@ with gr.Blocks() as demo:
80
  clear.click(lambda: None, None, chatbot, queue=False)
81
 
82
  if __name__ == "__main__":
83
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
16
  client_utils.get_type = patched_get_type
17
 
18
  # Load model and tokenizer
19
+ model_name = "arshiaafshani/Arsh-llm"
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
22
 
23
+ # تنظیم توکن‌های خاص
24
+ tokenizer.bos_token = "<sos>"
25
+ tokenizer.eos_token = "<|endoftext|>"
26
+
27
  # Create pipeline
28
  pipe = pipeline(
29
  "text-generation",
 
33
  )
34
 
35
  def respond(message, chat_history, system_message, max_tokens, temperature, top_p, top_k, repeat_penalty):
36
+ # Prepare prompt using apply_chat_template
37
+ chat_history = chat_history or []
38
+ messages = [{"role": "system", "content": system_message}] + \
39
+ [{"role": "user", "content": msg} for msg, _ in chat_history] + \
40
+ [{"role": "user", "content": message}, {"role": "assistant", "content": ""}]
41
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False)
42
 
43
  # Generate response
44
  output = pipe(
 
52
  pad_token_id=tokenizer.eos_token_id
53
  )
54
 
55
+ # Extract response
56
+ response = output[0]['generated_text'][len(prompt):].strip()
57
 
58
  # Update chat history
59
  chat_history.append((message, response))
 
89
  clear.click(lambda: None, None, chatbot, queue=False)
90
 
91
  if __name__ == "__main__":
92
+ demo.launch(server_name="0.0.0.0", server_port=7860)