import torch import torch.nn as nn from torch.nn.utils.rnn import pad_sequence from collections import Counter from torchtext.vocab import Vocab # ----------------------------- # 1. Model Definition # ----------------------------- class GRUPoetryModel(nn.Module): def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.gru = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True) self.fc = nn.Linear(hidden_dim, vocab_size) def forward(self, x, hidden=None): x = self.embedding(x) out, hidden = self.gru(x, hidden) logits = self.fc(out) return logits, hidden # ----------------------------- # 2. Tokenization & Vocab # ----------------------------- def tokenize(text): return text.split() # Dummy data to rebuild vocab dummy_texts = [ "main ek hoon", "dil se teri yaad aai", "ab yahan se kahan jaye hum", "tum na aaye to kya ghar kya bazaar", "ishq mein zindagi bhar dena" ] # Build vocab def yield_tokens(texts): for text in texts: yield tokenize(text.lower()) counter = Counter() for tokens in yield_tokens(dummy_texts): counter.update(tokens) vocab = Vocab(counter, specials=["", ""]) vocab.set_default_index(vocab[""]) # ----------------------------- # 3. Poetry Generation Logic # ----------------------------- def generate_poetry(model, vocab, start_text="main", max_len=50): model.eval() tokens = [vocab[token] for token in tokenize(start_text.lower())] input_tensor = torch.tensor([tokens], dtype=torch.long).to(next(model.parameters()).device) hidden = None for _ in range(max_len - len(tokens)): logits, hidden = model(input_tensor, hidden) pred_token = logits.argmax(2)[:, -1].item() input_tensor = torch.cat([input_tensor, torch.tensor([[pred_token]], device=input_tensor.device)], dim=1) generated = input_tensor[0].cpu().tolist() return " ".join([vocab.get_itos()[idx] for idx in generated]) # ----------------------------- # 4. Load Model # ----------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" model = GRUPoetryModel(len(vocab)) model.load_state_dict(torch.load("gru_poetry_model.pth", map_location=device)) model.to(device) model.eval() # ----------------------------- # 5. Gradio Interface # ----------------------------- import gradio as gr def generate(start_text): return generate_poetry(model, vocab, start_text) demo = gr.Interface( fn=generate, inputs=gr.Textbox(placeholder="Enter starting line...", label="Start Text"), outputs="text", title="Roman Urdu Poetry Generator", description="Generate Urdu poetry using GRU model trained on ghazal and poetry datasets.", examples=[["main ek hoon"], ["dil se teri yaad aai"]] ) demo.launch()