Mohaddz commited on
Commit
278d275
·
1 Parent(s): 08b5ccb
demo_rnd_generation.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Demo script for RND1 generation.
4
+ """
5
+
6
+ import torch
7
+ import argparse
8
+ import os
9
+ import sys
10
+ import random
11
+ import numpy as np
12
+ from transformers import AutoTokenizer
13
+
14
+ # Add RND1 module to path for local testing
15
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16
+
17
+
18
+ def set_seed(seed: int):
19
+ """Set random seed for reproducibility.
20
+ """
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed_all(seed)
25
+
26
+
27
+ def demo_completion(
28
+ model_path: str,
29
+ checkpoint_path: str = None,
30
+ device: str = "cuda:0",
31
+ use_bfloat16: bool = True,
32
+ show_visualization: bool = True,
33
+ num_steps: int = 64,
34
+ max_new_tokens: int = 256,
35
+ custom_prompt: str = None,
36
+ temperature: float = 1.0,
37
+ top_k: int = None,
38
+ top_p: float = None,
39
+ mask_token_id: int = 151669,
40
+ seed: int = 12345,
41
+ moe_backend: str = "hf",
42
+ mode: str = "task",
43
+ ):
44
+ """
45
+ Demonstrate text completion using RND1.
46
+
47
+ Args:
48
+ model_path: Path to base model or HuggingFace model ID
49
+ checkpoint_path: Path to custom checkpoint (if any)
50
+ device: Device to run on (e.g., cuda:0, cpu)
51
+ use_bfloat16: Whether to use bfloat16 precision
52
+ show_visualization: Whether to show live visualization (requires rich)
53
+ num_steps: Number of diffusion steps
54
+ max_new_tokens: Maximum number of tokens to generate
55
+ custom_prompt: Custom prompt to use instead of default examples
56
+ temperature: Temperature for sampling (0.0 = greedy)
57
+ top_k: Top-k filtering for sampling (None = disabled)
58
+ top_p: Top-p (nucleus) filtering for sampling (None = disabled)
59
+ mask_token_id: Token ID for mask token
60
+ seed: Random seed for reproducibility
61
+ moe_backend: MoE backend to use ('hf' or 'flashinfer')
62
+ mode: Generation mode ('task' for Q&A format, 'completion' for continuation)
63
+ """
64
+ set_seed(seed)
65
+
66
+ from rnd.configuration_rnd import RND1Config
67
+ from rnd.modeling_rnd import RND1LM
68
+
69
+ print("Loading tokenizer...")
70
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
71
+
72
+ dtype = torch.bfloat16 if use_bfloat16 else torch.float32
73
+ print(f"Using dtype: {dtype}")
74
+
75
+ if moe_backend == "hf":
76
+ print("\n⚠️ Note: HuggingFace backend is slower. Consider using --moe_backend flashinfer or sglang for better performance.\n")
77
+
78
+ # Load from checkpoint if provided, otherwise from model_path
79
+ load_path = checkpoint_path if checkpoint_path else model_path
80
+
81
+ print(f"Loading model from {load_path}...")
82
+
83
+ # Load config and set RND1-specific settings
84
+ cfg = RND1Config.from_pretrained(load_path)
85
+ cfg.model_type = "rnd1"
86
+ cfg.attn_implementation = "sdpa"
87
+ cfg.moe_backend = moe_backend
88
+
89
+ # Load model with RND1LM
90
+ model = RND1LM.from_pretrained(
91
+ load_path,
92
+ config=cfg,
93
+ torch_dtype=dtype,
94
+ device_map="auto" if device == "cuda:0" else device,
95
+ trust_remote_code=True,
96
+ use_safetensors=True,
97
+ low_cpu_mem_usage=True,
98
+ )
99
+ print("Model loaded")
100
+ model = model.eval()
101
+
102
+ if custom_prompt:
103
+ prompts = [custom_prompt]
104
+ else:
105
+ # Default prompts based on mode
106
+ if mode == "task":
107
+ prompts = ["Write a Python function that finds the longest common subsequence of two strings. Include comments explaining the algorithm."]
108
+ else:
109
+ prompts = ["The key to understanding quantum computing lies in"]
110
+
111
+ greedy = (temperature == 1.0)
112
+
113
+ generator = torch.Generator(device=device if device != "auto" else "cuda")
114
+ generator.manual_seed(seed)
115
+
116
+ for i, user_prompt in enumerate(prompts):
117
+ print(f"\n{'='*60}")
118
+ print(f"Mode: {mode.upper()}")
119
+ print(f"Prompt {i+1}: {user_prompt[:100]}...")
120
+ print(f"{'='*60}\n")
121
+
122
+ if mode == "task":
123
+ # Task mode: Add "Question: " prefix if not already present
124
+ if not user_prompt.strip().startswith("Question:"):
125
+ prompt = f"Question: {user_prompt}\n"
126
+ else:
127
+ prompt = user_prompt
128
+ else:
129
+ # Completion mode: Use prompt as-is for continuation
130
+ prompt = user_prompt
131
+
132
+ inputs = tokenizer(prompt, return_tensors="pt")
133
+ input_ids = inputs.input_ids.to(device if device != "auto" else "cuda")
134
+ attention_mask = inputs.attention_mask.to(device if device != "auto" else "cuda") if 'attention_mask' in inputs else None
135
+
136
+ print("Generation parameters:")
137
+ print(f" Prompt length: {input_ids.shape[1]} tokens")
138
+ print(f" Max new tokens: {max_new_tokens}")
139
+ print(f" Total sequence: {input_ids.shape[1] + max_new_tokens} tokens")
140
+ print(f" Diffusion steps: {num_steps}")
141
+ print(f" Temperature: {temperature}")
142
+ print(f" Greedy: {greedy}")
143
+ if top_k:
144
+ print(f" Top-k: {top_k}")
145
+ if top_p:
146
+ print(f" Top-p: {top_p}")
147
+ print()
148
+
149
+ # Create explicit generation config that takes priority over model defaults
150
+ from rnd.generation_config import RND1GenerationConfig
151
+ gen_config = RND1GenerationConfig(
152
+ max_new_tokens=max_new_tokens,
153
+ num_diffusion_steps=num_steps,
154
+ mask_token_id=mask_token_id,
155
+ temperature=temperature if not greedy else 1.0,
156
+ top_k=top_k,
157
+ top_p=top_p,
158
+ greedy=greedy,
159
+ eos_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id else 151645,
160
+ pad_token_id=tokenizer.pad_token_id,
161
+ bos_token_id=tokenizer.bos_token_id,
162
+ )
163
+
164
+ with torch.no_grad():
165
+ if show_visualization and hasattr(model, 'generate_with_visualization'):
166
+ # Use method with visualization support (requires tokenizer)
167
+ output = model.generate_with_visualization(
168
+ tokenizer=tokenizer,
169
+ inputs=input_ids,
170
+ generation_config=gen_config,
171
+ generator=generator,
172
+ )
173
+ else:
174
+ # Use standard generate method with explicit config
175
+ output = model.generate(
176
+ inputs=input_ids,
177
+ generation_config=gen_config,
178
+ generator=generator,
179
+ )
180
+
181
+ generated_tokens = output[0][len(input_ids[0]):]
182
+ generation = tokenizer.decode(
183
+ generated_tokens.tolist(),
184
+ skip_special_tokens=True
185
+ )
186
+
187
+ print("\nGenerated response:")
188
+ print(generation)
189
+
190
+ print(f"\n(Generation completed in {num_steps} diffusion steps)")
191
+
192
+
193
+ def main():
194
+ parser = argparse.ArgumentParser(
195
+ description="RND1 diffusion model demo with live visualization",
196
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
197
+ )
198
+
199
+ # Model configuration
200
+ model_group = parser.add_argument_group('Model Configuration')
201
+ model_group.add_argument(
202
+ "--model_path",
203
+ type=str,
204
+ default="radicalnumerics/RND1-Base-0910",
205
+ help="Path to model or HuggingFace model ID"
206
+ )
207
+ model_group.add_argument(
208
+ "--checkpoint",
209
+ type=str,
210
+ default=None,
211
+ help="Path to custom checkpoint file or directory"
212
+ )
213
+ model_group.add_argument(
214
+ "--device",
215
+ type=str,
216
+ default="cuda:0",
217
+ help="Device to run on (e.g., cuda:0, cpu)"
218
+ )
219
+ model_group.add_argument(
220
+ "--fp32",
221
+ action="store_true",
222
+ help="Use FP32 precision instead of BF16"
223
+ )
224
+
225
+ # Generation configuration
226
+ gen_group = parser.add_argument_group('Generation Settings')
227
+ gen_group.add_argument(
228
+ "--num_steps",
229
+ type=int,
230
+ default=256,
231
+ help="Number of diffusion steps"
232
+ )
233
+ gen_group.add_argument(
234
+ "--max_new_tokens",
235
+ type=int,
236
+ default=256,
237
+ help="Maximum number of tokens to generate"
238
+ )
239
+ gen_group.add_argument(
240
+ "--prompt",
241
+ type=str,
242
+ default=None,
243
+ help="Custom prompt to use for generation"
244
+ )
245
+ gen_group.add_argument(
246
+ "--mode",
247
+ type=str,
248
+ default="task",
249
+ choices=["task", "completion"],
250
+ help="Generation mode: 'task' (Q&A format for instructions) or 'completion' (text continuation)"
251
+ )
252
+ gen_group.add_argument(
253
+ "--mask_token_id",
254
+ type=int,
255
+ default=151669,
256
+ help="Token ID for mask token"
257
+ )
258
+
259
+ # Sampling configuration
260
+ sampling_group = parser.add_argument_group('Sampling Parameters')
261
+ sampling_group.add_argument(
262
+ "--temperature",
263
+ type=float,
264
+ default=1.0,
265
+ help="Temperature for sampling (1.0 = greedy/deterministic)"
266
+ )
267
+ sampling_group.add_argument(
268
+ "--top_k",
269
+ type=int,
270
+ default=None,
271
+ help="Top-k filtering: keep only k most likely tokens"
272
+ )
273
+ sampling_group.add_argument(
274
+ "--top_p",
275
+ type=float,
276
+ default=None,
277
+ help="Top-p (nucleus) filtering: keep tokens with cumulative probability <= p"
278
+ )
279
+
280
+ # Visualization
281
+ viz_group = parser.add_argument_group('Visualization')
282
+ viz_group.add_argument(
283
+ "--no_viz",
284
+ action="store_true",
285
+ help="Disable live visualization during generation (requires rich library)"
286
+ )
287
+
288
+ # Other settings
289
+ other_group = parser.add_argument_group('Other Settings')
290
+ other_group.add_argument(
291
+ "--seed",
292
+ type=int,
293
+ default=12345,
294
+ help="Random seed for reproducibility"
295
+ )
296
+
297
+ moe_backend_group = parser.add_argument_group('MoE Backend')
298
+ moe_backend_group.add_argument(
299
+ "--moe_backend",
300
+ type=str,
301
+ default="hf",
302
+ choices=["hf", "flashinfer", "sglang"],
303
+ help="MoE backend to use for sparse mixture of experts layers"
304
+ )
305
+
306
+ args = parser.parse_args()
307
+
308
+ if args.temperature < 0:
309
+ parser.error("Temperature must be non-negative")
310
+ if args.top_k is not None and args.top_k <= 0:
311
+ parser.error("Top-k must be positive")
312
+ if args.top_p is not None and (args.top_p <= 0 or args.top_p > 1):
313
+ parser.error("Top-p must be between 0 and 1")
314
+
315
+
316
+ print("\n" + "="*60)
317
+ print("RND1 Diffusion Language Model Demo")
318
+ print("="*60)
319
+ print("Configuration:")
320
+ print(f" Model: {args.model_path}")
321
+ if args.checkpoint:
322
+ print(f" Checkpoint: {args.checkpoint}")
323
+ print(f" Device: {args.device}")
324
+ print(f" Precision: {'FP32' if args.fp32 else 'BF16'}")
325
+ print(f" Mode: {args.mode.upper()} ({'Q&A format for instructions' if args.mode == 'task' else 'Text continuation'})")
326
+ print(f" Random seed: {args.seed}")
327
+ print(f" Diffusion steps: {args.num_steps}")
328
+ print(f" Max new tokens: {args.max_new_tokens}")
329
+ print(f" Algorithm: Entropy-based selection")
330
+ print(f" Temperature: {args.temperature}")
331
+ if args.top_k:
332
+ print(f" Top-k: {args.top_k}")
333
+ if args.top_p:
334
+ print(f" Top-p: {args.top_p}")
335
+ print(f" MoE Backend: {args.moe_backend}")
336
+ print(f" Visualization: {'Enabled' if not args.no_viz else 'Disabled'}")
337
+ print("="*60 + "\n")
338
+
339
+ demo_completion(
340
+ model_path=args.model_path,
341
+ checkpoint_path=args.checkpoint,
342
+ device=args.device,
343
+ use_bfloat16=not args.fp32,
344
+ show_visualization=not args.no_viz,
345
+ num_steps=args.num_steps,
346
+ max_new_tokens=args.max_new_tokens,
347
+ custom_prompt=args.prompt,
348
+ temperature=args.temperature,
349
+ top_k=args.top_k,
350
+ top_p=args.top_p,
351
+ mask_token_id=args.mask_token_id,
352
+ seed=args.seed,
353
+ moe_backend=args.moe_backend,
354
+ mode=args.mode,
355
+ )
356
+
357
+
358
+ if __name__ == "__main__":
359
+ main()
pyproject.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "rnd"
7
+ version = "0.1.0"
8
+ dependencies = [
9
+ "accelerate",
10
+ "torch>=2.8",
11
+ "transformers",
12
+ "rich"
13
+ ]
14
+
15
+ [project.optional-dependencies]
16
+ flashinfer = [
17
+ "flashinfer-python",
18
+ ]
19
+ sglang = ["sglang[all]"]
20
+
21
+ [tool.setuptools]
22
+ packages = ["rnd"]
rnd/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Radical Numerics Inc.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0, found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ Radical Numerics Diffusion (RND1) - Diffusion-based Language Model.
8
+ """
9
+
10
+ from .configuration_rnd import RND1Config
11
+ from .modeling_rnd import (
12
+ RND1LM,
13
+ RND1Model,
14
+ RND1PreTrainedModel,
15
+ RND1Attention,
16
+ RND1DecoderLayer,
17
+ RND1SparseMoeBlock,
18
+ )
19
+ from .generation_config import RND1GenerationConfig
20
+ from .generation_utils import RND1GenerationMixin
21
+ from .sampling import (
22
+ diffusion_sample,
23
+ apply_top_k_filtering,
24
+ apply_top_p_filtering,
25
+ )
26
+ from .terminal_visualizer import TerminalVisualizer, SimpleProgressBar
27
+
28
+ __version__ = "0.1.0"
29
+
30
+ __all__ = [
31
+ "RND1Config",
32
+ "RND1GenerationConfig",
33
+ "RND1LM",
34
+ "RND1Model",
35
+ "RND1PreTrainedModel",
36
+ "RND1Attention",
37
+ "RND1DecoderLayer",
38
+ "RND1SparseMoeBlock",
39
+ "RND1GenerationMixin",
40
+ "TerminalVisualizer",
41
+ "SimpleProgressBar",
42
+ ]
43
+
44
+ # Register with HuggingFace Auto classes for local usage
45
+ try:
46
+ from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM
47
+
48
+ AutoConfig.register("rnd1", RND1Config)
49
+ AutoModel.register(RND1Config, RND1Model)
50
+ AutoModelForMaskedLM.register(RND1Config, RND1LM)
51
+ except ImportError:
52
+ # transformers not available or Auto classes not imported
53
+ pass
rnd/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.07 kB). View file
 
