BigSalmon commited on
Commit
2a2f758
·
1 Parent(s): b77881f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -16,7 +16,7 @@ def load_model(model_name):
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForCausalLM.from_pretrained(model_name)
18
  return model, tokenizer
19
- def extend(input_text, num_return_sequences, max_size=20, top_k=50, top_p=0.95, bad_words):
20
  if len(input_text) == 0:
21
  input_text = ""
22
  encoded_prompt = tokenizer.encode(
@@ -25,8 +25,7 @@ def extend(input_text, num_return_sequences, max_size=20, top_k=50, top_p=0.95,
25
  if encoded_prompt.size()[-1] == 0:
26
  input_ids = None
27
  else:
28
- input_ids = encoded_prompt
29
-
30
  bad_words = bad_words.split()
31
  bad_word_ids = []
32
  for bad_word in bad_words:
@@ -90,11 +89,11 @@ if __name__ == "__main__":
90
  if len(text_area.strip()) == 0:
91
  text_area = random.choice(suggested_text_list)
92
  result = extend(input_text=text_area,
93
- num_return_sequences=int(num_return_sequences),
 
94
  max_size=int(max_len),
95
  top_k=int(top_k),
96
- top_p=float(top_p),
97
- bad_words = bad_words)
98
  print("Done length: " + str(len(result)) + " bytes")
99
  #<div class="rtl" dir="rtl" style="text-align:right;">
100
  st.markdown(f"{result}", unsafe_allow_html=True)
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
  model = AutoModelForCausalLM.from_pretrained(model_name)
18
  return model, tokenizer
19
+ def extend(input_text, num_return_sequences, bad_words, max_size=20, top_k=50, top_p=0.95):
20
  if len(input_text) == 0:
21
  input_text = ""
22
  encoded_prompt = tokenizer.encode(
 
25
  if encoded_prompt.size()[-1] == 0:
26
  input_ids = None
27
  else:
28
+ input_ids = encoded_prompt
 
29
  bad_words = bad_words.split()
30
  bad_word_ids = []
31
  for bad_word in bad_words:
 
89
  if len(text_area.strip()) == 0:
90
  text_area = random.choice(suggested_text_list)
91
  result = extend(input_text=text_area,
92
+ num_return_sequences=int(num_return_sequences),
93
+ bad_words = bad_words,
94
  max_size=int(max_len),
95
  top_k=int(top_k),
96
+ top_p=float(top_p))
 
97
  print("Done length: " + str(len(result)) + " bytes")
98
  #<div class="rtl" dir="rtl" style="text-align:right;">
99
  st.markdown(f"{result}", unsafe_allow_html=True)