huzaifa113's picture
Update app.py
2a513c1 verified
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from collections import Counter
from torchtext.vocab import Vocab
# -----------------------------
# 1. Model Definition
# -----------------------------
class GRUPoetryModel(nn.Module):
def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.gru = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, hidden=None):
x = self.embedding(x)
out, hidden = self.gru(x, hidden)
logits = self.fc(out)
return logits, hidden
# -----------------------------
# 2. Tokenization & Vocab
# -----------------------------
def tokenize(text):
return text.split()
# Dummy data to rebuild vocab
dummy_texts = [
"main ek hoon", "dil se teri yaad aai", "ab yahan se kahan jaye hum",
"tum na aaye to kya ghar kya bazaar", "ishq mein zindagi bhar dena"
]
# Build vocab
def yield_tokens(texts):
for text in texts:
yield tokenize(text.lower())
counter = Counter()
for tokens in yield_tokens(dummy_texts):
counter.update(tokens)
vocab = Vocab(counter, specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"])
# -----------------------------
# 3. Poetry Generation Logic
# -----------------------------
def generate_poetry(model, vocab, start_text="main", max_len=50):
model.eval()
tokens = [vocab[token] for token in tokenize(start_text.lower())]
input_tensor = torch.tensor([tokens], dtype=torch.long).to(next(model.parameters()).device)
hidden = None
for _ in range(max_len - len(tokens)):
logits, hidden = model(input_tensor, hidden)
pred_token = logits.argmax(2)[:, -1].item()
input_tensor = torch.cat([input_tensor, torch.tensor([[pred_token]], device=input_tensor.device)], dim=1)
generated = input_tensor[0].cpu().tolist()
return " ".join([vocab.get_itos()[idx] for idx in generated])
# -----------------------------
# 4. Load Model
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
model = GRUPoetryModel(len(vocab))
model.load_state_dict(torch.load("gru_poetry_model.pth", map_location=device))
model.to(device)
model.eval()
# -----------------------------
# 5. Gradio Interface
# -----------------------------
import gradio as gr
def generate(start_text):
return generate_poetry(model, vocab, start_text)
demo = gr.Interface(
fn=generate,
inputs=gr.Textbox(placeholder="Enter starting line...", label="Start Text"),
outputs="text",
title="Roman Urdu Poetry Generator",
description="Generate Urdu poetry using GRU model trained on ghazal and poetry datasets.",
examples=[["main ek hoon"], ["dil se teri yaad aai"]]
)
demo.launch()