Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
from torch.nn.utils.rnn import pad_sequence | |
from collections import Counter | |
from torchtext.vocab import Vocab | |
# ----------------------------- | |
# 1. Model Definition | |
# ----------------------------- | |
class GRUPoetryModel(nn.Module): | |
def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2): | |
super().__init__() | |
self.embedding = nn.Embedding(vocab_size, embed_dim) | |
self.gru = nn.GRU(embed_dim, hidden_dim, num_layers, batch_first=True) | |
self.fc = nn.Linear(hidden_dim, vocab_size) | |
def forward(self, x, hidden=None): | |
x = self.embedding(x) | |
out, hidden = self.gru(x, hidden) | |
logits = self.fc(out) | |
return logits, hidden | |
# ----------------------------- | |
# 2. Tokenization & Vocab | |
# ----------------------------- | |
def tokenize(text): | |
return text.split() | |
# Dummy data to rebuild vocab | |
dummy_texts = [ | |
"main ek hoon", "dil se teri yaad aai", "ab yahan se kahan jaye hum", | |
"tum na aaye to kya ghar kya bazaar", "ishq mein zindagi bhar dena" | |
] | |
# Build vocab | |
def yield_tokens(texts): | |
for text in texts: | |
yield tokenize(text.lower()) | |
counter = Counter() | |
for tokens in yield_tokens(dummy_texts): | |
counter.update(tokens) | |
vocab = Vocab(counter, specials=["<unk>", "<pad>"]) | |
vocab.set_default_index(vocab["<unk>"]) | |
# ----------------------------- | |
# 3. Poetry Generation Logic | |
# ----------------------------- | |
def generate_poetry(model, vocab, start_text="main", max_len=50): | |
model.eval() | |
tokens = [vocab[token] for token in tokenize(start_text.lower())] | |
input_tensor = torch.tensor([tokens], dtype=torch.long).to(next(model.parameters()).device) | |
hidden = None | |
for _ in range(max_len - len(tokens)): | |
logits, hidden = model(input_tensor, hidden) | |
pred_token = logits.argmax(2)[:, -1].item() | |
input_tensor = torch.cat([input_tensor, torch.tensor([[pred_token]], device=input_tensor.device)], dim=1) | |
generated = input_tensor[0].cpu().tolist() | |
return " ".join([vocab.get_itos()[idx] for idx in generated]) | |
# ----------------------------- | |
# 4. Load Model | |
# ----------------------------- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = GRUPoetryModel(len(vocab)) | |
model.load_state_dict(torch.load("gru_poetry_model.pth", map_location=device)) | |
model.to(device) | |
model.eval() | |
# ----------------------------- | |
# 5. Gradio Interface | |
# ----------------------------- | |
import gradio as gr | |
def generate(start_text): | |
return generate_poetry(model, vocab, start_text) | |
demo = gr.Interface( | |
fn=generate, | |
inputs=gr.Textbox(placeholder="Enter starting line...", label="Start Text"), | |
outputs="text", | |
title="Roman Urdu Poetry Generator", | |
description="Generate Urdu poetry using GRU model trained on ghazal and poetry datasets.", | |
examples=[["main ek hoon"], ["dil se teri yaad aai"]] | |
) | |
demo.launch() |