Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
# 1. Konfiguracja modelu i tokenizera | |
MODEL_ID = "tiiuae/Falcon-H1-1.5B-Deep-Instruct" | |
# Ładowanie tokenizera | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
# Ładowanie modelu z optymalizacją autodevice i bfloat16 (jeśli wspierane) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16, # lub torch.float16 / torch.float32, zależnie od dostępnego sprzętu | |
device_map="auto", # automatyczne rozłożenie na GPU/CPU | |
) | |
# 2. Funkcja generująca odpowiedź | |
def generate_text(prompt: str, max_length: int = 256, temperature: float = 0.7, top_p: float = 0.9): | |
# Tokenizacja wejścia | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Generacja sekwencji | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Dekodowanie na tekst | |
generated = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Usuń powtórzone zapytanie | |
return generated[len(prompt):].strip() | |
# 3. Interfejs Gradio | |
with gr.Blocks(title="Falcon-H1-1.5B Deep Instruct") as demo: | |
gr.Markdown("## Falcon-H1-1.5B-Deep-Instruct\nInteraktywny interfejs do generowania tekstu za pomocą modelu Instrukcyjnego") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
prompt_input = gr.Textbox(label="Wpisz prompt", lines=6, placeholder="Napisz coś...") | |
max_len_slider = gr.Slider(minimum=16, maximum=1024, value=256, step=16, label="Maksymalna długość odpowiedzi") | |
temp_slider = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.05, label="Temperature") | |
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)") | |
submit_btn = gr.Button("Generuj") | |
with gr.Column(scale=5): | |
output_box = gr.Textbox(label="Wygenerowany tekst", lines=10) | |
# Powiązanie przycisku z funkcją | |
submit_btn.click( | |
fn=generate_text, | |
inputs=[prompt_input, max_len_slider, temp_slider, top_p_slider], | |
outputs=output_box | |
) | |
# 4. Uruchomienie serwera | |
if __name__ == "__main__": | |
demo.launch(share=False, server_name="0.0.0.0", server_port=7860) | |