rnd/__pycache__/configuration_rnd.cpython-310.pyc ADDED
Binary file (3.49 kB). View file
 
rnd/__pycache__/generation_config.cpython-310.pyc ADDED
Binary file (2.34 kB). View file
 
rnd/__pycache__/generation_utils.cpython-310.pyc ADDED
Binary file (5.54 kB). View file
 
rnd/__pycache__/modeling_rnd.cpython-310.pyc ADDED
Binary file (16.2 kB). View file
 
rnd/__pycache__/sampling.cpython-310.pyc ADDED
Binary file (7.07 kB). View file
 
rnd/__pycache__/terminal_visualizer.cpython-310.pyc ADDED
Binary file (7.31 kB). View file
 
rnd/configuration_rnd.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Radical Numerics Inc.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0, found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ RND1 Model Configuration.
8
+
9
+ This module defines the configuration class for RND1 models.
10
+ The default settings are derived from Qwen/Qwen3-30B-A3B and augmented
11
+ with RND1-specific parameters.
12
+ """
13
+
14
+ from transformers.configuration_utils import PretrainedConfig
15
+
16
+ # Qwen3-30B-A3B / checkpoint defaults
17
+ CONFIG_DEFAULTS = {
18
+ "attention_bias": False,
19
+ "attention_dropout": 0.0,
20
+ "bos_token_id": 151643,
21
+ "decoder_sparse_step": 1,
22
+ "eos_token_id": 151645,
23
+ "head_dim": 128,
24
+ "hidden_act": "silu",
25
+ "hidden_size": 2048,
26
+ "initializer_range": 0.02,
27
+ "intermediate_size": 6144,
28
+ "max_position_embeddings": 40960,
29
+ "max_window_layers": 48,
30
+ "mlp_only_layers": [],
31
+ "moe_intermediate_size": 768,
32
+ "norm_topk_prob": True,
33
+ "num_attention_heads": 32,
34
+ "num_experts": 128,
35
+ "num_experts_per_tok": 8,
36
+ "num_hidden_layers": 48,
37
+ "num_key_value_heads": 4,
38
+ "output_router_logits": False,
39
+ "rms_norm_eps": 1e-06,
40
+ "rope_scaling": False,
41
+ "rope_theta": 1000000.0,
42
+ "router_aux_loss_coef": 0.001,
43
+ "sliding_window": False,
44
+ "tie_word_embeddings": False,
45
+ "torch_dtype": "bfloat16",
46
+ "use_cache": False,
47
+ "use_sliding_window": False,
48
+ "vocab_size": 151936,
49
+ }
50
+
51
+
52
+ class RND1Config(PretrainedConfig):
53
+ """
54
+ Configuration class for RND1 models.
55
+
56
+ This configuration extends Qwen3MoeConfig with additional parameters
57
+ specific to the RND1 (Radical Numerics Diffusion v1) architecture.
58
+
59
+ Args:
60
+ moe_backend: Backend for MoE computation ("hf", "flashinfer", or "sglang")
61
+ num_diffusion_steps: Default number of diffusion steps for generation
62
+ mask_token_id: Token ID used for masking (default: 151669 for Qwen)
63
+ **kwargs: Additional arguments passed to Qwen3MoeConfig
64
+ """
65
+
66
+ model_type = "rnd1"
67
+
68
+ def __init__(
69
+ self,
70
+ moe_backend: str = "hf",
71
+ num_diffusion_steps: int = 256,
72
+ mask_token_id: int = 151669,
73
+ **kwargs,
74
+ ):
75
+ # Force non-causal and no caching for RND1
76
+ kwargs["use_cache"] = False
77
+ kwargs["is_causal"] = False
78
+
79
+ super().__init__(**kwargs)
80
+
81
+ # Set defaults after pretrained init to prevent overrides
82
+ self.set_config_defaults()
83
+
84
+ # QoL: set attn impl directly from config
85
+ if "attn_implementation" in kwargs:
86
+ self._attn_implementation = kwargs["attn_implementation"]
87
+
88
+ # RND1-specific parameters
89
+ self.moe_backend = moe_backend
90
+ self.num_diffusion_steps = num_diffusion_steps
91
+ self.mask_token_id = mask_token_id
92
+
93
+ # Ensure bidirectional attention and no caching
94
+ self.is_causal = False
95
+ self.use_cache = False
96
+
97
+ def set_config_defaults(self):
98
+ """
99
+ Ensure model defaults are set according to final training checkpoint
100
+
101
+ Qwen3MoeConfig defaults don't match Qwen/Qwen3-30B-A3B settings from which
102
+ RND1 is derived.
103
+ """
104
+ for k, v in CONFIG_DEFAULTS.items():
105
+ setattr(self, k, v)
106
+
107
+ def to_dict(self):
108
+ """
109
+ Serializes configuration to dictionary with auto_map for Hub.
110
+
111
+ The auto_map ensures that when users load from HuggingFace Hub,
112
+ the correct custom classes are automatically resolved.
113
+ """
114
+ data = super().to_dict()
115
+ data.setdefault(
116
+ "auto_map",
117
+ {
118
+ "AutoConfig": "configuration_rnd.RND1Config",
119
+ "AutoModel": "modeling_rnd.RND1Model",
120
+ "AutoModelForMaskedLM": "modeling_rnd.RND1LM",
121
+ },
122
+ )
123
+ return data
rnd/generation_config.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Radical Numerics Inc.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0, found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ RND1 Generation Configuration.
8
+
9
+ This module defines the generation configuration for RND1 models,
10
+ controlling the diffusion-based generation process.
11
+ """
12
+
13
+ from typing import Optional
14
+ from transformers.generation.configuration_utils import GenerationConfig
15
+
16
+
17
+ class RND1GenerationConfig(GenerationConfig):
18
+ """
19
+ Configuration class for RND1 generation parameters.
20
+
21
+ This class extends the base GenerationConfig to include parameters
22
+ specific to diffusion-based language generation.
23
+
24
+ Args:
25
+ max_length: Maximum sequence length
26
+ num_diffusion_steps: Number of denoising steps in the diffusion process
27
+ mask_token_id: Token ID used for masking during diffusion
28
+ temperature: Temperature for sampling (higher = more random)
29
+ top_k: Optional top-k filtering
30
+ top_p: Optional nucleus (top-p) filtering
31
+ greedy: Whether to use greedy decoding (True) or stochastic sampling (False)
32
+ **kwargs: Additional arguments passed to GenerationConfig
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ max_length: int = 256,
38
+ num_diffusion_steps: int = 256,
39
+ mask_token_id: int = 151669,
40
+ temperature: float = 1.0,
41
+ top_k: Optional[int] = None,
42
+ top_p: Optional[float] = None,
43
+ greedy: bool = True,
44
+ bos_token_id: int = None,
45
+ eos_token_id: int = None,
46
+ pad_token_id: int = None,
47
+ use_cache: bool = False,
48
+ **kwargs,
49
+ ):
50
+ # Force no caching for RND generation
51
+ # kwargs['use_cache'] = False
52
+ kwargs.pop('use_cache', None)
53
+ super().__init__(
54
+ max_length=max_length,
55
+ bos_token_id=bos_token_id,
56
+ eos_token_id=eos_token_id,
57
+ pad_token_id=pad_token_id,
58
+ temperature=temperature,
59
+ top_k=top_k,
60
+ top_p=top_p,
61
+ do_sample=not greedy,
62
+ use_cache=False,
63
+ **kwargs,
64
+ )
65
+
66
+ # RND-specific parameters
67
+ self.num_diffusion_steps = num_diffusion_steps
68
+ self.mask_token_id = mask_token_id
69
+ self.greedy = greedy
70
+
71
+ def to_dict(self):
72
+ """Convert configuration to dictionary."""
73
+ output = super().to_dict()
74
+ output["num_diffusion_steps"] = self.num_diffusion_steps
75
+ output["mask_token_id"] = self.mask_token_id
76
+ output["greedy"] = self.greedy
77
+ return output
rnd/generation_utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Radical Numerics Inc.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0, found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ RND1 Generation Utilities.
8
+
9
+ This module provides generation utilities and mixins for RND1 models,
10
+ including the main GenerationMixin class that integrates with HuggingFace.
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from typing import Optional, Union, Dict, Any
16
+ from transformers import GenerationMixin as HFGenerationMixin
17
+ from transformers.generation import GenerationConfig
18
+
19
+ from .sampling import diffusion_sample, apply_top_k_filtering, apply_top_p_filtering
20
+
21
+
22
+ class RND1GenerationMixin(HFGenerationMixin):
23
+ """
24
+ Generation mixin for RND1 models.
25
+
26
+ This mixin provides generation methods compatible with HuggingFace's
27
+ generation API while using RND1's diffusion-based sampling internally.
28
+ """
29
+
30
+ def generate(
31
+ self,
32
+ inputs: Optional[torch.LongTensor] = None,
33
+ generation_config: Optional[GenerationConfig] = None,
34
+ # RND1-specific parameters
35
+ prefix_ids: Optional[torch.LongTensor] = None,
36
+ suffix_ids: Optional[torch.LongTensor] = None,
37
+ infill_length: Optional[int] = None,
38
+ return_dict_in_generate: Optional[bool] = None,
39
+ **kwargs, # Accept all kwargs to be compatible with pipelines
40
+ ) -> Union[torch.LongTensor, Dict[str, Any]]:
41
+ """
42
+ Generate text using RND1's diffusion-based sampling.
43
+
44
+ Follows HuggingFace's standard generate API, using diffusion sampling
45
+ internally. Supports both standard generation and infilling.
46
+
47
+ Args:
48
+ inputs: Input token IDs to use as prefix (standard HF parameter)
49
+ generation_config: Generation configuration object
50
+ prefix_ids: Alternative to inputs for infilling tasks
51
+ suffix_ids: Optional suffix for infilling tasks
52
+ infill_length: Length of infill region (for infilling)
53
+ return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput
54
+ **kwargs: Additional arguments (accepted for compatibility)
55
+
56
+ Returns:
57
+ Generated token IDs or GenerateDecoderOnlyOutput
58
+ """
59
+ if generation_config is not None:
60
+ gen_config = generation_config
61
+ model_kwargs = kwargs.copy()
62
+ else:
63
+ # Only prepare config from kwargs if no config was provided
64
+ gen_config, model_kwargs = self._prepare_generation_config(None, **kwargs)
65
+
66
+ device = next(self.parameters()).device
67
+
68
+ if inputs is not None:
69
+ prefix_ids = inputs.to(device)
70
+ elif prefix_ids is not None:
71
+ prefix_ids = prefix_ids.to(device)
72
+ else:
73
+ prefix_ids = None
74
+
75
+ if suffix_ids is not None:
76
+ suffix_ids = suffix_ids.to(device)
77
+
78
+ eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645)
79
+ pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", None)
80
+ bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None)
81
+ mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669))
82
+
83
+ if infill_length is not None and prefix_ids is not None:
84
+ # Infilling mode: use specified infill_length
85
+ prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0
86
+ suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0
87
+ seq_len = prefix_len + infill_length + suffix_len
88
+ else:
89
+ # Standard generation mode
90
+ if prefix_ids is not None:
91
+ prefix_len = prefix_ids.shape[1]
92
+ if gen_config.max_new_tokens is not None:
93
+ seq_len = prefix_len + gen_config.max_new_tokens
94
+ else:
95
+ seq_len = gen_config.max_length or self.config.max_position_embeddings
96
+ else:
97
+ seq_len = gen_config.max_length or self.config.max_position_embeddings
98
+
99
+ num_diffusion_steps = getattr(gen_config, "num_diffusion_steps",
100
+ getattr(self.config, "num_diffusion_steps", 256))
101
+
102
+ temperature = float(getattr(gen_config, "temperature", 1.0))
103
+ top_k = getattr(gen_config, "top_k", None)
104
+ top_p = getattr(gen_config, "top_p", None)
105
+
106
+ greedy = getattr(gen_config, "greedy",
107
+ not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True)
108
+
109
+ generator = model_kwargs.get("generator", None)
110
+ if generator is None:
111
+ seed = getattr(gen_config, 'seed', None)
112
+ if seed is not None:
113
+ generator = torch.Generator(device=device)
114
+ generator.manual_seed(seed)
115
+
116
+ with torch.inference_mode():
117
+ sequences = diffusion_sample(
118
+ model=self,
119
+ seq_len=seq_len,
120
+ num_steps=num_diffusion_steps,
121
+ mask_token_id=mask_token_id,
122
+ temperature=temperature,
123
+ top_k=top_k,
124
+ top_p=top_p,
125
+ greedy=greedy,
126
+ prefix_ids=prefix_ids,
127
+ suffix_ids=suffix_ids,
128
+ infill_length=infill_length,
129
+ eos_token_id=eos_token_id,
130
+ pad_token_id=pad_token_id,
131
+ bos_token_id=bos_token_id,
132
+ device=device,
133
+ generator=generator,
134
+ visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs
135
+ )
136
+
137
+ if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False):
138
+ from transformers.generation.utils import GenerateDecoderOnlyOutput
139
+ return GenerateDecoderOnlyOutput(sequences=sequences)
140
+
141
+ return sequences
142
+
143
+ def generate_with_visualization(
144
+ self,
145
+ tokenizer,
146
+ inputs: Optional[torch.LongTensor] = None,
147
+ generation_config: Optional[GenerationConfig] = None,
148
+ suffix_ids: Optional[torch.LongTensor] = None,
149
+ infill_length: Optional[int] = None,
150
+ generator: Optional[torch.Generator] = None,
151
+ **kwargs,
152
+ ) -> torch.LongTensor:
153
+ """
154
+ Generate with live visualization (for demos).
155
+
156
+ This method requires a tokenizer to display the generation process.
157
+ For production use, prefer `generate()`.
158
+
159
+ Args:
160
+ tokenizer: Tokenizer for decoding tokens to text
161
+ inputs: Input token IDs to use as prefix
162
+ generation_config: Generation configuration object
163
+ suffix_ids: Optional suffix token IDs
164
+ infill_length: Length of infill region
165
+ generator: Random generator for reproducibility
166
+ **kwargs: Additional arguments for backward compatibility
167
+
168
+ Returns:
169
+ Generated token IDs as LongTensor
170
+ """
171
+ from .terminal_visualizer import TerminalVisualizer
172
+ visualizer = TerminalVisualizer(tokenizer, show_visualization=True)
173
+
174
+ return self.generate(
175
+ inputs=inputs,
176
+ generation_config=generation_config,
177
+ suffix_ids=suffix_ids,
178
+ infill_length=infill_length,
179
+ generator=generator,
180
+ visualizer=visualizer,
181
+ return_dict_in_generate=False,
182
+ **kwargs,
183
+ )
184
+
185
+ def prepare_inputs_for_generation(
186
+ self,
187
+ input_ids: torch.LongTensor,
188
+ **kwargs,
189
+ ) -> Dict[str, Any]:
190
+ """
191
+ Prepare inputs for generation (required by HuggingFace).
192
+
193
+ For RND1, we don't use the standard autoregressive generation,
194
+ so this just returns the input_ids.
195
+ """
196
+ return {"input_ids": input_ids}
rnd/modeling_rnd.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Radical Numerics Inc.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0, found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ RND1 model implementation.
8
+
9
+ This module implements the RND1 architecture with bidirectional attention for
10
+ diffusion-based language modeling. Includes support for Mixture of Experts (MoE)
11
+ with multiple backend options (HF, FlashInfer, SGLang).
12
+
13
+ Based on the Qwen3Moe architecture:
14
+ https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import os
20
+ from typing import Optional, Tuple, List, Union
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+ from transformers.utils import logging
26
+ from transformers.cache_utils import Cache
27
+ from transformers.modeling_outputs import (
28
+ MoeModelOutputWithPast,
29
+ MaskedLMOutput,
30
+ )
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.configuration_utils import PretrainedConfig
33
+ from transformers.generation import GenerationConfig
34
+
35
+ from .configuration_rnd import RND1Config
36
+ from .generation_utils import RND1GenerationMixin
37
+ from .generation_config import RND1GenerationConfig
38
+
39
+ from transformers.models.qwen3_moe.modeling_qwen3_moe import (
40
+ Qwen3MoeConfig,
41
+ Qwen3MoeRMSNorm,
42
+ Qwen3MoeRotaryEmbedding,
43
+ Qwen3MoeSparseMoeBlock,
44
+ Qwen3MoeMLP,
45
+ apply_rotary_pos_emb
46
+ )
47
+ import torch.nn.functional as F
48
+
49
+ try:
50
+ import flashinfer.fused_moe as fused_moe
51
+ except Exception:
52
+ fused_moe = None
53
+
54
+ try:
55
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe as sglang_fused_moe
56
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
57
+ except Exception:
58
+ sglang_fused_moe = None
59
+ StandardTopKOutput = None
60
+
61
+ logger = logging.get_logger(__name__)
62
+
63
+
64
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
65
+ """Expand key/value heads to match query heads for grouped-query attention."""
66
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
67
+ if n_rep == 1:
68
+ return hidden_states
69
+ hidden_states = hidden_states[:, :, None, :, :].expand(
70
+ batch, num_key_value_heads, n_rep, slen, head_dim
71
+ )
72
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
73
+
74
+
75
+ class RND1Attention(nn.Module):
76
+ """RND1 attention layer with bidirectional attention for diffusion modeling."""
77
+
78
+ def __init__(self, config: RND1Config, layer_idx: int):
79
+ super().__init__()
80
+ self.config = config
81
+ self.layer_idx = layer_idx
82
+
83
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
84
+ self.num_heads = config.num_attention_heads
85
+ self.num_key_value_heads = config.num_key_value_heads
86
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
87
+
88
+ self.scaling = self.head_dim ** -0.5
89
+ self.attention_dropout = config.attention_dropout
90
+ self.is_causal = False
91
+
92
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
93
+ self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
94
+ self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
95
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
96
+
97
+ self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
98
+ self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
99
+
100
+ self.sliding_window = getattr(config, "sliding_window", None)
101
+
102
+ self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
103
+
104
+ def forward(
105
+ self,
106
+ hidden_states: torch.Tensor,
107
+ attention_mask: Optional[torch.Tensor] = None,
108
+ position_ids: Optional[torch.LongTensor] = None,
109
+ past_key_values: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
110
+ cache_position: Optional[torch.LongTensor] = None,
111
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
112
+ dual_cache: Optional[bool] = False,
113
+ replace_position: Optional[torch.Tensor] = None,
114
+ **kwargs,
115
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]]]:
116
+
117
+ bsz, q_len, _ = hidden_states.size()
118
+ input_shape = hidden_states.shape[:-1]
119
+ hidden_shape = (*input_shape, -1, self.head_dim)
120
+
121
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
122
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
123
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
124
+
125
+ cos, sin = position_embeddings
126
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
127
+
128
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
129
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
130
+
131
+ use_sdpa = (getattr(self.config, "_attn_implementation", "eager") == "sdpa")
132
+
133
+ if use_sdpa:
134
+ if attention_mask is not None and isinstance(attention_mask, torch.Tensor):
135
+ if attention_mask.dtype not in [torch.bool, torch.float32, torch.float16, torch.bfloat16]:
136
+ attention_mask = attention_mask.to(dtype=query_states.dtype)
137
+
138
+ assert not self.is_causal, f"Attention layer {self.layer_idx} is causal"
139
+ attn_out = torch.nn.functional.scaled_dot_product_attention(
140
+ query_states, key_states, value_states,
141
+ attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
142
+ dropout_p=self.attention_dropout if self.training else 0.0,
143
+ is_causal=self.is_causal,
144
+ )
145
+ attn_out = attn_out.transpose(1, 2).contiguous()
146
+ attn_out = attn_out.view(bsz, q_len, self.num_heads * self.head_dim)
147
+ attn_out = self.o_proj(attn_out)
148
+ return attn_out, None
149
+
150
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
151
+
152
+ if attention_mask is not None:
153
+ attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
154
+
155
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
156
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
157
+
158
+ attn_out = torch.matmul(attn_weights, value_states)
159
+ attn_out = attn_out.transpose(1, 2).contiguous().view(hidden_states.size(0), hidden_states.size(1), -1)
160
+ attn_out = self.o_proj(attn_out)
161
+
162
+ return attn_out, None
163
+
164
+
165
+ class RND1DecoderLayer(nn.Module):
166
+ """RND1 decoder layer with bidirectional attention for diffusion language modeling."""
167
+
168
+ def __init__(self, config: RND1Config, layer_idx: int):
169
+ super().__init__()
170
+ self.self_attn = RND1Attention(config, layer_idx)
171
+ self.mlp = RND1SparseMoeBlock(config)
172
+ self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
173
+ self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
174
+
175
+ def forward(
176
+ self,
177
+ hidden_states: torch.Tensor,
178
+ attention_mask: Optional[torch.Tensor] = None,
179
+ position_ids: Optional[torch.LongTensor] = None,
180
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
181
+ replace_position: Optional[torch.Tensor] = None,
182
+ **kwargs,
183
+ ) -> Tuple[torch.FloatTensor, Optional[torch.Tensor]]:
184
+ residual = hidden_states
185
+ hidden_states = self.input_layernorm(hidden_states)
186
+
187
+ attn_out, attn_weights = self.self_attn(
188
+ hidden_states,
189
+ attention_mask=attention_mask,
190
+ position_ids=position_ids,
191
+ position_embeddings=position_embeddings,
192
+ replace_position=replace_position,
193
+ )
194
+ hidden_states = residual + attn_out
195
+
196
+ residual = hidden_states
197
+ hidden_states = self.post_attention_layernorm(hidden_states)
198
+ ff_out = self.mlp(hidden_states)
199
+ if isinstance(ff_out, tuple):
200
+ ff_out = ff_out[0]
201
+ hidden_states = residual + ff_out
202
+
203
+ return hidden_states, attn_weights
204
+
205
+
206
+ class RND1SparseMoeBlock(nn.Module):
207
+ """RND1 Sparse MoE block with multiple backend support (HF, FlashInfer, SGLang)."""
208
+
209
+ def __init__(self, config: RND1Config):
210
+ super().__init__()
211
+ self.config = config
212
+ self.backend = getattr(config, "moe_backend", "hf")
213
+ self.num_experts = config.num_experts
214
+ self.top_k = config.num_experts_per_tok
215
+ self.norm_topk_prob = config.norm_topk_prob
216
+ self.hidden_size = config.hidden_size
217
+ self.intermediate_size = getattr(config, "moe_intermediate_size", config.intermediate_size)
218
+
219
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
220
+ self.experts = nn.ModuleList(
221
+ [Qwen3MoeMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)]
222
+ )
223
+
224
+ # Cached weight tensors for optimized backends
225
+ self._flashinfer_fc1_weights = None
226
+ self._flashinfer_fc2_weights = None
227
+ self._sglang_w1 = None
228
+ self._sglang_w2 = None
229
+ if self.backend == "sglang":
230
+ if sglang_fused_moe is None or StandardTopKOutput is None:
231
+ raise RuntimeError("sglang is not available, cannot use sglang backend")
232
+ elif self.backend == "flashinfer":
233
+ if fused_moe is None:
234
+ raise RuntimeError("flashinfer is not available, cannot use flashinfer backend")
235
+
236
+ def _initialize_flashinfer_weights(self):
237
+ """Initialize FlashInfer-compatible weight format."""
238
+ fc1_list = []
239
+ fc2_list = []
240
+
241
+ for expert in self.experts:
242
+ gate_w = expert.gate_proj.weight # [I, H]
243
+ up_w = expert.up_proj.weight # [I, H]
244
+ down_w = expert.down_proj.weight # [H, I]
245
+ # FlashInfer expects [up; gate] ordering
246
+ fc1_list.append(torch.cat([up_w, gate_w], dim=0)) # [2I, H]
247
+ fc2_list.append(down_w) # [H, I]
248
+
249
+ self._flashinfer_fc1_weights = torch.stack(fc1_list, dim=0).contiguous()
250
+ self._flashinfer_fc2_weights = torch.stack(fc2_list, dim=0).contiguous()
251
+
252
+ def _initialize_sglang_weights(self):
253
+ """Initialize SGLang-compatible weight format."""
254
+ w1_list = []
255
+ w2_list = []
256
+
257
+ for expert in self.experts:
258
+ gate_w = expert.gate_proj.weight # [I, H]
259
+ up_w = expert.up_proj.weight # [I, H]
260
+ down_w = expert.down_proj.weight # [H, I]
261
+ w1 = torch.cat([gate_w, up_w], dim=0) # [2I, H]
262
+ w1_list.append(w1)
263
+ w2_list.append(down_w)
264
+
265
+ self._sglang_w1 = torch.stack(w1_list, dim=0).contiguous()
266
+ self._sglang_w2 = torch.stack(w2_list, dim=0).contiguous()
267
+
268
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
269
+ """Forward pass with expert routing and computation."""
270
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
271
+ x = hidden_states.view(-1, hidden_dim)
272
+
273
+ # Expert routing
274
+ router_logits = self.gate(x)
275
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
276
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
277
+ if self.norm_topk_prob:
278
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
279
+
280
+ if self.backend == "hf":
281
+ final_hidden_states = torch.zeros(
282
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
283
+ )
284
+
285
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
286
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
287
+
288
+ for expert_idx in expert_hit:
289
+ expert_layer = self.experts[expert_idx]
290
+ idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
291
+ current_state = x[top_x]
292
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
293
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
294
+ out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
295
+ return out, router_logits.view(batch_size, sequence_length, -1)
296
+
297
+ elif self.backend == "flashinfer":
298
+ if self._flashinfer_fc1_weights is None or self._flashinfer_fc2_weights is None:
299
+ self._initialize_flashinfer_weights()
300
+
301
+ result = fused_moe.cutlass_fused_moe(
302
+ input=x,
303
+ token_selected_experts=selected_experts.to(torch.int),
304
+ token_final_scales=routing_weights.to(torch.float32),
305
+ fc1_expert_weights=self._flashinfer_fc1_weights,
306
+ fc2_expert_weights=self._flashinfer_fc2_weights,
307
+ output_dtype=x.dtype,
308
+ quant_scales=None,
309
+ )
310
+ if isinstance(result, (list, tuple)):
311
+ out_flat = result[0]
312
+ else:
313
+ out_flat = result
314
+ out = out_flat.view(batch_size, sequence_length, hidden_dim)
315
+ return out, router_logits.view(batch_size, sequence_length, -1)
316
+
317
+ elif self.backend == "sglang":
318
+ if self._sglang_w1 is None or self._sglang_w2 is None:
319
+ self._initialize_sglang_weights()
320
+
321
+ topk_output = StandardTopKOutput(
322
+ topk_weights=routing_weights,
323
+ topk_ids=selected_experts,
324
+ router_logits=router_logits,
325
+ )
326
+
327
+ out_flat = sglang_fused_moe(
328
+ hidden_states=x,
329
+ w1=self._sglang_w1,
330
+ w2=self._sglang_w2,
331
+ topk_output=topk_output,
332
+ )
333
+ out = out_flat.view(batch_size, sequence_length, hidden_dim)
334
+ return out, router_logits.view(batch_size, sequence_length, -1)
335
+
336
+ else:
337
+ raise ValueError(f"Invalid backend: {self.backend}")
338
+
339
+
340
+ class RND1PreTrainedModel(PreTrainedModel):
341
+ """Base class for RND1 models with weight initialization and loading support."""
342
+ config_class = RND1Config
343
+ base_model_prefix = "model"
344
+ supports_gradient_checkpointing = True
345
+ _no_split_modules = ["RND1DecoderLayer"]
346
+ _skip_keys_device_placement = "past_key_values"
347
+ _supports_flash_attn_2 = True
348
+ _supports_sdpa = True
349
+ _supports_cache_class = True
350
+ _supports_quantized_cache = True
351
+ _supports_static_cache = True
352
+
353
+ def _init_weights(self, module):
354
+ """Initialize weights using normal distribution."""
355
+ std = self.config.initializer_range
356
+ if isinstance(module, nn.Linear):
357
+ module.weight.data.normal_(mean=0.0, std=std)
358
+ if module.bias is not None:
359
+ module.bias.data.zero_()
360
+ elif isinstance(module, nn.Embedding):
361
+ module.weight.data.normal_(mean=0.0, std=std)
362
+ if module.padding_idx is not None:
363
+ module.weight.data[module.padding_idx].zero_()
364
+
365
+ @classmethod
366
+ def from_pretrained(
367
+ cls,
368
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
369
+ *model_args,
370
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
371
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
372
+ ignore_mismatched_sizes: bool = False,
373
+ force_download: bool = False,
374
+ local_files_only: bool = False,
375
+ token: Optional[Union[str, bool]] = None,
376
+ revision: str = "main",
377
+ use_safetensors: Optional[bool] = None,
378
+ weights_only: bool = True,
379
+ **kwargs,
380
+ ):
381
+ """Load pretrained model with generation config."""
382
+ _model = super().from_pretrained(
383
+ pretrained_model_name_or_path,
384
+ *model_args,
385
+ config=config,
386
+ cache_dir=cache_dir,
387
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
388
+ force_download=force_download,
389
+ local_files_only=local_files_only,
390
+ token=token,
391
+ revision=revision,
392
+ use_safetensors=use_safetensors,
393
+ weights_only=weights_only,
394
+ **kwargs,
395
+ )
396
+
397
+ resume_download = kwargs.get("resume_download", None)
398
+ proxies = kwargs.get("proxies", None)
399
+ subfolder = kwargs.get("subfolder", "")
400
+ from_auto_class = kwargs.get("_from_auto", False)
401
+ from_pipeline = kwargs.get("_from_pipeline", None)
402
+
403
+ _model.generation_config = GenerationConfig.from_pretrained(
404
+ pretrained_model_name_or_path,
405
+ cache_dir=cache_dir,
406
+ force_download=force_download,
407
+ resume_download=resume_download,
408
+ proxies=proxies,
409
+ local_files_only=local_files_only,
410
+ token=token,
411
+ revision=revision,
412
+ subfolder=subfolder,
413
+ _from_auto=from_auto_class,
414
+ _from_pipeline=from_pipeline,
415
+ )
416
+
417
+ return _model
418
+
419
+
420
+ class RND1Model(RND1PreTrainedModel):
421
+ """RND1 transformer model with bidirectional attention for diffusion language modeling."""
422
+
423
+ def __init__(self, config: RND1Config):
424
+ super().__init__(config)
425
+
426
+ self.padding_idx = config.pad_token_id
427
+ self.vocab_size = config.vocab_size
428
+
429
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
430
+ self.layers = nn.ModuleList([RND1DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
431
+ self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
432
+
433
+ self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
434
+
435
+ self.post_init()
436
+
437
+
438
+ def forward(
439
+ self,
440
+ input_ids: Optional[torch.LongTensor] = None,
441
+ attention_mask: Optional[torch.Tensor] = None,
442
+ position_ids: Optional[torch.LongTensor] = None,
443
+ inputs_embeds: Optional[torch.FloatTensor] = None,
444
+ **kwargs,
445
+ ) -> MoeModelOutputWithPast:
446
+ """Forward pass through the RND1 model."""
447
+
448
+ if (input_ids is None) == (inputs_embeds is None):
449
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
450
+
451
+ if inputs_embeds is None:
452
+ inputs_embeds = self.embed_tokens(input_ids)
453
+
454
+ if position_ids is None:
455
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
456
+
457
+ position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
458
+
459
+ hidden_states = inputs_embeds
460
+
461
+ for layer in self.layers:
462
+ hidden_states, _ = layer(
463
+ hidden_states,
464
+ attention_mask=attention_mask,
465
+ position_ids=position_ids,
466
+ position_embeddings=position_embeddings,
467
+ )
468
+
469
+ hidden_states = self.norm(hidden_states)
470
+
471
+ return MoeModelOutputWithPast(
472
+ last_hidden_state=hidden_states,
473
+ router_logits=None,
474
+ )
475
+
476
+
477
+ class RND1LM(RND1PreTrainedModel, RND1GenerationMixin):
478
+ """Radical Numerics Diffusion Language Model with bidirectional attention."""
479
+
480
+ def __init__(self, config: RND1Config):
481
+ super().__init__(config)
482
+ self.model = RND1Model(config)
483
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
484
+ self.post_init()
485
+
486
+ def get_input_embeddings(self):
487
+ """Get the input embeddings layer."""
488
+ return self.model.embed_tokens
489
+
490
+ def set_input_embeddings(self, value):
491
+ """Set the input embeddings layer."""
492
+ self.model.embed_tokens = value
493
+
494
+ def get_output_embeddings(self):
495
+ """Get the output embeddings layer (lm_head)."""
496
+ return self.lm_head
497
+
498
+ def set_output_embeddings(self, new_embeddings):
499
+ """Set the output embeddings layer (lm_head)."""
500
+ self.lm_head = new_embeddings
501
+
502
+ @classmethod
503
+ def can_generate(cls) -> bool:
504
+ """Indicates this model can generate text."""
505
+ return True
506
+
507
+ def forward(
508
+ self,
509
+ input_ids: Optional[torch.LongTensor] = None,
510
+ attention_mask: Optional[torch.Tensor] = None,
511
+ position_ids: Optional[torch.LongTensor] = None,
512
+ inputs_embeds: Optional[torch.FloatTensor] = None,
513
+ labels: Optional[torch.LongTensor] = None,
514
+ **kwargs,
515
+ ) -> MaskedLMOutput:
516
+ """Forward pass with optional loss computation."""
517
+ outputs = self.model(
518
+ input_ids=input_ids,
519
+ attention_mask=attention_mask,
520
+ position_ids=position_ids,
521
+ inputs_embeds=inputs_embeds,
522
+ **kwargs,
523
+ )
524
+ logits = self.lm_head(outputs.last_hidden_state)
525
+
526
+ loss = None
527
+ if labels is not None:
528
+ loss_fct = nn.CrossEntropyLoss()
529
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
530
+
531
+ return MaskedLMOutput(
532
+ loss=loss,
533
+ logits=logits,
534
+ )
rnd/sampling.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Radical Numerics Inc.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0, found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ RND1 sampling module for masked diffusion generation.
8
+
9
+ This module implements entropy-based token selection for iterative denoising
10
+ in diffusion language models. Supports both greedy and stochastic sampling
11
+ with optional prefix/suffix constraints and infilling.
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from typing import Optional, Tuple, Union
18
+
19
+
20
+ def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
21
+ """
22
+ Apply top-k filtering to logits: with non-top-k values set to -inf
23
+ """
24
+ top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1)
25
+ filtered_logits = torch.full_like(logits, float('-inf'))
26
+ filtered_logits.scatter_(-1, top_k_indices, top_k_values)
27
+ return filtered_logits
28
+
29
+
30
+ def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
31
+ """
32
+ Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
33
+ """
34
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
35
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
36
+
37
+ # Remove tokens with cumulative probability above threshold
38
+ sorted_indices_to_remove = cumulative_probs > p
39
+ sorted_indices_to_remove[..., 0] = False # Keep at least one token
40
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
41
+
42
+ indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
43
+ return logits.masked_fill(indices_to_remove, float('-inf'))
44
+
45
+
46
+ @torch.no_grad()
47
+ def diffusion_sample(
48
+ model: nn.Module,
49
+ seq_len: int = 256,
50
+ num_steps: int = 256,
51
+ top_k: Optional[int] = None,
52
+ top_p: Optional[float] = None,
53
+ temperature: float = 1.0,
54
+ greedy: bool = True,
55
+ mask_token_id: int = 151669,
56
+ prefix_ids: Optional[torch.LongTensor] = None,
57
+ suffix_ids: Optional[torch.LongTensor] = None,
58
+ infill_length: Optional[int] = None,
59
+ eos_token_id: int = 151645,
60
+ pad_token_id: Optional[int] = None,
61
+ bos_token_id: Optional[int] = None,
62
+ device: Optional[Union[str, torch.device]] = None,
63
+ generator: Optional[torch.Generator] = None,
64
+ visualizer: Optional['TerminalVisualizer'] = None,
65
+ ) -> torch.LongTensor:
66
+ """
67
+ Perform masked diffusion sampling with entropy-based token selection.
68
+
69
+ Args:
70
+ model: The RND1 language model
71
+ seq_len: Target sequence length
72
+ num_steps: Number of denoising steps
73
+ top_k: Optional top-k filtering for sampling (None = no filtering)
74
+ top_p: Optional nucleus (top-p) filtering for sampling (None = no filtering)
75
+ When both top_k and top_p are set, top_k is applied first, then top_p
76
+ temperature: Temperature for sampling (higher = more random, lower = more deterministic)
77
+ Values close to 0 are clamped to 1e-8 to avoid division by zero
78
+ greedy: Whether to use greedy sampling (True) or stochastic (False)
79
+ mask_token_id: Token ID for masked positions (default: 151669)
80
+ prefix_ids: Optional prefix token IDs to preserve
81
+ suffix_ids: Optional suffix token IDs to preserve
82
+ infill_length: Length of infill region between prefix/suffix
83
+ eos_token_id: End of sequence token ID (default: 151645)
84
+ pad_token_id: Padding token ID (default: None, uses 0 if needed)
85
+ bos_token_id: Beginning of sequence token ID (default: None)
86
+ device: Device for computation (None = infer from model)
87
+ generator: Optional torch generator for reproducible sampling
88
+ visualizer: Optional TerminalVisualizer for live visualization
89
+
90
+ Returns:
91
+ Generated token IDs as LongTensor
92
+ """
93
+ model.eval()
94
+
95
+ if device is None:
96
+ device = next(model.parameters()).device
97
+ else:
98
+ device = torch.device(device)
99
+ dtype = next(model.parameters()).dtype
100
+
101
+ if pad_token_id is None:
102
+ pad_token_id = 0
103
+
104
+ # Build initial masked sequence
105
+ # When prefix_ids is provided, we create a sequence of length seq_len where:
106
+ # - The prefix occupies the first pre_len positions
107
+ # - The remaining (seq_len - pre_len) positions are filled with mask tokens to be generated
108
+ if prefix_ids is not None or suffix_ids is not None:
109
+ if prefix_ids is not None:
110
+ prefix_ids = prefix_ids.to(device) if isinstance(prefix_ids, torch.Tensor) else torch.tensor(prefix_ids, device=device)
111
+ pre_len = prefix_ids.shape[-1] if prefix_ids.dim() > 0 else 0
112
+ else:
113
+ pre_len = 0
114
+
115
+ if suffix_ids is not None:
116
+ suffix_ids = suffix_ids.to(device) if isinstance(suffix_ids, torch.Tensor) else torch.tensor(suffix_ids, device=device)
117
+ suf_len = suffix_ids.shape[-1] if suffix_ids.dim() > 0 else 0
118
+ else:
119
+ suf_len = 0
120
+
121
+ reserved = (1 if bos_token_id is not None else 0) + (1 if eos_token_id is not None else 0)
122
+ used = pre_len + suf_len + reserved
123
+
124
+ if used > seq_len:
125
+ raise ValueError(
126
+ f"Combined length of prefix ({pre_len}), suffix ({suf_len}), "
127
+ f"and special tokens ({reserved}) = {used} exceeds seq_len ({seq_len}). "
128
+ f"Please increase seq_len or reduce input lengths."
129
+ )
130
+ elif used == seq_len:
131
+ raise ValueError(
132
+ f"No space for generation: prefix ({pre_len}) + suffix ({suf_len}) "
133
+ f"+ special tokens ({reserved}) = seq_len ({seq_len}). "
134
+ f"Need at least 1 position for generation."
135
+ )
136
+
137
+ infill_length = min(infill_length or (seq_len - used), seq_len - used)
138
+
139
+ x = torch.full((1, seq_len), pad_token_id, dtype=torch.long, device=device)
140
+ pos = 0
141
+ if bos_token_id is not None:
142
+ x[0, pos] = bos_token_id; pos += 1
143
+ if pre_len > 0:
144
+ x[0, pos:pos+pre_len] = prefix_ids.flatten()[:pre_len]; pos += pre_len
145
+ fill_start, fill_end = pos, pos + infill_length
146
+ x[0, fill_start:fill_end] = mask_token_id
147
+ pos = fill_end
148
+ if suf_len > 0:
149
+ x[0, pos:pos+suf_len] = suffix_ids.flatten()[:suf_len]; pos += suf_len
150
+
151
+ init_maskable = torch.zeros_like(x, dtype=torch.bool)
152
+ init_maskable[0, fill_start:fill_end] = True
153
+ else:
154
+ x = torch.full((1, seq_len), mask_token_id, dtype=torch.long, device=device)
155
+ if bos_token_id is not None:
156
+ x[0, 0] = bos_token_id
157
+ if eos_token_id is not None:
158
+ x[0, -1] = eos_token_id
159
+ init_maskable = x.eq(mask_token_id)
160
+
161
+ if bos_token_id is not None:
162
+ init_maskable[:, 0] = False
163
+ if eos_token_id is not None:
164
+ init_maskable &= x.ne(eos_token_id)
165
+ init_maskable &= x.ne(pad_token_id)
166
+
167
+ maskable = init_maskable.clone()
168
+ xt = x.clone()
169
+
170
+ if visualizer:
171
+ visualizer.start_visualization(xt, maskable, num_steps)
172
+
173
+ def forward_scores(tokens):
174
+ """Compute predictions and entropy scores for next tokens."""
175
+ # Try with input_ids parameter first (standard HF models)
176
+ try:
177
+ model_output = model(input_ids=tokens)
178
+ except TypeError:
179
+ # Fall back to positional argument
180
+ model_output = model(tokens)
181
+
182
+ # Apply temperature scaling (with safety for near-zero temperature)
183
+ safe_temperature = max(temperature, 1e-8) # Prevent division by zero
184
+ logits = model_output.logits / safe_temperature
185
+
186
+ # Apply filtering strategies
187
+ # Note: When both top_k and top_p are provided, they are applied sequentially:
188
+ # First top_k filters to k tokens, then top_p filters from those k tokens
189
+ if top_k is not None and top_k > 0:
190
+ logits = apply_top_k_filtering(logits, top_k)
191
+
192
+ if top_p is not None and 0 < top_p < 1.0:
193
+ logits = apply_top_p_filtering(logits, top_p)
194
+
195
+ # Convert to log probabilities
196
+ logp = torch.log_softmax(logits, dim=-1)
197
+
198
+ # Greedy or stochastic sampling
199
+ if greedy:
200
+ pred_next = logp.argmax(-1)
201
+ else:
202
+ pred_next = torch.distributions.Categorical(logits=logp).sample(generator=generator)
203
+
204
+ conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
205
+
206
+ p = logp.exp()
207
+ ent_next = -(p * logp).sum(-1)
208
+
209
+ # Shift predictions: pos i predicts token i+1
210
+ pred_i = tokens.clone()
211
+ conf_i = torch.full_like(conf_next, torch.finfo(conf_next.dtype).min)
212
+ ent_i = torch.zeros_like(ent_next)
213
+
214
+ pred_i[:, 1:] = pred_next[:, :-1]
215
+ conf_i[:, 1:] = conf_next[:, :-1]
216
+ ent_i[:, 1:] = ent_next[:, :-1]
217
+
218
+ return pred_i, conf_i, ent_i
219
+
220
+ pred_i, conf_i, ent_i = forward_scores(xt)
221
+ total_masked = init_maskable.sum(1, keepdim=True)
222
+ finf = torch.finfo(conf_i.dtype)
223
+
224
+ for step in range(num_steps - 1, 0, -1):
225
+ rate = step / num_steps
226
+ cutoff_len = (total_masked * rate).long().clamp(min=0)
227
+
228
+ # Choose HIGH-entropy tokens to keep masked
229
+ sel_scores = ent_i.masked_fill(~maskable, -finf.max)
230
+ B, L = sel_scores.shape
231
+ k_max = cutoff_len.max().item()
232
+ if k_max > 0:
233
+ sss, idx = torch.topk(sel_scores, k_max, dim=-1, largest=True)
234
+ keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
235
+ for b in range(B):
236
+ k_b = int(cutoff_len[b].item())
237
+ if k_b > 0:
238
+ keep_mask[b, idx[b, :k_b]] = True
239
+ else:
240
+ keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
241
+
242
+ to_unmask = maskable & ~keep_mask
243
+ if to_unmask.any():
244
+ xt[to_unmask] = pred_i[to_unmask]
245
+ maskable[to_unmask] = False
246
+
247
+ if visualizer:
248
+ visualizer.update_step(xt, maskable, num_steps - step, ent_i, conf_i)
249
+
250
+ if maskable.any():
251
+ pred_i, conf_i, ent_i = forward_scores(xt)
252
+
253
+ if maskable.any():
254
+ xt[maskable] = pred_i[maskable]
255
+
256
+ if visualizer:
257
+ visualizer.stop_visualization()
258
+
259
+ return xt
rnd/terminal_visualizer.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 Radical Numerics Inc.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0, found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """
7
+ Terminal visualization for RND1 generation.
8
+
9
+ This module provides real-time visualization of the diffusion denoising process,
10
+ showing token evolution and generation progress in the terminal using rich
11
+ formatting when available.
12
+ """
13
+
14
+ import torch
15
+ from typing import Optional
16
+ from tqdm import tqdm
17
+
18
+ try:
19
+ from rich.console import Console
20
+ from rich.live import Live
21
+ from rich.text import Text
22
+ from rich.panel import Panel
23
+ from rich.progress import Progress, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
24
+ from rich.layout import Layout
25
+ RICH_AVAILABLE = True
26
+ except ImportError:
27
+ RICH_AVAILABLE = False
28
+
29
+
30
+ class TerminalVisualizer:
31
+ """
32
+ Rich-based visualization for diffusion process with live updates.
33
+
34
+ Provides real-time visualization of the token denoising process during
35
+ diffusion-based language generation, with colored highlighting of masked
36
+ positions and progress tracking.
37
+ """
38
+
39
+ def __init__(self, tokenizer, show_visualization: bool = True):
40
+ """
41
+ Initialize the terminal visualizer.
42
+
43
+ Args:
44
+ tokenizer: The tokenizer for decoding tokens to text
45
+ show_visualization: Whether to show visualization (requires rich)
46
+ """
47
+ self.tokenizer = tokenizer
48
+ self.show_visualization = show_visualization and RICH_AVAILABLE
49
+ if not RICH_AVAILABLE and show_visualization:
50
+ print("Warning: Install 'rich' for better visualization. Falling back to simple progress bar.")
51
+ self.show_visualization = False
52
+
53
+ if self.show_visualization:
54
+ self.console = Console()
55
+ self.live = None
56
+ self.progress = None
57
+ self.layout = None
58
+ else:
59
+ self.pbar = None
60
+
61
+ self.current_tokens = None
62
+ self.mask_positions = None
63
+ self.total_steps = 0
64
+ self.current_step = 0
65
+
66
+ def start_visualization(self, initial_tokens: torch.LongTensor, mask_positions: torch.BoolTensor, total_steps: int):
67
+ """
68
+ Start the visualization.
69
+
70
+ Args:
71
+ initial_tokens: Initial token IDs (possibly masked)
72
+ mask_positions: Boolean mask indicating which positions are masked
73
+ total_steps: Total number of diffusion steps
74
+ """
75
+ if not self.show_visualization:
76
+ self.pbar = tqdm(total=total_steps, desc="Diffusion")
77
+ return
78
+
79
+ self.current_tokens = initial_tokens.clone()
80
+ self.mask_positions = mask_positions
81
+ self.total_steps = total_steps
82
+ self.current_step = 0
83
+
84
+ self.layout = Layout()
85
+ self.layout.split_column(
86
+ Layout(name="header", size=3),
87
+ Layout(name="text", ratio=1),
88
+ Layout(name="progress", size=3)
89
+ )
90
+
91
+ self.progress = Progress(
92
+ TextColumn("[bold blue]Diffusion"),
93
+ BarColumn(),
94
+ MofNCompleteColumn(),
95
+ TextColumn("•"),
96
+ TextColumn("[cyan]Masks: {task.fields[masks]}"),
97
+ TimeRemainingColumn(),
98
+ )
99
+ self.progress_task = self.progress.add_task(
100
+ "Generating",
101
+ total=total_steps,
102
+ masks=mask_positions.sum().item()
103
+ )
104
+
105
+ self.live = Live(self.layout, console=self.console, refresh_per_second=4)
106
+ self.live.start()
107
+ self._update_display()
108
+
109
+ def update_step(self, tokens: torch.LongTensor, maskable: Optional[torch.BoolTensor], step: int,
110
+ entropy: Optional[torch.FloatTensor] = None, confidence: Optional[torch.FloatTensor] = None):
111
+ """
112
+ Update visualization for current step.
113
+
114
+ Args:
115
+ tokens: Current token IDs
116
+ maskable: Boolean mask of remaining masked positions
117
+ step: Current step number
118
+ entropy: Optional entropy scores for each position
119
+ confidence: Optional confidence scores for each position
120
+ """
121
+ if not self.show_visualization:
122
+ if self.pbar:
123
+ self.pbar.update(1)
124
+ masks = maskable.sum().item() if maskable is not None else 0
125
+ self.pbar.set_postfix({'masks': masks})
126
+ return
127
+
128
+ self.current_tokens = tokens.clone()
129
+ self.mask_positions = maskable
130
+ self.current_step = step
131
+
132
+ masks_remaining = maskable.sum().item() if maskable is not None else 0
133
+ self.progress.update(
134
+ self.progress_task,
135
+ advance=1,
136
+ masks=masks_remaining
137
+ )
138
+
139
+ self._update_display()
140
+
141
+ def _update_display(self):
142
+ """Update the live display."""
143
+ if not self.live:
144
+ return
145
+
146
+ header = Text("RND1-Base Generation", style="bold magenta", justify="center")
147
+ self.layout["header"].update(Panel(header, border_style="bright_blue"))
148
+
149
+ text_display = self._format_text_with_masks()
150
+ self.layout["text"].update(
151
+ Panel(
152
+ text_display,
153
+ title="[bold]Generated Text",
154
+ subtitle=f"[dim]Step {self.current_step}/{self.total_steps}[/dim]",
155
+ border_style="cyan"
156
+ )
157
+ )
158
+
159
+ self.layout["progress"].update(Panel(self.progress))
160
+
161
+ def _format_text_with_masks(self) -> Text:
162
+ """
163
+ Format text with colored masks.
164
+
165
+ Returns:
166
+ Rich Text object with formatted tokens
167
+ """
168
+ text = Text()
169
+
170
+ if self.current_tokens is None:
171
+ return text
172
+
173
+ token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
174
+ mask_flags = self.mask_positions[0] if self.mask_positions is not None and self.mask_positions.dim() > 1 else self.mask_positions
175
+
176
+ for i, token_id in enumerate(token_ids):
177
+ if mask_flags is not None and i < len(mask_flags) and mask_flags[i]:
178
+ # Alternate colors for visual effect
179
+ text.append("[MASK]", style="bold red on yellow" if self.current_step % 2 == 0 else "bold yellow on red")
180
+ else:
181
+ try:
182
+ token_str = self.tokenizer.decode([token_id.item()], skip_special_tokens=False)
183
+ # Skip special tokens in display
184
+ if token_str not in ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<s>", "</s>"]:
185
+ # Color based on position
186
+ text.append(token_str, style="green" if i < len(token_ids) // 2 else "cyan")
187
+ except:
188
+ continue
189
+
190
+ return text
191
+
192
+ def stop_visualization(self):
193
+ """Stop the visualization and display final result."""
194
+ if not self.show_visualization:
195
+ if self.pbar:
196
+ self.pbar.close()
197
+ print("\n✨ Generation complete!\n")
198
+ return
199
+
200
+ if self.live:
201
+ self.live.stop()
202
+
203
+ self.console.print("\n[bold green]✨ Generation complete![/bold green]\n")
204
+
205
+ # Display final text
206
+ if self.current_tokens is not None:
207
+ try:
208
+ token_ids = self.current_tokens[0] if self.current_tokens.dim() > 1 else self.current_tokens
209
+ final_text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
210
+
211
+ self.console.print(Panel(
212
+ final_text,
213
+ title="[bold]Final Generated Text",
214
+ border_style="green",
215
+ padding=(1, 2)
216
+ ))
217
+ except:
218
+ pass
219
+
220
+
221
+ class SimpleProgressBar:
222
+ """
223
+ Simple progress bar fallback when rich is not available.
224
+
225
+ Provides basic progress tracking using tqdm when the rich library
226
+ is not installed.
227
+ """
228
+
229
+ def __init__(self, total_steps: int):
230
+ """
231
+ Initialize simple progress bar.
232
+
233
+ Args:
234
+ total_steps: Total number of steps
235
+ """
236
+ self.pbar = tqdm(total=total_steps, desc="Diffusion")
237
+
238
+ def update(self, masks_remaining: int = 0):
239
+ """
240
+ Update progress bar.
241
+
242
+ Args:
243
+ masks_remaining: Number of masks still remaining
244
+ """
245
+ self.pbar.update(1)
246
+ self.pbar.set_postfix({'masks': masks_remaining})
247
+
248
+ def close(self):
249
+ """Close the progress bar."""
250
+ self.pbar.close()
251
+ print("\n✨ Generation complete!\n")