|
import gradio as gr |
|
from transformers import pipeline, GPT2LMHeadModel, GPT2Tokenizer |
|
|
|
title = "GPT2" |
|
description = "Gradio Demo for OpenAI GPT2. To use it, simply add your text, or click one of the examples to load them." |
|
article = "<p style='text-align: center'><a href='https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf' target='_blank'>Language Models are Unsupervised Multitask Learners</a></p>" |
|
|
|
examples = [ |
|
['Paris is the capital of', "gpt2-medium"] |
|
] |
|
|
|
|
|
models = {} |
|
|
|
def load_model(model_name): |
|
if model_name not in models: |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
models[model_name] = pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
return models[model_name] |
|
|
|
def inference(text, model_name): |
|
|
|
model_map = { |
|
"distilgpt2": "distilgpt2", |
|
"gpt2-medium": "gpt2-medium", |
|
"gpt2-large": "gpt2-large", |
|
"gpt2-xl": "gpt2-xl" |
|
} |
|
|
|
|
|
hf_model_name = model_map.get(model_name, "distilgpt2") |
|
|
|
|
|
generator = load_model(hf_model_name) |
|
|
|
|
|
generated = generator(text, max_length=50, num_return_sequences=1) |
|
return generated[0]['generated_text'] |
|
|
|
iface = gr.Interface( |
|
inference, |
|
[ |
|
gr.Textbox(label="Input"), |
|
gr.Dropdown( |
|
choices=["distilgpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"], |
|
value="gpt2-medium", |
|
label="Model" |
|
) |
|
], |
|
gr.Textbox(label="Output"), |
|
examples=examples, |
|
article=article, |
|
title=title, |
|
description=description |
|
) |
|
|
|
iface.launch(enable_queue=True) |