Spaces:
Runtime error
Runtime error
import os | |
import random | |
import string | |
import gradio as gr | |
import torch | |
from transformers import pipeline, set_seed | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import logging | |
# Monkey patch | |
import inspect | |
from gradio import routes | |
from typing import List, Type | |
def get_types(cls_set: List[Type], component: str): | |
docset = [] | |
types = [] | |
if component == "input": | |
for cls in cls_set: | |
doc = inspect.getdoc(cls) | |
doc_lines = doc.split("\n") | |
docset.append(doc_lines[1].split(":")[-1]) | |
types.append(doc_lines[1].split(")")[0].split("(")[-1]) | |
else: | |
for cls in cls_set: | |
doc = inspect.getdoc(cls) | |
doc_lines = doc.split("\n") | |
docset.append(doc_lines[-1].split(":")[-1]) | |
types.append(doc_lines[-1].split(")")[0].split("(")[-1]) | |
return docset, types | |
routes.get_types = get_types | |
logger = logging.getLogger() | |
logger.addHandler(logging.StreamHandler()) | |
DEBUG = os.environ.get("DEBUG", "false")[0] in "ty1" | |
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None) | |
DEVICE = os.environ.get("DEVICE", "cpu") #聽cuda:0 | |
if DEVICE != "cpu" and not torch.cuda.is_available(): | |
DEVICE = "cpu" | |
logger.info(f"DEVICE {DEVICE}") | |
DTYPE = getattr( | |
torch, | |
os.environ.get("DTYPE", ""), | |
torch.float32 if DEVICE == "cpu" else torch.float16 | |
) | |
LOW_CPU_MEM = bool(os.environ.get("LOW_CPU_MEM", False if DEVICE == "cpu" else True)) | |
MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B") | |
MODEL_REVISION = os.environ.get("MODEL_REVISION", "main") | |
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024)) | |
display_model_name = "BERTIN GPT-J-6B" if MODEL_NAME == "bertin-project/bertin-gpt-j-6B" else MODEL_NAME.upper() | |
HEADER_INFO = f""" | |
# {display_model_name} | |
Spanish {display_model_name} Model. | |
""".strip() | |
LOGO = "https://huggingface.co/bertin-project/bertin-roberta-base-spanish/resolve/main/images/bertin.png" | |
HEADER = f""" | |
<link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300&display=swap%22%20rel=%22stylesheet%22" rel="stylesheet"> | |
<style> | |
.ltr, | |
textarea {{ | |
font-family: Roboto !important; | |
text-align: left; | |
direction: ltr !important; | |
}} | |
.ltr-box {{ | |
border-bottom: 1px solid #ddd; | |
padding-bottom: 20px; | |
}} | |
.rtl {{ | |
text-align: left; | |
direction: ltr !important; | |
}} | |
span.result-text {{ | |
padding: 3px 3px; | |
line-height: 32px; | |
}} | |
span.generated-text {{ | |
background-color: rgb(118 200 147 / 13%); | |
}} | |
</style> | |
<div align=center> | |
<img src="{LOGO}" width=150/> | |
# {display_model_name} | |
BERTIN proporciona una serie de modelos de lenguaje en Espa帽ol entrenados en abierto. | |
Este modelo ha sido entrenado con [Mesh Transformer JAX](https://github.com/kingoflolz/mesh-transformer-jax) en TPUs proporcionadas por Google a trav茅s del programa Tensor Research Cloud, a partir del modelo [GPT-J de EleutherAI](https://huggingface.co/EleutherAI/gpt-j-6B) con el corpus [mC4-es-sampled (gaussian)](https://huggingface.co/datasets/bertin-project/mc4-es-sampled). Esta demo funciona sobre una GPU proporcionada por HuggingFace. | |
</div> | |
""" | |
FOOTER = f""" | |
<div align=center> | |
Para m谩s informaci贸n, visite el repositorio del modelo: <a href="https://huggingface.co/{MODEL_NAME}">{display_model_name}</a>. | |
<img src="https://visitor-badge.glitch.me/badge?page_id=spaces/{MODEL_NAME}"/> | |
<div align=center> | |
""".strip() | |
EXAMPLES = [ | |
"", | |
"脡rase una vez,", | |
"驴Cu谩l es la capital de Francia? Respuesta:", | |
"En un lugar de la Mancha, de cuyo nombre no quiero acordarme, no ha mucho tiempo que viv铆a un hidalgo de los de lanza en astillero, adarga antigua, roc铆n flaco y galgo corredor.", | |
"""Los templos egipcios fueron construidos para el culto oficial de los dioses y la conmemoraci贸n de los faraones del Antiguo Egipto en las regiones bajo su dominio. Los templos eran vistos como el hogar de los dioses o faraones deificados a quienes eran dedicados, y en ellos los faraones y el clero egipcio llevaban a cabo diversos rituales, las funciones centrales de la religi贸n egipcia: realizar ofrendas a sus dioses, recrear pasajes mitol贸gicos mediante festivales y protegerse de las fuerzas del caos. Estos rituales eran vistos como necesarios para que los dioses mantuvieran la maat, el orden divino del universo. | |
El cuidado del hogar de los dioses era obligaci贸n de los faraones, que dedicaron ingentes cantidades de recursos para la construcci贸n y el mantenimiento de los templos. Por necesidad, los faraones delegaban la mayor铆a de los rituales en una amplia casta sacerdotal, aunque la mayor parte del pueblo llano permanec铆a al margen de la participaci贸n directa en las ceremonias por tener prohibido el acceso a las zonas m谩s sagradas de los templos. A pesar de ello, el templo siempre fue un importante centro religioso para todos los egipcios, que iban a ellos a rezar, realizar ofrendas y buscar la gu铆a de los or谩culos. | |
Pregunta: 驴Qui茅n cuidaba del hogar los dioses? | |
Respuesta:""", | |
] | |
AGENT = os.environ.get("AGENT_NAME", "BERTIN") | |
PREV = "PREV" | |
USER = "ENTREVISTADOR" | |
CONTEXT = """La siguiente conversaci贸n es un extracto de una entrevista a {AGENT} celebrada en Madrid para Radio Televisi贸n Espa帽ola: | |
{USER}: Bienvenido, {AGENT}. Un placer tenerlo hoy con nosotros. | |
{AGENT}: Gracias. El placer es m铆o.""" | |
class Normalizer: | |
def remove_repetitions(self, text): | |
"""Remove repetitions""" | |
first_ocurrences = [] | |
for sentence in text.split("."): | |
if sentence not in first_ocurrences: | |
first_ocurrences.append(sentence) | |
return '.'.join(first_ocurrences) | |
def trim_last_sentence(self, text): | |
"""Trim last sentence if incomplete""" | |
return text[:text.rfind(".") + 1] | |
def clean_txt(self, text): | |
return self.trim_last_sentence(self.remove_repetitions(text)) | |
class TextGeneration: | |
def __init__(self): | |
self.tokenizer = None | |
self.generator = None | |
self.task = "text-generation" | |
self.model_name_or_path = MODEL_NAME | |
set_seed(42) | |
def load(self): | |
logger.info("Loading model...") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name_or_path, revision=MODEL_REVISION, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None, | |
) | |
self.tokenizer_prefix_space = AutoTokenizer.from_pretrained( | |
self.model_name_or_path, add_prefix_space=True, revision=MODEL_REVISION, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None, | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name_or_path, revision=MODEL_REVISION, | |
use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None, | |
pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, | |
torch_dtype=DTYPE, low_cpu_mem_usage=LOW_CPU_MEM, | |
).to(device=DEVICE, non_blocking=False) | |
_ = self.model.eval() | |
device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1]) | |
self.generator = pipeline(self.task, model=self.model, tokenizer=self.tokenizer, device=device_number) | |
logger.info("Loading model done.") | |
# with torch.no_grad(): | |
# tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True) | |
# gen_tokens = self.model.generate(tokens, do_sample=True, temperature=0.8, max_length=128) | |
# generated = tokenizer.batch_decode(gen_tokens)[0] | |
# return generated | |
def generate(self, text, generation_kwargs, previous_text=None): | |
do_clean = generation_kwargs.pop("do_clean", False) | |
bad_words = generation_kwargs.pop("bad_words", "") | |
if bad_words: | |
generation_kwargs["bad_words_ids"] = self.tokenizer_prefix_space( | |
[word.strip() for word in bad_words.split(",")], add_special_tokens=False | |
).input_ids | |
if "repetition_penalty" in generation_kwargs: | |
generation_kwargs["repetition_penalty"] = float(generation_kwargs["repetition_penalty"]) | |
input_text = previous_text or text | |
# max_length = len(self.tokenizer(input_text)["input_ids"]) + generation_kwargs["max_length"] | |
# generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions) | |
generation_kwargs["max_new_tokens"] = generation_kwargs.pop("max_length", 50) | |
generated_text = None | |
if input_text: | |
pre_input_text = "" | |
input_ids = self.tokenizer(input_text).input_ids | |
if len(input_ids) + generation_kwargs["max_new_tokens"] >= 2048: | |
prompt_cutoff = 2048 - generation_kwargs["max_new_tokens"] + 1 | |
pre_input_text = self.tokenizer.decode(input_ids[:-prompt_cutoff]) | |
input_text = self.tokenizer.decode(input_ids[-prompt_cutoff:]) | |
for _ in range(10): | |
generated_text = pre_input_text + (" " if do_clean else "") + self.generator( | |
input_text, | |
**generation_kwargs, | |
)[0]["generated_text"] | |
input_text = self.tokenizer.decode(input_ids) | |
if generated_text.strip().startswith(input_text): | |
generated_text = generated_text.replace(input_text, "", 1).strip() | |
if do_clean: | |
generated_text = cleaner.clean_txt(generated_text) | |
if generated_text: | |
if previous_text and previous_text != text: | |
diff = [ | |
(text, None), (previous_text.replace(text, " ", 1).strip(), PREV), (generated_text, AGENT) | |
] | |
else: | |
diff = [(text, None), (generated_text, AGENT)] | |
return ( | |
input_text + " " + generated_text, | |
diff | |
) | |
if not generated_text: | |
return ( | |
"", | |
[(f"Tras 10 intentos {AGENT} no gener贸 nada. Pruebe cambiando las opciones.", "ERROR")] | |
) | |
return ( | |
"", | |
[("Debe escribir algo primero.", "ERROR")] | |
) | |
#@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}) | |
#@st.cache(allow_output_mutation=True) | |
#@st.cache(allow_output_mutation=True, hash_funcs={TextGeneration: lambda _: None}) | |
def load_text_generator(): | |
text_generator = TextGeneration() | |
text_generator.load() | |
return text_generator | |
cleaner = Normalizer() | |
generator = load_text_generator() | |
def complete_with_gpt(text, max_length, top_k, top_p, penalty_alpha, num_beams, temperature, repetition_penalty, no_repeat_ngram_size, bad_words, do_sample, do_clean): | |
generation_kwargs = { | |
"max_length": max_length, | |
"top_k": top_k, | |
"top_p": top_p, | |
"penalty_alpha": penalty_alpha, | |
"num_beams": num_beams, | |
"temperature": temperature, | |
"repetition_penalty": repetition_penalty, | |
"no_repeat_ngram_size": no_repeat_ngram_size, | |
"bad_words": bad_words, | |
"do_sample": do_sample, | |
"do_clean": do_clean, | |
} | |
return generator.generate(text, generation_kwargs) | |
def expand_with_gpt(hidden, text, max_length, top_k, top_p, penalty_alpha, num_beams, temperature, repetition_penalty, no_repeat_ngram_size, bad_words, do_sample, do_clean): | |
generation_kwargs = { | |
"max_length": max_length, | |
"top_k": top_k, | |
"top_p": top_p, | |
"penalty_alpha": penalty_alpha, | |
"num_beams": num_beams, | |
"temperature": temperature, | |
"repetition_penalty": repetition_penalty, | |
"no_repeat_ngram_size": no_repeat_ngram_size, | |
"bad_words": bad_words, | |
"do_sample": do_sample, | |
"do_clean": do_clean, | |
} | |
return generator.generate(text, generation_kwargs, previous_text=hidden) | |
def chat_with_gpt(agent, user, context, user_message, history, max_length, top_k, top_p, penalty_alpha, num_beams, temperature, repetition_penalty, no_repeat_ngram_size, bad_words, do_sample, do_clean): | |
# agent = AGENT | |
# user = USER | |
generation_kwargs = { | |
"max_length": max_length, | |
"top_k": top_k, | |
"top_p": top_p, | |
"penalty_alpha": penalty_alpha, | |
"num_beams": num_beams, | |
"temperature": temperature, | |
"repetition_penalty": repetition_penalty, | |
"no_repeat_ngram_size": no_repeat_ngram_size, | |
"bad_words": bad_words, | |
"do_sample": do_sample, | |
"do_clean": do_clean, | |
# "num_return_sequences": 1, | |
# "return_full_text": False, | |
} | |
message = user_message.split(" ", 1)[0].capitalize() + " " + user_message.split(" ", 1)[-1] | |
history = history or [] #[(f"{user}: Bienvenido. Encantado de tenerle con nosotros.", f"{agent}: Un placer, muchas gracias por la invitaci贸n.")] | |
context = context.format(USER=user or USER, AGENT=agent or AGENT).strip() | |
if context[-1] not in ".:": | |
context += "." | |
context_length = len(context.split()) | |
history_take = 0 | |
history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:]) | |
while len(history_context.split()) > generator.model.config.n_positions - (generation_kwargs["max_length"] + context_length): | |
history_take += 1 | |
history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:]) | |
if history_take >= generator.model.config.n_positions: | |
break | |
context += history_context | |
for _ in range(5): | |
prompt = f"{context}\n\n{user}: {message}.\n" | |
response = generator.generate(prompt, generation_kwargs)[0] | |
if DEBUG: | |
print("\n-----\n" + response + "\n-----\n") | |
# response = response.split("\n")[-1] | |
# if agent in response and response.split(agent)[-1]: | |
# response = response.split(agent)[-1] | |
# if user in response and response.split(user)[-1]: | |
# response = response.split(user)[-1] | |
# Take the first response | |
response = [ | |
r for r in response.replace(prompt, "").split(f"{AGENT}:") if r.strip() | |
][0].split(USER)[0].replace(f"{AGENT}:", "\n").strip() | |
if response[0] in string.punctuation: | |
response = response[1:].strip() | |
if response.strip().startswith(f"{user}: {message}"): | |
response = response.strip().split(f"{user}: {message}")[-1] | |
if response.replace(".", "").strip() and message.replace(".", "").strip() != response.replace(".", "").strip(): | |
break | |
if DEBUG: | |
print() | |
print("CONTEXT:") | |
print(context) | |
print() | |
print("MESSAGE") | |
print(message) | |
print() | |
print("RESPONSE:") | |
print(response) | |
if not response.strip(): | |
response = random.choice(["No s茅 muy bien c贸mo contestar a eso.", "No puedo contestar con seguridad.", "Prefiero no contestar.", "Ni idea.", "驴Podemos cambiar de tema?"]) | |
history.append((user_message, response)) | |
return history, history, "" | |
# css="#htext span {white-space: pre}" | |
with gr.Blocks() as demo: | |
gr.Markdown(HEADER) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
with gr.Group(): | |
with gr.Box(): | |
gr.Markdown("Opciones") | |
with gr.Tabs(): | |
with gr.TabItem("Generaci贸n"): | |
max_length = gr.Slider( | |
label='Palabras a generar', | |
# help="N煤mero m谩ximo (aproximado) de palabras a generar.", | |
minimum=1, | |
maximum=MAX_LENGTH, | |
value=50, | |
step=1 | |
) | |
top_k = gr.Slider( | |
label='Top-k', | |
# help="N煤mero de palabras con alta probabilidad a mantener para el filtrado `top-k`", | |
minimum=0, | |
maximum=80, | |
value=50, | |
step=1 | |
) | |
top_p = gr.Slider( | |
label='Top-p', | |
# help="Solo las palabras m谩s probables con probabilidades que sumen `top_p` o m谩s se mantienen para la generaci贸n.", | |
minimum=0.01, | |
maximum=5.0, | |
value=0.95, | |
step=0.01 | |
) | |
penalty_alpha = gr.Slider( | |
label='Penalizaci贸n (alpha)', | |
# help="Penalizaci贸n para contrastive search.", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.0, | |
step=0.01 | |
) | |
num_beams = gr.Slider( | |
label='Haces (beams)', | |
# help="N煤mero de beams para b煤squeda.", | |
minimum=1, | |
maximum=50, | |
value=1, | |
step=1 | |
) | |
temperature = gr.Slider( | |
label='Temperatura', | |
# help="Valor utilizado para modular las probabilidades de las siguientes palabras generadas.", | |
minimum=0.0, | |
maximum=10.0, | |
value=0.8, | |
step=0.05 | |
) | |
do_sample = gr.Checkbox( | |
label='驴Muestrear?', | |
value = True, | |
# options=(True, False), | |
# help="Si no se muestrea se usar谩 una decodificaci贸n voraz (_greedy_).", | |
) | |
do_clean = gr.Checkbox( | |
label='驴Limpiar texto?', | |
value = False, | |
# options=(True, False), | |
# help="Si eliminar o no las palabras repetidas y recortar las 煤ltimas frases sin terminar.", | |
) | |
with gr.TabItem("Control de repetici贸n"): | |
repetition_penalty = gr.Slider( | |
label='Penalizaci贸n por repetici贸n', | |
help="Un valor de 1 significa no penalizaci贸n.", | |
minimum=1.0, | |
maximum=10.0, | |
value=1.0, | |
step=0.01 | |
) | |
no_repeat_ngram_size = gr.Slider( | |
label='No repetir ngrams de tama帽o', | |
minimum=0, | |
maximum=10, | |
value=0, | |
step=1 | |
) | |
bad_words = gr.Textbox( | |
label="Palabras a evitar", | |
info="Lista de palabras separadas por comas", | |
lines=1, | |
value="", | |
) | |
with gr.Accordion("Estrategias", open=False): | |
gr.Markdown(""" | |
- **greedy decoding** si `num_beams=1` y `do_sample=False` | |
- **contrastive search** si `penalty_alpha>0.0` y `top_k>1` | |
- **multinomial sampling** si `num_beams=1` y `do_sample=True` | |
- **beam-search decoding** si `num_beams>1` y `do_sample=False` | |
- **beam-search multinomial sampling** si `num_beams>1` y `do_sample=True` | |
""") | |
with gr.Column(scale=4): | |
with gr.Tabs(): | |
with gr.TabItem("Generar"): | |
textbox = gr.Textbox(label="Texto", placeholder="Escriba algo (o seleccione un ejemplo) y pulse 'Generar'...", lines=8) | |
examples = gr.Dropdown(label="Ejemplos", choices=EXAMPLES, value=None, type="value") | |
hidden = gr.Textbox(visible=False, show_label=False) | |
with gr.Box(): | |
# output = gr.Markdown() | |
output = gr.HighlightedText( | |
elem_id="htext", | |
label="Resultado", | |
combine_adjacent=True, | |
).style( | |
color_map={AGENT: "green", "ERROR": "red", PREV: "blue"}, | |
) | |
with gr.Row(): | |
generate_btn = gr.Button("Generar") | |
generate_btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, penalty_alpha, num_beams, temperature, repetition_penalty, no_repeat_ngram_size, bad_words, do_sample, do_clean], outputs=[hidden, output], api_name="generate") | |
expand_btn = gr.Button("A帽adir") | |
expand_btn.click(expand_with_gpt, inputs=[hidden, textbox, max_length, top_k, top_p, penalty_alpha, num_beams, temperature, repetition_penalty, no_repeat_ngram_size, bad_words, do_sample, do_clean], outputs=[hidden, output]) | |
edit_btn = gr.Button("Editar", variant="secondary") | |
edit_btn.click(lambda x: (x, "", []), inputs=[hidden], outputs=[textbox, hidden, output]) | |
clean_btn = gr.Button("Borrar", variant="secondary") | |
clean_btn.click(lambda: ("", "", [], ""), inputs=[], outputs=[textbox, hidden, output, examples]) | |
examples.change(lambda x: x, inputs=[examples], outputs=[textbox]) | |
with gr.TabItem("Charlar") as tab_chat: | |
# tab_chat.select(lambda: 25, inputs=[], outputs=[max_length]) | |
context = gr.Textbox(label="Contexto", value=CONTEXT, lines=5) | |
with gr.Row(): | |
agent = gr.Textbox(label="Agente", value=AGENT) | |
user = gr.Textbox(label="Usuario", value=USER) | |
history = gr.Variable(value=[]) | |
chatbot = gr.Chatbot().style(color_map=("green", "gray")) | |
with gr.Row(): | |
message = gr.Textbox(placeholder="Escriba aqu铆 su mensaje y pulse 'Enviar'", show_label=False) | |
chat_btn = gr.Button("Enviar") | |
chat_btn.click(chat_with_gpt, inputs=[agent, user, context, message, history, max_length, top_k, top_p, penalty_alpha, num_beams, temperature, repetition_penalty, no_repeat_ngram_size, bad_words, do_sample, do_clean], outputs=[chatbot, history, message]) | |
gr.Markdown(FOOTER) | |
# with gr.Interface(lambda: None, inputs=["text", max_length, top_k, top_p, penalty_alpha, num_beams, temperature, do_sample, do_clean], outputs=[hidden, output]) as iface: | |
# demo.examples = None | |
# demo.predict_durations = [] | |
# demo.input_components = iface.input_components | |
# demo.output_components = iface.output_components | |
demo.queue() | |
demo.launch(share=True) | |