RND1-Base-0910 / modeling_rnd.py
athms's picture
Upload folder using huggingface_hub
3e7a3bf verified
raw
history blame
21.3 kB
"""
RND1 model implementation.
This module implements the RND1 architecture with bidirectional attention for
diffusion-based language modeling. Includes support for Mixture of Experts (MoE)
with multiple backend options (HF, FlashInfer, SGLang).
Based on the Qwen3Moe architecture:
https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
"""
from __future__ import annotations
import os
from typing import Optional, Tuple, List, Union
import torch
from torch import nn
from transformers.utils import logging
from transformers.cache_utils import Cache
from transformers.modeling_outputs import (
MoeModelOutputWithPast,
MaskedLMOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers.generation import GenerationConfig
from .configuration_rnd import RND1Config
from .generation_utils import RND1GenerationMixin
from .generation_config import RND1GenerationConfig
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
Qwen3MoeConfig,
Qwen3MoeRMSNorm,
Qwen3MoeRotaryEmbedding,
Qwen3MoeSparseMoeBlock,
Qwen3MoeMLP,
apply_rotary_pos_emb
)
import torch.nn.functional as F
try:
import flashinfer.fused_moe as fused_moe
except Exception:
fused_moe = None
try:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe as sglang_fused_moe
from sglang.srt.layers.moe.topk import StandardTopKOutput
except Exception:
sglang_fused_moe = None
StandardTopKOutput = None
logger = logging.get_logger(__name__)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Expand key/value heads to match query heads for grouped-query attention."""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class RND1Attention(nn.Module):
"""RND1 attention layer with bidirectional attention for diffusion modeling."""
def __init__(self, config: RND1Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.scaling = self.head_dim ** -0.5
self.attention_dropout = config.attention_dropout
self.is_causal = False
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.sliding_window = getattr(config, "sliding_window", None)
self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
dual_cache: Optional[bool] = False,
replace_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]]]:
bsz, q_len, _ = hidden_states.size()
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
use_sdpa = (getattr(self.config, "_attn_implementation", "eager") == "sdpa")
if use_sdpa:
if attention_mask is not None and isinstance(attention_mask, torch.Tensor):
if attention_mask.dtype not in [torch.bool, torch.float32, torch.float16, torch.bfloat16]:
attention_mask = attention_mask.to(dtype=query_states.dtype)
assert not self.is_causal, f"Attention layer {self.layer_idx} is causal"
attn_out = torch.nn.functional.scaled_dot_product_attention(
query_states, key_states, value_states,
attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=self.is_causal,
)
attn_out = attn_out.transpose(1, 2).contiguous()
attn_out = attn_out.view(bsz, q_len, self.num_heads * self.head_dim)
attn_out = self.o_proj(attn_out)
return attn_out, None
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_out = torch.matmul(attn_weights, value_states)
attn_out = attn_out.transpose(1, 2).contiguous().view(hidden_states.size(0), hidden_states.size(1), -1)
attn_out = self.o_proj(attn_out)
return attn_out, None
class RND1DecoderLayer(nn.Module):
"""RND1 decoder layer with bidirectional attention for diffusion language modeling."""
def __init__(self, config: RND1Config, layer_idx: int):
super().__init__()
self.self_attn = RND1Attention(config, layer_idx)
self.mlp = RND1SparseMoeBlock(config)
self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
replace_position: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[torch.Tensor]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_out, attn_weights = self.self_attn(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
position_embeddings=position_embeddings,
replace_position=replace_position,
)
hidden_states = residual + attn_out
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
ff_out = self.mlp(hidden_states)
if isinstance(ff_out, tuple):
ff_out = ff_out[0]
hidden_states = residual + ff_out
return hidden_states, attn_weights
class RND1SparseMoeBlock(nn.Module):
"""RND1 Sparse MoE block with multiple backend support (HF, FlashInfer, SGLang)."""
def __init__(self, config: RND1Config):
super().__init__()
self.config = config
self.backend = getattr(config, "moe_backend", "hf")
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.hidden_size = config.hidden_size
self.intermediate_size = getattr(config, "moe_intermediate_size", config.intermediate_size)
self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
self.experts = nn.ModuleList(
[Qwen3MoeMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)]
)
# Cached weight tensors for optimized backends
self._flashinfer_fc1_weights = None
self._flashinfer_fc2_weights = None
self._sglang_w1 = None
self._sglang_w2 = None
if self.backend == "sglang":
if sglang_fused_moe is None or StandardTopKOutput is None:
raise RuntimeError("sglang is not available, cannot use sglang backend")
elif self.backend == "flashinfer":
if fused_moe is None:
raise RuntimeError("flashinfer is not available, cannot use flashinfer backend")
def _initialize_flashinfer_weights(self):
"""Initialize FlashInfer-compatible weight format."""
fc1_list = []
fc2_list = []
for expert in self.experts:
gate_w = expert.gate_proj.weight # [I, H]
up_w = expert.up_proj.weight # [I, H]
down_w = expert.down_proj.weight # [H, I]
# FlashInfer expects [up; gate] ordering
fc1_list.append(torch.cat([up_w, gate_w], dim=0)) # [2I, H]
fc2_list.append(down_w) # [H, I]
self._flashinfer_fc1_weights = torch.stack(fc1_list, dim=0).contiguous()
self._flashinfer_fc2_weights = torch.stack(fc2_list, dim=0).contiguous()
def _initialize_sglang_weights(self):
"""Initialize SGLang-compatible weight format."""
w1_list = []
w2_list = []
for expert in self.experts:
gate_w = expert.gate_proj.weight # [I, H]
up_w = expert.up_proj.weight # [I, H]
down_w = expert.down_proj.weight # [H, I]
w1 = torch.cat([gate_w, up_w], dim=0) # [2I, H]
w1_list.append(w1)
w2_list.append(down_w)
self._sglang_w1 = torch.stack(w1_list, dim=0).contiguous()
self._sglang_w2 = torch.stack(w2_list, dim=0).contiguous()
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass with expert routing and computation."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
x = hidden_states.view(-1, hidden_dim)
# Expert routing
router_logits = self.gate(x)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
if self.backend == "hf":
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
current_state = x[top_x]
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return out, router_logits.view(batch_size, sequence_length, -1)
elif self.backend == "flashinfer":
if self._flashinfer_fc1_weights is None or self._flashinfer_fc2_weights is None:
self._initialize_flashinfer_weights()
result = fused_moe.cutlass_fused_moe(
input=x,
token_selected_experts=selected_experts.to(torch.int),
token_final_scales=routing_weights.to(torch.float32),
fc1_expert_weights=self._flashinfer_fc1_weights,
fc2_expert_weights=self._flashinfer_fc2_weights,
output_dtype=x.dtype,
quant_scales=None,
)
if isinstance(result, (list, tuple)):
out_flat = result[0]
else:
out_flat = result
out = out_flat.view(batch_size, sequence_length, hidden_dim)
return out, router_logits.view(batch_size, sequence_length, -1)
elif self.backend == "sglang":
if self._sglang_w1 is None or self._sglang_w2 is None:
self._initialize_sglang_weights()
topk_output = StandardTopKOutput(
topk_weights=routing_weights,
topk_ids=selected_experts,
router_logits=router_logits,
)
out_flat = sglang_fused_moe(
hidden_states=x,
w1=self._sglang_w1,
w2=self._sglang_w2,
topk_output=topk_output,
)
out = out_flat.view(batch_size, sequence_length, hidden_dim)
return out, router_logits.view(batch_size, sequence_length, -1)
else:
raise ValueError(f"Invalid backend: {self.backend}")
class RND1PreTrainedModel(PreTrainedModel):
"""Base class for RND1 models with weight initialization and loading support."""
config_class = RND1Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["RND1DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
"""Initialize weights using normal distribution."""
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
*model_args,
config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
ignore_mismatched_sizes: bool = False,
force_download: bool = False,
local_files_only: bool = False,
token: Optional[Union[str, bool]] = None,
revision: str = "main",
use_safetensors: Optional[bool] = None,
weights_only: bool = True,
**kwargs,
):
"""Load pretrained model with generation config."""
_model = super().from_pretrained(
pretrained_model_name_or_path,
*model_args,
config=config,
cache_dir=cache_dir,
ignore_mismatched_sizes=ignore_mismatched_sizes,
force_download=force_download,
local_files_only=local_files_only,
token=token,
revision=revision,
use_safetensors=use_safetensors,
weights_only=weights_only,
**kwargs,
)
resume_download = kwargs.get("resume_download", None)
proxies = kwargs.get("proxies", None)
subfolder = kwargs.get("subfolder", "")
from_auto_class = kwargs.get("_from_auto", False)
from_pipeline = kwargs.get("_from_pipeline", None)
_model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
)
return _model
class RND1Model(RND1PreTrainedModel):
"""RND1 transformer model with bidirectional attention for diffusion language modeling."""
def __init__(self, config: RND1Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([RND1DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> MoeModelOutputWithPast:
"""Forward pass through the RND1 model."""
if (input_ids is None) == (inputs_embeds is None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
hidden_states = inputs_embeds
for layer in self.layers:
hidden_states, _ = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
position_embeddings=position_embeddings,
)
hidden_states = self.norm(hidden_states)
return MoeModelOutputWithPast(
last_hidden_state=hidden_states,
router_logits=None,
)
class RND1LM(RND1PreTrainedModel, RND1GenerationMixin):
"""Radical Numerics Diffusion Language Model with bidirectional attention."""
def __init__(self, config: RND1Config):
super().__init__(config)
self.model = RND1Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
"""Get the input embeddings layer."""
return self.model.embed_tokens
def set_input_embeddings(self, value):
"""Set the input embeddings layer."""
self.model.embed_tokens = value
def get_output_embeddings(self):
"""Get the output embeddings layer (lm_head)."""
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""Set the output embeddings layer (lm_head)."""
self.lm_head = new_embeddings
@classmethod
def can_generate(cls) -> bool:
"""Indicates this model can generate text."""
return True
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> MaskedLMOutput:
"""Forward pass with optional loss computation."""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
**kwargs,
)
logits = self.lm_head(outputs.last_hidden_state)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return MaskedLMOutput(
loss=loss,
logits=logits,
)