|
import gradio as gr |
|
import torch |
|
from transformers import ( |
|
AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification, |
|
T5ForConditionalGeneration, T5Tokenizer, pipeline |
|
) |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
class MultiModelHub: |
|
def __init__(self): |
|
self.models = {} |
|
self.tokenizers = {} |
|
self.pipelines = {} |
|
self.model_configs = { |
|
|
|
"GPT-2 Indonesia": { |
|
"model_name": "Lyon28/GPT-2", |
|
"type": "text-generation", |
|
"description": "GPT-2 fine-tuned untuk bahasa Indonesia" |
|
}, |
|
"Tinny Llama": { |
|
"model_name": "Lyon28/Tinny-Llama", |
|
"type": "text-generation", |
|
"description": "Compact language model untuk chat" |
|
}, |
|
"Pythia": { |
|
"model_name": "Lyon28/Pythia", |
|
"type": "text-generation", |
|
"description": "Pythia model untuk text generation" |
|
}, |
|
"GPT-Neo": { |
|
"model_name": "Lyon28/GPT-Neo", |
|
"type": "text-generation", |
|
"description": "GPT-Neo untuk creative writing" |
|
}, |
|
"Distil GPT-2": { |
|
"model_name": "Lyon28/Distil_GPT-2", |
|
"type": "text-generation", |
|
"description": "Lightweight GPT-2 variant" |
|
}, |
|
"GPT-2 Tinny": { |
|
"model_name": "Lyon28/GPT-2-Tinny", |
|
"type": "text-generation", |
|
"description": "Compact GPT-2 model" |
|
}, |
|
|
|
|
|
"BERT Tinny": { |
|
"model_name": "Lyon28/Bert-Tinny", |
|
"type": "text-classification", |
|
"description": "BERT untuk klasifikasi teks" |
|
}, |
|
"ALBERT Base": { |
|
"model_name": "Lyon28/Albert-Base-V2", |
|
"type": "text-classification", |
|
"description": "ALBERT untuk analisis sentimen" |
|
}, |
|
"DistilBERT": { |
|
"model_name": "Lyon28/Distilbert-Base-Uncased", |
|
"type": "text-classification", |
|
"description": "Efficient BERT untuk classification" |
|
}, |
|
"ELECTRA Small": { |
|
"model_name": "Lyon28/Electra-Small", |
|
"type": "text-classification", |
|
"description": "ELECTRA untuk text understanding" |
|
}, |
|
|
|
|
|
"T5 Small": { |
|
"model_name": "Lyon28/T5-Small", |
|
"type": "text2text-generation", |
|
"description": "T5 untuk berbagai NLP tasks" |
|
} |
|
} |
|
|
|
def load_model(self, model_key): |
|
"""Load model on-demand untuk menghemat memory""" |
|
if model_key in self.pipelines: |
|
return self.pipelines[model_key] |
|
|
|
try: |
|
config = self.model_configs[model_key] |
|
model_name = config["model_name"] |
|
model_type = config["type"] |
|
|
|
|
|
if model_type == "text-generation": |
|
pipe = pipeline( |
|
"text-generation", |
|
model=model_name, |
|
tokenizer=model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
device_map="auto" if torch.cuda.is_available() else None |
|
) |
|
elif model_type == "text-classification": |
|
pipe = pipeline( |
|
"text-classification", |
|
model=model_name, |
|
tokenizer=model_name |
|
) |
|
elif model_type == "text2text-generation": |
|
pipe = pipeline( |
|
"text2text-generation", |
|
model=model_name, |
|
tokenizer=model_name |
|
) |
|
else: |
|
raise ValueError(f"Unsupported model type: {model_type}") |
|
|
|
self.pipelines[model_key] = pipe |
|
return pipe |
|
|
|
except Exception as e: |
|
return f"Error loading model {model_key}: {str(e)}" |
|
|
|
def generate_text(self, model_key, prompt, max_length=100, temperature=0.7, top_p=0.9): |
|
"""Generate text menggunakan model yang dipilih""" |
|
try: |
|
pipe = self.load_model(model_key) |
|
if isinstance(pipe, str): |
|
return pipe |
|
|
|
config = self.model_configs[model_key] |
|
|
|
if config["type"] == "text-generation": |
|
result = pipe( |
|
prompt, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
pad_token_id=pipe.tokenizer.eos_token_id |
|
) |
|
generated_text = result[0]['generated_text'] |
|
|
|
if generated_text.startswith(prompt): |
|
generated_text = generated_text[len(prompt):].strip() |
|
return generated_text |
|
|
|
elif config["type"] == "text-classification": |
|
result = pipe(prompt) |
|
return f"Label: {result[0]['label']}, Score: {result[0]['score']:.4f}" |
|
|
|
elif config["type"] == "text2text-generation": |
|
result = pipe(prompt, max_length=max_length) |
|
return result[0]['generated_text'] |
|
|
|
except Exception as e: |
|
return f"Error generating text: {str(e)}" |
|
|
|
def get_model_info(self, model_key): |
|
"""Get informasi model""" |
|
config = self.model_configs[model_key] |
|
return f"**{model_key}**\n\nType: {config['type']}\n\nDescription: {config['description']}" |
|
|
|
|
|
hub = MultiModelHub() |
|
|
|
def chat_interface(model_choice, user_input, max_length, temperature, top_p, history): |
|
"""Main chat interface""" |
|
if not user_input.strip(): |
|
return history, "" |
|
|
|
|
|
response = hub.generate_text( |
|
model_choice, |
|
user_input, |
|
max_length=int(max_length), |
|
temperature=temperature, |
|
top_p=top_p |
|
) |
|
|
|
|
|
history.append([user_input, response]) |
|
|
|
return history, "" |
|
|
|
def get_model_description(model_choice): |
|
"""Update model description""" |
|
return hub.get_model_info(model_choice) |
|
|
|
|
|
with gr.Blocks(title="Lyon28 Multi-Model Hub", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown( |
|
""" |
|
# π€ Lyon28 Multi-Model Hub |
|
|
|
Deploy dan test semua 11 models Lyon28 dalam satu interface. |
|
Pilih model, atur parameter, dan mulai chat! |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
model_dropdown = gr.Dropdown( |
|
choices=list(hub.model_configs.keys()), |
|
value="GPT-2 Indonesia", |
|
label="Select Model", |
|
info="Choose which model to use" |
|
) |
|
|
|
|
|
model_info = gr.Markdown( |
|
hub.get_model_info("GPT-2 Indonesia"), |
|
label="Model Information" |
|
) |
|
|
|
|
|
gr.Markdown("### Generation Parameters") |
|
max_length_slider = gr.Slider( |
|
minimum=20, |
|
maximum=500, |
|
value=100, |
|
step=10, |
|
label="Max Length", |
|
info="Maximum response length" |
|
) |
|
|
|
temperature_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature", |
|
info="Creativity level (higher = more creative)" |
|
) |
|
|
|
top_p_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.9, |
|
step=0.05, |
|
label="Top-p", |
|
info="Nucleus sampling parameter" |
|
) |
|
|
|
with gr.Column(scale=2): |
|
|
|
chatbot = gr.Chatbot( |
|
label="Chat with Model", |
|
height=400, |
|
show_label=True |
|
) |
|
|
|
user_input = gr.Textbox( |
|
placeholder="Type your message here...", |
|
label="Your Message", |
|
lines=2 |
|
) |
|
|
|
with gr.Row(): |
|
send_btn = gr.Button("Send", variant="primary") |
|
clear_btn = gr.Button("Clear Chat", variant="secondary") |
|
|
|
|
|
gr.Markdown("### π‘ Example Prompts") |
|
example_prompts = gr.Examples( |
|
examples=[ |
|
["Ceritakan tentang Indonesia"], |
|
["What is artificial intelligence?"], |
|
["Write a Python function to sort a list"], |
|
["Explain quantum computing in simple terms"], |
|
["Create a short story about robots"], |
|
], |
|
inputs=user_input, |
|
label="Click to use example prompts" |
|
) |
|
|
|
|
|
model_dropdown.change( |
|
fn=get_model_description, |
|
inputs=[model_dropdown], |
|
outputs=[model_info] |
|
) |
|
|
|
send_btn.click( |
|
fn=chat_interface, |
|
inputs=[model_dropdown, user_input, max_length_slider, temperature_slider, top_p_slider, chatbot], |
|
outputs=[chatbot, user_input] |
|
) |
|
|
|
user_input.submit( |
|
fn=chat_interface, |
|
inputs=[model_dropdown, user_input, max_length_slider, temperature_slider, top_p_slider, chatbot], |
|
outputs=[chatbot, user_input] |
|
) |
|
|
|
clear_btn.click( |
|
fn=lambda: ([], ""), |
|
outputs=[chatbot, user_input] |
|
) |
|
|
|
|
|
with demo: |
|
gr.Markdown( |
|
""" |
|
--- |
|
|
|
### π Features: |
|
- **11 Models**: Akses semua model Lyon28 dalam satu tempat |
|
- **Multiple Types**: Text generation, classification, dan text2text |
|
- **Configurable**: Adjust temperature, top-p, dan max length |
|
- **Memory Efficient**: Models loaded on-demand |
|
- **API Ready**: Gradio auto-generates API endpoints |
|
|
|
### π‘ API Usage: |
|
```python |
|
import requests |
|
|
|
response = requests.post( |
|
"https://your-space-name.hf.space/api/predict", |
|
json={"data": ["GPT-2 Indonesia", "Hello world", 100, 0.7, 0.9, []]} |
|
) |
|
``` |
|
|
|
**Built by Lyon28** π₯ |
|
""" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
share=True, |
|
server_name="0.0.0.0", |
|
server_port=7860 |
|
) |