Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -16,7 +16,7 @@ def load_model(model_name):
|
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
18 |
return model, tokenizer
|
19 |
-
def extend(input_text, num_return_sequences, max_size=20, top_k=50, top_p=0.95
|
20 |
if len(input_text) == 0:
|
21 |
input_text = ""
|
22 |
encoded_prompt = tokenizer.encode(
|
@@ -25,8 +25,7 @@ def extend(input_text, num_return_sequences, max_size=20, top_k=50, top_p=0.95,
|
|
25 |
if encoded_prompt.size()[-1] == 0:
|
26 |
input_ids = None
|
27 |
else:
|
28 |
-
input_ids = encoded_prompt
|
29 |
-
|
30 |
bad_words = bad_words.split()
|
31 |
bad_word_ids = []
|
32 |
for bad_word in bad_words:
|
@@ -90,11 +89,11 @@ if __name__ == "__main__":
|
|
90 |
if len(text_area.strip()) == 0:
|
91 |
text_area = random.choice(suggested_text_list)
|
92 |
result = extend(input_text=text_area,
|
93 |
-
num_return_sequences=int(num_return_sequences),
|
|
|
94 |
max_size=int(max_len),
|
95 |
top_k=int(top_k),
|
96 |
-
top_p=float(top_p)
|
97 |
-
bad_words = bad_words)
|
98 |
print("Done length: " + str(len(result)) + " bytes")
|
99 |
#<div class="rtl" dir="rtl" style="text-align:right;">
|
100 |
st.markdown(f"{result}", unsafe_allow_html=True)
|
|
|
16 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
17 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
18 |
return model, tokenizer
|
19 |
+
def extend(input_text, num_return_sequences, bad_words, max_size=20, top_k=50, top_p=0.95):
|
20 |
if len(input_text) == 0:
|
21 |
input_text = ""
|
22 |
encoded_prompt = tokenizer.encode(
|
|
|
25 |
if encoded_prompt.size()[-1] == 0:
|
26 |
input_ids = None
|
27 |
else:
|
28 |
+
input_ids = encoded_prompt
|
|
|
29 |
bad_words = bad_words.split()
|
30 |
bad_word_ids = []
|
31 |
for bad_word in bad_words:
|
|
|
89 |
if len(text_area.strip()) == 0:
|
90 |
text_area = random.choice(suggested_text_list)
|
91 |
result = extend(input_text=text_area,
|
92 |
+
num_return_sequences=int(num_return_sequences),
|
93 |
+
bad_words = bad_words,
|
94 |
max_size=int(max_len),
|
95 |
top_k=int(top_k),
|
96 |
+
top_p=float(top_p))
|
|
|
97 |
print("Done length: " + str(len(result)) + " bytes")
|
98 |
#<div class="rtl" dir="rtl" style="text-align:right;">
|
99 |
st.markdown(f"{result}", unsafe_allow_html=True)
|