pmolchanov commited on
Commit
a57c2fb
·
verified ·
1 Parent(s): 43898b1

Update app_chat.py

Browse files
Files changed (1) hide show
  1. app_chat.py +19 -1
app_chat.py CHANGED
@@ -6,6 +6,7 @@ import gradio as gr
6
  import spaces
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
9
 
10
  MAX_MAX_NEW_TOKENS = 1024
11
  DEFAULT_MAX_NEW_TOKENS = 512
@@ -21,6 +22,19 @@ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat1
21
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
22
  #tokenizer.use_default_system_prompt = False
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  @spaces.GPU
26
  def generate(
@@ -39,7 +53,10 @@ def generate(
39
  conversation += chat_history
40
  conversation.append({"role": "User", "content": message})
41
 
42
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
 
 
 
43
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
44
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
45
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
@@ -56,6 +73,7 @@ def generate(
56
  temperature=temperature,
57
  num_beams=1,
58
  repetition_penalty=repetition_penalty,
 
59
  )
60
  t = Thread(target=model.generate, kwargs=generate_kwargs)
61
  t.start()
 
6
  import spaces
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+ # from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCriteria
10
 
11
  MAX_MAX_NEW_TOKENS = 1024
12
  DEFAULT_MAX_NEW_TOKENS = 512
 
22
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
23
  #tokenizer.use_default_system_prompt = False
24
 
25
+ # class StoppingCriteriaSub(StoppingCriteria):
26
+ # def __init__(self, tokenizer, stops = [], encounters=1):
27
+ # super().__init__()
28
+ # self.stops = [stop.to("cuda") for stop in stops]
29
+ # self.tokenizer = tokenizer
30
+ # self.num_mamba_stop_ids = 8
31
+
32
+ # def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
33
+ # last_token = input_ids[0][-self.num_mamba_stop_ids:]
34
+ # for stop in self.stops:
35
+ # if self.tokenizer.decode(stop) in self.tokenizer.decode(last_token):
36
+ # return True
37
+ # return False
38
 
39
  @spaces.GPU
40
  def generate(
 
53
  conversation += chat_history
54
  conversation.append({"role": "User", "content": message})
55
 
56
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt")
57
+
58
+ # stopping_criteria = StoppingCriteriaList([StopStringCriteria(tokenizer=tokenizer, stop_strings="</s>")])
59
+
60
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
61
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
62
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
73
  temperature=temperature,
74
  num_beams=1,
75
  repetition_penalty=repetition_penalty,
76
+ # "stopping_criteria": stopping_criteria,
77
  )
78
  t = Thread(target=model.generate, kwargs=generate_kwargs)
79
  t.start()