GPT2R / app.py
XMichaelX's picture
Update app.py
523bd62 verified
import gradio as gr
from transformers import pipeline, GPT2LMHeadModel, GPT2Tokenizer
title = "GPT2"
description = "Gradio Demo for OpenAI GPT2. To use it, simply add your text, or click one of the examples to load them."
article = "<p style='text-align: center'><a href='https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf' target='_blank'>Language Models are Unsupervised Multitask Learners</a></p>"
examples = [
['Paris is the capital of', "gpt2-medium"]
]
# Initialize models dictionary to cache loaded models
models = {}
def load_model(model_name):
if model_name not in models:
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
models[model_name] = pipeline("text-generation", model=model, tokenizer=tokenizer)
return models[model_name]
def inference(text, model_name):
# Map the model names to their Hugging Face identifiers
model_map = {
"distilgpt2": "distilgpt2",
"gpt2-medium": "gpt2-medium",
"gpt2-large": "gpt2-large",
"gpt2-xl": "gpt2-xl"
}
# Get the correct model identifier
hf_model_name = model_map.get(model_name, "distilgpt2")
# Load the model (will be cached after first load)
generator = load_model(hf_model_name)
# Generate text
generated = generator(text, max_length=50, num_return_sequences=1)
return generated[0]['generated_text']
iface = gr.Interface(
inference,
[
gr.Textbox(label="Input"),
gr.Dropdown(
choices=["distilgpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"],
value="gpt2-medium",
label="Model"
)
],
gr.Textbox(label="Output"),
examples=examples,
article=article,
title=title,
description=description
)
iface.launch(enable_queue=True)