GPT2-PBE / app.py
tymbos's picture
Update app.py
4410500 verified
raw
history blame
7.95 kB
# -*- coding: utf-8 -*-
import os
import gradio as gr
import time
import datetime
from io import BytesIO
import matplotlib.pyplot as plt
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
from langdetect import detect, DetectorFactory
# Για επαναληψιμότητα στο langdetect
DetectorFactory.seed = 0
# Ρυθμίσεις checkpointing και αποθήκευσης του tokenizer
CHECKPOINT_FILE = "checkpoint.txt"
TOKENIZER_DIR = "tokenizer_model"
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
MAX_SAMPLES = 3000000 # Όριο δειγμάτων
# Παγκόσμια μεταβλητή ελέγχου συλλογής
STOP_COLLECTION = False
# ===== ΕΜΦΑΝΙΣΗ LOG ΕΚΚΙΝΗΣΗΣ =====
startup_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"\n===== Application Startup at {startup_time} =====\n")
def load_checkpoint():
"""Φόρτωση δεδομένων από το checkpoint αν υπάρχει."""
if os.path.exists(CHECKPOINT_FILE):
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
return f.read().splitlines()
return []
def append_to_checkpoint(texts):
"""Αποθήκευση δεδομένων στο αρχείο checkpoint."""
with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f:
for t in texts:
f.write(t + "\n")
def create_iterator(dataset_name, configs, split):
"""Φορτώνει το dataset και αποδίδει τα κείμενα ως iterator."""
configs_list = [c.strip() for c in configs.split(",") if c.strip()]
for config in configs_list:
try:
dataset = load_dataset(dataset_name, name=config, split=split, streaming=True)
for example in dataset:
text = example.get('text', '')
if text:
yield text
except Exception as e:
print(f"⚠️ Σφάλμα φόρτωσης dataset για config {config}: {e}")
def collect_samples(dataset_name, configs, split, chunk_size):
"""
Ξεκινά τη συλλογή δειγμάτων από το dataset μέχρι να φτάσει το MAX_SAMPLES
ή μέχρι να ζητηθεί διακοπή (STOP_COLLECTION).
"""
global STOP_COLLECTION
STOP_COLLECTION = False
total_processed = len(load_checkpoint())
# LOG: Ξεκίνησε η διαδικασία συλλογής
print(f"🚀 Ξεκινά η συλλογή δεδομένων... Υπάρχουν ήδη {total_processed} δείγματα στο checkpoint.")
progress_messages = [f"📌 Ξεκινά η συλλογή... Υπάρχουν ήδη {total_processed} δείγματα στο checkpoint."]
dataset_iterator = create_iterator(dataset_name, configs, split)
new_texts = []
for text in dataset_iterator:
if STOP_COLLECTION:
progress_messages.append("⏹️ Η συλλογή σταμάτησε από το χρήστη.")
print("⏹️ Η συλλογή σταμάτησε από το χρήστη.")
break
new_texts.append(text)
total_processed += 1
if len(new_texts) >= chunk_size:
append_to_checkpoint(new_texts)
progress_messages.append(f"✅ Αποθηκεύτηκαν {total_processed} δείγματα στο checkpoint.")
print(f"✅ Αποθηκεύτηκαν {total_processed} δείγματα στο checkpoint.")
new_texts = []
if total_processed >= MAX_SAMPLES:
progress_messages.append("⚠️ Έφτασε το όριο δειγμάτων.")
print("⚠️ Έφτασε το όριο δειγμάτων.")
break
if new_texts:
append_to_checkpoint(new_texts)
progress_messages.append(f"✅ Τελικό batch αποθηκεύτηκε ({total_processed} δείγματα).")
print(f"✅ Τελικό batch αποθηκεύτηκε ({total_processed} δείγματα).")
return "\n".join(progress_messages)
def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text):
"""Εκπαιδεύει τον tokenizer χρησιμοποιώντας τα δεδομένα του checkpoint."""
print("\n🚀 Ξεκινά η διαδικασία εκπαίδευσης...")
all_texts = load_checkpoint()
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR)
# LOG: Τέλος εκπαίδευσης
print(f"✅ Εκπαίδευση ολοκληρώθηκε! Το tokenizer αποθηκεύτηκε στο {TOKENIZER_DIR}.")
# Φόρτωση εκπαιδευμένου tokenizer
trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
# Δοκιμή
encoded = trained_tokenizer.encode(test_text)
decoded = trained_tokenizer.decode(encoded.ids)
# Γράφημα κατανομής tokens
token_lengths = [len(t) for t in encoded.tokens]
fig = plt.figure()
plt.hist(token_lengths, bins=20)
plt.xlabel('Μήκος Token')
plt.ylabel('Συχνότητα')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
return (f"✅ Εκπαίδευση ολοκληρώθηκε!\nΑποθηκεύτηκε στον φάκελο: {TOKENIZER_DIR}",
decoded,
img_buffer.getvalue())
def stop_collection():
"""Σταματά τη συλλογή δειγμάτων."""
global STOP_COLLECTION
STOP_COLLECTION = True
print("⏹️ Η συλλογή σταμάτησε από το χρήστη.")
return "⏹️ Η συλλογή σταμάτησε από το χρήστη."
def restart_collection():
"""Διαγράφει το checkpoint και επανεκκινεί τη συλλογή."""
global STOP_COLLECTION
STOP_COLLECTION = False
if os.path.exists(CHECKPOINT_FILE):
os.remove(CHECKPOINT_FILE)
print("🔄 Το checkpoint διαγράφηκε. Έτοιμο για νέα συλλογή.")
return "🔄 Το checkpoint διαγράφηκε. Μπορείς να ξεκινήσεις νέα συλλογή."
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Wikipedia Tokenizer Trainer with Logs & Control")
with gr.Row():
with gr.Column():
dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset Name")
configs = gr.Textbox(value="20231101.el,20231101.en", label="Configs")
split = gr.Dropdown(choices=["train"], value="train", label="Split")
chunk_size = gr.Slider(500, 50000, value=50000, label="Chunk Size")
vocab_size = gr.Slider(20000, 100000, value=50000, label="Vocabulary Size")
min_freq = gr.Slider(1, 100, value=3, label="Minimum Frequency")
test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
start_btn = gr.Button("Start Collection")
stop_btn = gr.Button("Stop Collection")
restart_btn = gr.Button("Restart Collection")
train_btn = gr.Button("Train Tokenizer")
progress = gr.Textbox(label="Progress", interactive=False, lines=10)
decoded_text = gr.Textbox(label="Decoded Text", interactive=False)
token_distribution = gr.Image(label="Token Distribution")
start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size], progress)
stop_btn.click(stop_collection, [], progress)
restart_btn.click(restart_collection, [], progress)
train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text],
[progress, decoded_text, token_distribution])
print("\nGradio Interface is launching...")
demo.launch()