import torch import gradio as gr from tokenizers import Tokenizer from transformer.config import load_config from transformer.components.decoding import beam_search from transformer.transformer import Transformer DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") CONFIG_PATH = "configs/config.yaml" MODEL_PATH = "model_checkpoint.pt" TOKENIZER_PATH = "tokenizers/tokenizer-joint-de-en-vocab37000.json" MAX_LEN = 128 config = load_config(CONFIG_PATH) tokenizer = Tokenizer.from_file(TOKENIZER_PATH) padding_idx = tokenizer.token_to_id("[PAD]") model = Transformer.load_from_checkpoint(checkpoint_path=MODEL_PATH, config=config, device=DEVICE) def translate(text: str, beam_size: int = 4) -> str: src_ids = torch.tensor([tokenizer.encode(text).ids], device=DEVICE) src_mask = (src_ids != padding_idx).unsqueeze(1).unsqueeze(2) with torch.no_grad(): result_ids = beam_search( model, src_ids, src_mask, tokenizer, max_len=MAX_LEN, beam_size=beam_size, )[0] return tokenizer.decode(result_ids, skip_special_tokens=True) with gr.Blocks(title="Transformer From Scratch Translation Demo") as demo: gr.Markdown( "# Transformer From Scratch Translation Demo\n" "Translate English to German using a custom Transformer model trained from scratch.\n\n" "**Note:** This model was trained on the WMT14 English-German news dataset. It works best on formal, news-style sentences and may not perform well on everyday informal or conversational text." ) with gr.Row(equal_height=True): with gr.Column(): input_text = gr.Textbox( label="English Text", placeholder="Enter text to translate...", lines=3 ) beam_size = gr.Slider( minimum=1, maximum=8, step=1, value=4, label="Beam Size" ) with gr.Column(): output_text = gr.Textbox( label="German Translation", lines=3, interactive=False, show_copy_button=True, show_label=True ) with gr.Row(): with gr.Column(scale=1): pass with gr.Column(scale=2, min_width=300, elem_id="centered-controls"): translate_btn = gr.Button("Translate") gr.Examples( examples=[ ["Hello, how are you?"], ["The weather is nice today."], ["I love machine learning."], ], inputs=[input_text] ) with gr.Column(scale=1): pass translate_btn.click( translate, inputs=[input_text, beam_size], outputs=[output_text] ) if __name__ == "__main__": demo.launch()