GPT2-PBE / app.py
tymbos's picture
Update app.py
59d3e79 verified
raw
history blame
5.97 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from io import BytesIO
import tempfile
import os
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
from datasets import load_dataset
def create_iterator(files=None, dataset_name=None, dataset_config=None, split="train", streaming=True):
if dataset_name:
try:
# Επεξεργασία ονόματος dataset με έλεγχο εγκυρότητας
if not re.match(r'^[\w\-\.]+(/[\w\-\.]+)*$', dataset_name):
raise ValueError(f"Μη έγκυρο όνομα dataset: {dataset_name}")
# Φόρτωση dataset με config αν υπάρχει
dataset = load_dataset(
dataset_name,
name=dataset_config if dataset_config else None,
split=split,
streaming=streaming
)
for example in dataset:
yield example['text']
except Exception as e:
raise gr.Error(f"Σφάλμα φόρτωσης dataset: {str(e)}")
elif files:
for file in files:
with open(file.name, 'r', encoding='utf-8') as f:
for line in f:
yield line.strip()
def enhanced_validation(tokenizer, test_text):
"""
Εκτελεί επικύρωση του tokenizer με ένα roundtrip test και παρέχει στατιστικά.
"""
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded.ids)
# Μέτρηση των Unknown tokens
unknown_tokens = sum(1 for t in encoded.tokens if t == "<unk>")
unknown_percent = (unknown_tokens / len(encoded.tokens) * 100) if encoded.tokens else 0
# Υπολογισμός μήκους των tokens
token_lengths = [len(t) for t in encoded.tokens]
avg_length = np.mean(token_lengths) if token_lengths else 0
# Έλεγχος κάλυψης κώδικα: παραδείγματα συμβόλων
code_symbols = ['{', '}', '(', ')', ';', '//', 'printf']
code_coverage = {sym: (sym in test_text and sym in encoded.tokens) for sym in code_symbols}
# Δημιουργία histogram για την κατανομή του μήκους των tokens
fig = plt.figure()
plt.hist(token_lengths, bins=20, color='skyblue', edgecolor='black')
plt.xlabel('Μήκος Token')
plt.ylabel('Συχνότητα')
plt.title('Κατανομή Μήκους Tokens')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
img_buffer.seek(0)
return {
"roundtrip_success": test_text == decoded,
"unknown_tokens": f"{unknown_tokens} ({unknown_percent:.2f}%)",
"average_token_length": f"{avg_length:.2f}",
"code_coverage": code_coverage,
"token_length_distribution": img_buffer.getvalue()
}
# ... (προηγούμενο imports και functions παραμένουν ίδια)
def train_and_test(files, dataset_name, split, vocab_size, min_freq, test_text):
if not files and not dataset_name:
raise gr.Error("Πρέπει να παρέχετε αρχεία ή όνομα dataset!")
try:
# Δημιουργία iterator με fallback
iterator = create_iterator(files, dataset_name, split)
# Προσθήκη progress bar
with gr.Progress() as progress:
progress(0.1, desc="Προεπεξεργασία δεδομένων...")
tokenizer = train_tokenizer(iterator, vocab_size, min_freq)
# ... (υπόλοιπη λειτουργία παραμένει ίδια)
except Exception as e:
raise gr.Error(f"Σφάλμα εκπαίδευσης: {str(e)}")
# Αποθήκευση και φόρτωση του tokenizer για επικύρωση
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp:
tokenizer.save(tmp.name)
trained_tokenizer = Tokenizer.from_file(tmp.name)
os.unlink(tmp.name)
# Εκτενής επικύρωση με το δοκιμαστικό κείμενο
validation = enhanced_validation(trained_tokenizer, test_text)
return {
"validation_metrics": {k: v for k, v in validation.items() if k != "token_length_distribution"},
"histogram": validation["token_length_distribution"]
}
# Ενημερωμένο Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Προχωρημένος BPE Tokenizer Trainer")
with gr.Row():
with gr.Column():
with gr.Tab("Local Files"):
file_input = gr.File(file_count="multiple", label="Ανέβασμα αρχείων")
with gr.Tab("Hugging Face Dataset"):
dataset_name = gr.Textbox(label="Όνομα Dataset (π.χ. 'wikitext', 'codeparrot/github-code')")
split = gr.Textbox(value="train", label="Split")
vocab_size = gr.Slider(1000, 100000, value=32000, label="Μέγεθος Λεξιλογίου")
min_freq = gr.Slider(1, 100, value=2, label="Ελάχιστη Συχνότητα")
test_text = gr.Textbox(
value='function helloWorld() { console.log("Γειά σου Κόσμε!"); } // Ελληνικά + κώδικας',
label="Test Text"
)
train_btn = gr.Button("Εκπαίδευση Tokenizer", variant="primary")
with gr.Column():
results_json = gr.JSON(label="Μετρικές")
results_plot = gr.Image(label="Κατανομή Μήκους Tokens")
train_btn.click(
fn=train_and_test,
inputs=[file_input, dataset_name, split, vocab_size, min_freq, test_text],
outputs=[results_json, results_plot]
)
if __name__ == "__main__":
demo.launch()