import torch import torch.nn as nn from huggingface_hub import PyTorchModelHubMixin class DiffusionTextModel(nn.Module, PyTorchModelHubMixin): def __init__(self, vocab_size, max_seq_len, max_time_steps, embed_dim=128, n_layers=4, n_heads=4): super().__init__() self.config = { "vocab_size": vocab_size, "max_seq_len": max_seq_len, "max_time_steps": max_time_steps, "embed_dim": embed_dim, "n_layers": n_layers, "n_heads": n_heads } self.token_emb = nn.Embedding(vocab_size, embed_dim) self.pos_emb = nn.Embedding(max_seq_len, embed_dim) self.time_emb = nn.Embedding(max_time_steps+1, embed_dim) enc_layer = nn.TransformerEncoderLayer( d_model=embed_dim, nhead=n_heads, dim_feedforward=4*embed_dim, activation="gelu" ) self.transformer = nn.TransformerEncoder(enc_layer, num_layers=n_layers) self.out = nn.Linear(embed_dim, vocab_size) def forward(self, x, t): B, L = x.shape tok = self.token_emb(x) pos = self.pos_emb(torch.arange(L, device=x.device).unsqueeze(0).expand(B, L)) tim = self.time_emb(t).unsqueeze(1).expand(B, L, -1) h = tok + pos + tim h = self.transformer(h.transpose(0,1)).transpose(0,1) return self.out(h)