Spaces:
Runtime error
Runtime error
import warnings | |
warnings.filterwarnings("ignore") | |
import logging | |
logging.getLogger("streamlit").setLevel(logging.ERROR) | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
def load_model(): | |
model_name = "radlab/polish-gpt2-small-v2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
return tokenizer, model | |
tokenizer, model = load_model() | |
st.set_page_config(page_title="Polski Chatbot AI", page_icon="🤖") | |
st.title("🤖 Polski Chatbot AI") | |
st.caption("Model: radlab/polish-gpt2-small-v2") | |
if "history" not in st.session_state: | |
st.session_state.history = "" | |
user_input = st.text_input("Wpisz wiadomość:", "") | |
if st.button("Wyślij") and user_input.strip() != "": | |
st.session_state.history += f"Użytkownik: {user_input}\nAI:" | |
input_ids = tokenizer.encode(st.session_state.history, return_tensors="pt", truncation=True, max_length=1024) | |
output = model.generate( | |
input_ids, | |
max_length=input_ids.shape[1] + 80, | |
pad_token_id=tokenizer.eos_token_id, | |
do_sample=True, | |
top_k=50, | |
top_p=0.95, | |
temperature=0.7 | |
) | |
output_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
model_reply = output_text[len(st.session_state.history):].split("Użytkownik:")[0].strip() | |
st.session_state.history += f" {model_reply}\n" | |
st.subheader("🗨️ Historia rozmów") | |
st.text_area("📖", st.session_state.history.strip(), height=300) | |
if st.button("🧹 Wyczyść historię"): | |
st.session_state.history = "" | |