# -*- coding: utf-8 -*- import os import gradio as gr import requests import time from io import BytesIO import matplotlib.pyplot as plt from datasets import load_dataset from train_tokenizer import train_tokenizer from tokenizers import Tokenizer from langdetect import detect, DetectorFactory from PIL import Image # Προσθήκη για σωστή διαχείριση εικόνας στο Gradio # Για επαναληψιμότητα στο langdetect DetectorFactory.seed = 0 # Ρυθμίσεις checkpointing και αποθήκευσης του tokenizer CHECKPOINT_FILE = "checkpoint.txt" TOKENIZER_DIR = "tokenizer_model" TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json") MAX_SAMPLES = 3000000 # Όριο δειγμάτων # Παγκόσμια μεταβλητή ελέγχου συλλογής STOP_COLLECTION = False def load_checkpoint(): """Φόρτωση δεδομένων από το checkpoint αν υπάρχει.""" if os.path.exists(CHECKPOINT_FILE): with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f: return f.read().splitlines() return [] def append_to_checkpoint(texts): """Αποθήκευση δεδομένων στο αρχείο checkpoint.""" with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f: for t in texts: f.write(t + "\n") def create_iterator(dataset_name, configs, split): """Φορτώνει το dataset και αποδίδει τα κείμενα ως iterator.""" configs_list = [c.strip() for c in configs.split(",") if c.strip()] for config in configs_list: try: dataset = load_dataset(dataset_name, name=config, split=split, streaming=True) for example in dataset: text = example.get('text', '') if text: yield text except Exception as e: print(f"⚠️ Σφάλμα φόρτωσης dataset για config {config}: {e}") def analyze_checkpoint(num_samples=1000): """Αναλύει τα πρώτα num_samples δείγματα από το checkpoint και επιστρέφει το ποσοστό γλωσσών.""" if not os.path.exists(CHECKPOINT_FILE): return "Το αρχείο checkpoint δεν υπάρχει." with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f: lines = f.read().splitlines() sample_lines = lines[:num_samples] if len(lines) >= num_samples else lines language_counts = {} total = 0 for line in sample_lines: try: lang = detect(line) language_counts[lang] = language_counts.get(lang, 0) + 1 total += 1 except Exception: continue if total == 0: return "Δεν βρέθηκαν έγκυρα δείγματα για ανάλυση." report = "📊 Αποτελέσματα Ανάλυσης:\n" for lang, count in language_counts.items(): report += f" - {lang}: {count / total * 100:.2f}%\n" return report def collect_samples(dataset_name, configs, split, chunk_size): """Ξεκινά τη συλλογή δειγμάτων από το dataset.""" global STOP_COLLECTION STOP_COLLECTION = False total_processed = len(load_checkpoint()) progress_messages = [f"📌 Ξεκινά η συλλογή... Υπάρχουν ήδη {total_processed} δείγματα στο checkpoint."] dataset_iterator = create_iterator(dataset_name, configs, split) new_texts = [] for text in dataset_iterator: if STOP_COLLECTION: progress_messages.append("⏹️ Η συλλογή σταμάτησε από το χρήστη.") break new_texts.append(text) total_processed += 1 if len(new_texts) >= chunk_size: append_to_checkpoint(new_texts) progress_messages.append(f"✅ Αποθηκεύτηκαν {total_processed} δείγματα στο checkpoint.") new_texts = [] if total_processed >= MAX_SAMPLES: progress_messages.append("⚠️ Έφτασε το όριο δειγμάτων.") break if new_texts: append_to_checkpoint(new_texts) progress_messages.append(f"✅ Τελικό batch αποθηκεύτηκε ({total_processed} δείγματα).") return "\n".join(progress_messages) def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text): """Εκπαιδεύει τον tokenizer χρησιμοποιώντας τα δεδομένα του checkpoint.""" print("🚀 Ξεκινά η εκπαίδευση...") all_texts = load_checkpoint() tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR) # Φόρτωση εκπαιδευμένου tokenizer trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE) # Δοκιμή encoded = trained_tokenizer.encode(test_text) decoded = trained_tokenizer.decode(encoded.ids) # Γράφημα κατανομής tokens token_lengths = [len(t) for t in encoded.tokens] fig = plt.figure() plt.hist(token_lengths, bins=20) plt.xlabel('Μήκος Token') plt.ylabel('Συχνότητα') # Αποθήκευση και μετατροπή εικόνας img_buffer = BytesIO() plt.savefig(img_buffer, format='png') plt.close() img_buffer.seek(0) img = Image.open(img_buffer) # Επιστροφή σωστής εικόνας return (f"✅ Εκπαίδευση ολοκληρώθηκε!\nΑποθηκεύτηκε στον φάκελο: {TOKENIZER_DIR}", decoded, img) def stop_collection(): """Σταματά τη συλλογή δειγμάτων.""" global STOP_COLLECTION STOP_COLLECTION = True return "⏹️ Η συλλογή σταμάτησε από το χρήστη." def restart_collection(): """Διαγράφει το checkpoint και επανεκκινεί τη συλλογή.""" global STOP_COLLECTION STOP_COLLECTION = False if os.path.exists(CHECKPOINT_FILE): os.remove(CHECKPOINT_FILE) return "🔄 Το checkpoint διαγράφηκε. Μπορείς να ξεκινήσεις νέα συλλογή." # Gradio Interface with gr.Blocks() as demo: gr.Markdown("## Wikipedia Tokenizer Trainer with Logs & Control") with gr.Row(): with gr.Column(): dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset Name") configs = gr.Textbox(value="20231101.el,20231101.en", label="Configs") split = gr.Dropdown(choices=["train"], value="train", label="Split") chunk_size = gr.Slider(500, 10000, value=5000, label="Chunk Size") vocab_size = gr.Slider(20000, 100000, value=50000, label="Vocabulary Size") min_freq = gr.Slider(1, 100, value=3, label="Minimum Frequency") test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text") start_btn = gr.Button("Start Collection") stop_btn = gr.Button("Stop Collection") restart_btn = gr.Button("Restart Collection") analyze_btn = gr.Button("Analyze Samples") train_btn = gr.Button("Train Tokenizer") progress = gr.Textbox(label="Progress", interactive=False, lines=10) decoded_text = gr.Textbox(label="Decoded Text", interactive=False) token_distribution = gr.Image(label="Token Distribution") start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size], progress) stop_btn.click(stop_collection, [], progress) restart_btn.click(restart_collection, [], progress) analyze_btn.click(analyze_checkpoint, [], progress) train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text], [progress, decoded_text, token_distribution]) demo.launch()