athms commited on
Commit
3e7a3bf
·
verified ·
1 Parent(s): 1e53c59

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
config.json ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocab_size": 151936,
3
+ "max_position_embeddings": 40960,
4
+ "hidden_size": 2048,
5
+ "intermediate_size": 6144,
6
+ "num_hidden_layers": 48,
7
+ "num_attention_heads": 32,
8
+ "use_sliding_window": false,
9
+ "sliding_window": null,
10
+ "num_key_value_heads": 4,
11
+ "hidden_act": "silu",
12
+ "initializer_range": 0.02,
13
+ "rms_norm_eps": 1e-06,
14
+ "use_cache": false,
15
+ "rope_theta": 1000000.0,
16
+ "rope_scaling": null,
17
+ "attention_bias": false,
18
+ "attention_dropout": 0.0,
19
+ "decoder_sparse_step": 1,
20
+ "moe_intermediate_size": 768,
21
+ "num_experts_per_tok": 8,
22
+ "num_experts": 128,
23
+ "norm_topk_prob": true,
24
+ "output_router_logits": false,
25
+ "router_aux_loss_coef": 0.001,
26
+ "mlp_only_layers": [],
27
+ "return_dict": true,
28
+ "output_hidden_states": false,
29
+ "torchscript": false,
30
+ "dtype": "bfloat16",
31
+ "pruned_heads": {},
32
+ "tie_word_embeddings": false,
33
+ "chunk_size_feed_forward": 0,
34
+ "is_encoder_decoder": false,
35
+ "is_decoder": false,
36
+ "cross_attention_hidden_size": null,
37
+ "add_cross_attention": false,
38
+ "tie_encoder_decoder": false,
39
+ "architectures": [
40
+ "Qwen3MoeForCausalLM"
41
+ ],
42
+ "finetuning_task": null,
43
+ "id2label": {
44
+ "0": "LABEL_0",
45
+ "1": "LABEL_1"
46
+ },
47
+ "label2id": {
48
+ "LABEL_0": 0,
49
+ "LABEL_1": 1
50
+ },
51
+ "task_specific_params": null,
52
+ "problem_type": null,
53
+ "tokenizer_class": null,
54
+ "prefix": null,
55
+ "bos_token_id": 151643,
56
+ "pad_token_id": null,
57
+ "eos_token_id": 151645,
58
+ "sep_token_id": null,
59
+ "decoder_start_token_id": null,
60
+ "max_length": 20,
61
+ "min_length": 0,
62
+ "do_sample": false,
63
+ "early_stopping": false,
64
+ "num_beams": 1,
65
+ "num_beam_groups": 1,
66
+ "diversity_penalty": 0.0,
67
+ "temperature": 1.0,
68
+ "top_k": 50,
69
+ "top_p": 1.0,
70
+ "typical_p": 1.0,
71
+ "repetition_penalty": 1.0,
72
+ "length_penalty": 1.0,
73
+ "no_repeat_ngram_size": 0,
74
+ "encoder_no_repeat_ngram_size": 0,
75
+ "bad_words_ids": null,
76
+ "num_return_sequences": 1,
77
+ "output_scores": false,
78
+ "return_dict_in_generate": false,
79
+ "forced_bos_token_id": null,
80
+ "forced_eos_token_id": null,
81
+ "remove_invalid_values": false,
82
+ "exponential_decay_length_penalty": null,
83
+ "suppress_tokens": null,
84
+ "begin_suppress_tokens": null,
85
+ "_name_or_path": "",
86
+ "transformers_version": "4.56.1",
87
+ "head_dim": 128,
88
+ "max_window_layers": 48,
89
+ "model_type": "rnd1",
90
+ "is_causal": false,
91
+ "tf_legacy_loss": false,
92
+ "use_bfloat16": false,
93
+ "moe_backend": "hf",
94
+ "num_diffusion_steps": 256,
95
+ "mask_token_id": 151669,
96
+ "output_attentions": false,
97
+ "auto_map": {
98
+ "AutoConfig": "configuration_rnd.RND1Config",
99
+ "AutoModel": "modeling_rnd.RND1Model",
100
+ "AutoModelForMaskedLM": "modeling_rnd.RND1LM"
101
+ }
102
+ }
configuration_rnd.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RND1 Model Configuration.
3
+
4
+ This module defines the configuration class for RND1 models,
5
+ extending Qwen3MoeConfig with RND1-specific parameters.
6
+ """
7
+
8
+ from typing import Optional
9
+ from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
10
+
11
+
12
+ class RND1Config(Qwen3MoeConfig):
13
+ """
14
+ Configuration class for RND1 models.
15
+
16
+ This configuration extends Qwen3MoeConfig with additional parameters
17
+ specific to the RND1 (Radical Numerics Diffusion v1) architecture.
18
+
19
+ Args:
20
+ moe_backend: Backend for MoE computation ("hf", "flashinfer", or "sglang")
21
+ num_diffusion_steps: Default number of diffusion steps for generation
22
+ mask_token_id: Token ID used for masking (default: 151669 for Qwen)
23
+ **kwargs: Additional arguments passed to Qwen3MoeConfig
24
+ """
25
+
26
+ model_type = "rnd1"
27
+
28
+ def __init__(
29
+ self,
30
+ moe_backend: str = "hf",
31
+ num_diffusion_steps: int = 256,
32
+ mask_token_id: int = 151669, # Default for Qwen-based RND1 models
33
+ use_cache: bool = False,
34
+ **kwargs,
35
+ ):
36
+ # Force non-causal and no caching for RND1
37
+ kwargs['use_cache'] = False
38
+ kwargs['is_causal'] = False
39
+ super().__init__(**kwargs)
40
+
41
+ # RND1-specific parameters
42
+ self.moe_backend = moe_backend
43
+ self.num_diffusion_steps = num_diffusion_steps
44
+ self.mask_token_id = mask_token_id
45
+
46
+ # Ensure bidirectional attention and no caching
47
+ self.is_causal = False
48
+ self.use_cache = False
49
+
50
+ def to_dict(self):
51
+ """
52
+ Serializes configuration to dictionary with auto_map for Hub.
53
+
54
+ The auto_map ensures that when users load from HuggingFace Hub,
55
+ the correct custom classes are automatically resolved.
56
+ """
57
+ data = super().to_dict()
58
+ data.setdefault("auto_map", {
59
+ "AutoConfig": "configuration_rnd.RND1Config",
60
+ "AutoModel": "modeling_rnd.RND1Model",
61
+ "AutoModelForMaskedLM": "modeling_rnd.RND1LM",
62
+ })
63
+ return data
generation_config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "eos_token_id": 151645,
4
+ "pad_token_id": 151643,
5
+ "mask_token_id": 151669,
6
+ "max_length": 256,
7
+ "max_new_tokens": 256,
8
+ "num_diffusion_steps": 256,
9
+ "temperature": 1.0,
10
+ "top_k": null,
11
+ "top_p": null,
12
+ "do_sample": true,
13
+ "greedy": true,
14
+ "use_cache": false,
15
+ "_from_model_config": true,
16
+ "transformers_version": "4.45.2"
17
+ }
generation_config.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RND1 Generation Configuration.
3
+
4
+ This module defines the generation configuration for RND1 models,
5
+ controlling the diffusion-based generation process.
6
+ """
7
+
8
+ from typing import Optional
9
+ from transformers.generation.configuration_utils import GenerationConfig
10
+
11
+
12
+ class RND1GenerationConfig(GenerationConfig):
13
+ """
14
+ Configuration class for RND1 generation parameters.
15
+
16
+ This class extends the base GenerationConfig to include parameters
17
+ specific to diffusion-based language generation.
18
+
19
+ Args:
20
+ max_length: Maximum sequence length
21
+ num_diffusion_steps: Number of denoising steps in the diffusion process
22
+ mask_token_id: Token ID used for masking during diffusion
23
+ temperature: Temperature for sampling (higher = more random)
24
+ top_k: Optional top-k filtering
25
+ top_p: Optional nucleus (top-p) filtering
26
+ greedy: Whether to use greedy decoding (True) or stochastic sampling (False)
27
+ **kwargs: Additional arguments passed to GenerationConfig
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ max_length: int = 256,
33
+ num_diffusion_steps: int = 256,
34
+ mask_token_id: int = 151669, # Default for Qwen-based RND1 models
35
+ temperature: float = 1.0,
36
+ top_k: Optional[int] = None,
37
+ top_p: Optional[float] = None,
38
+ greedy: bool = True,
39
+ seed: Optional[int] = None, # For reproducible generation
40
+ bos_token_id: Optional[int] = None,
41
+ eos_token_id: Optional[int] = None,
42
+ pad_token_id: Optional[int] = None,
43
+ use_cache: bool = False,
44
+ **kwargs,
45
+ ):
46
+ # Force no caching for RND1 generation - remove from kwargs if present
47
+ kwargs.pop('use_cache', None)
48
+
49
+ super().__init__(
50
+ max_length=max_length,
51
+ bos_token_id=bos_token_id,
52
+ eos_token_id=eos_token_id,
53
+ pad_token_id=pad_token_id,
54
+ temperature=temperature,
55
+ top_k=top_k,
56
+ top_p=top_p,
57
+ do_sample=not greedy,
58
+ use_cache=False, # Always False for RND1
59
+ **kwargs,
60
+ )
61
+
62
+ # RND1-specific parameters
63
+ self.num_diffusion_steps = num_diffusion_steps
64
+ self.mask_token_id = mask_token_id
65
+ self.greedy = greedy
66
+ self.temperature = float(temperature) # Ensure it's a float
67
+ self.seed = seed
68
+
69
+ def to_dict(self):
70
+ """Convert configuration to dictionary."""
71
+ output = super().to_dict()
72
+ output["num_diffusion_steps"] = self.num_diffusion_steps
73
+ output["mask_token_id"] = self.mask_token_id
74
+ output["greedy"] = self.greedy
75
+ if self.seed is not None:
76
+ output["seed"] = self.seed
77
+ return output
generation_utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RND1 Generation Utilities.
3
+
4
+ This module provides generation utilities and mixins for RND1 models,
5
+ including the main GenerationMixin class that integrates with HuggingFace.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from typing import Optional, Union, Dict, Any
11
+ from transformers import GenerationMixin as HFGenerationMixin
12
+ from transformers.generation import GenerationConfig
13
+
14
+ from .sampling import diffusion_sample, apply_top_k_filtering, apply_top_p_filtering
15
+
16
+
17
+ class RND1GenerationMixin(HFGenerationMixin):
18
+ """
19
+ Generation mixin for RND1 models.
20
+
21
+ This mixin provides generation methods compatible with HuggingFace's
22
+ generation API while using RND1's diffusion-based sampling internally.
23
+ """
24
+
25
+ def generate(
26
+ self,
27
+ inputs: Optional[torch.LongTensor] = None,
28
+ generation_config: Optional[GenerationConfig] = None,
29
+ # RND1-specific parameters
30
+ prefix_ids: Optional[torch.LongTensor] = None,
31
+ suffix_ids: Optional[torch.LongTensor] = None,
32
+ infill_length: Optional[int] = None,
33
+ return_dict_in_generate: Optional[bool] = None,
34
+ **kwargs, # Accept all kwargs to be compatible with pipelines
35
+ ) -> Union[torch.LongTensor, Dict[str, Any]]:
36
+ """
37
+ Generate text using RND1's diffusion-based sampling.
38
+
39
+ Follows HuggingFace's standard generate API, using diffusion sampling
40
+ internally. Supports both standard generation and infilling.
41
+
42
+ Args:
43
+ inputs: Input token IDs to use as prefix (standard HF parameter)
44
+ generation_config: Generation configuration object
45
+ prefix_ids: Alternative to inputs for infilling tasks
46
+ suffix_ids: Optional suffix for infilling tasks
47
+ infill_length: Length of infill region (for infilling)
48
+ return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput
49
+ **kwargs: Additional arguments (accepted for compatibility)
50
+
51
+ Returns:
52
+ Generated token IDs or GenerateDecoderOnlyOutput
53
+ """
54
+ if generation_config is not None:
55
+ gen_config = generation_config
56
+ model_kwargs = kwargs.copy()
57
+ else:
58
+ # Only prepare config from kwargs if no config was provided
59
+ gen_config, model_kwargs = self._prepare_generation_config(None, **kwargs)
60
+
61
+ device = next(self.parameters()).device
62
+
63
+ if inputs is not None:
64
+ prefix_ids = inputs.to(device)
65
+ elif prefix_ids is not None:
66
+ prefix_ids = prefix_ids.to(device)
67
+ else:
68
+ prefix_ids = None
69
+
70
+ if suffix_ids is not None:
71
+ suffix_ids = suffix_ids.to(device)
72
+
73
+ eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645)
74
+ pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", None)
75
+ bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None)
76
+ mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669))
77
+
78
+ if infill_length is not None and prefix_ids is not None:
79
+ # Infilling mode: use specified infill_length
80
+ prefix_len = prefix_ids.shape[1] if prefix_ids is not None else 0
81
+ suffix_len = suffix_ids.shape[1] if suffix_ids is not None else 0
82
+ seq_len = prefix_len + infill_length + suffix_len
83
+ else:
84
+ # Standard generation mode
85
+ if prefix_ids is not None:
86
+ prefix_len = prefix_ids.shape[1]
87
+ if gen_config.max_new_tokens is not None:
88
+ seq_len = prefix_len + gen_config.max_new_tokens
89
+ else:
90
+ seq_len = gen_config.max_length or self.config.max_position_embeddings
91
+ else:
92
+ seq_len = gen_config.max_length or self.config.max_position_embeddings
93
+
94
+ num_diffusion_steps = getattr(gen_config, "num_diffusion_steps",
95
+ getattr(self.config, "num_diffusion_steps", 256))
96
+
97
+ temperature = float(getattr(gen_config, "temperature", 1.0))
98
+ top_k = getattr(gen_config, "top_k", None)
99
+ top_p = getattr(gen_config, "top_p", None)
100
+
101
+ greedy = getattr(gen_config, "greedy",
102
+ not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True)
103
+
104
+ generator = model_kwargs.get("generator", None)
105
+ if generator is None:
106
+ seed = getattr(gen_config, 'seed', None)
107
+ if seed is not None:
108
+ generator = torch.Generator(device=device)
109
+ generator.manual_seed(seed)
110
+
111
+ with torch.inference_mode():
112
+ sequences = diffusion_sample(
113
+ model=self,
114
+ seq_len=seq_len,
115
+ num_steps=num_diffusion_steps,
116
+ mask_token_id=mask_token_id,
117
+ temperature=temperature,
118
+ top_k=top_k,
119
+ top_p=top_p,
120
+ greedy=greedy,
121
+ prefix_ids=prefix_ids,
122
+ suffix_ids=suffix_ids,
123
+ infill_length=infill_length,
124
+ eos_token_id=eos_token_id,
125
+ pad_token_id=pad_token_id,
126
+ bos_token_id=bos_token_id,
127
+ device=device,
128
+ generator=generator,
129
+ visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs
130
+ )
131
+
132
+ if return_dict_in_generate or getattr(gen_config, "return_dict_in_generate", False):
133
+ from transformers.generation.utils import GenerateDecoderOnlyOutput
134
+ return GenerateDecoderOnlyOutput(sequences=sequences)
135
+
136
+ return sequences
137
+
138
+ def prepare_inputs_for_generation(
139
+ self,
140
+ input_ids: torch.LongTensor,
141
+ **kwargs,
142
+ ) -> Dict[str, Any]:
143
+ """
144
+ Prepare inputs for generation (required by HuggingFace).
145
+
146
+ For RND1, we don't use the standard autoregressive generation,
147
+ so this just returns the input_ids.
148
+ """
149
+ return {"input_ids": input_ids}
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling_rnd.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RND1 model implementation.
3
+
4
+ This module implements the RND1 architecture with bidirectional attention for
5
+ diffusion-based language modeling. Includes support for Mixture of Experts (MoE)
6
+ with multiple backend options (HF, FlashInfer, SGLang).
7
+
8
+ Based on the Qwen3Moe architecture:
9
+ https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import os
15
+ from typing import Optional, Tuple, List, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from transformers.utils import logging
21
+ from transformers.cache_utils import Cache
22
+ from transformers.modeling_outputs import (
23
+ MoeModelOutputWithPast,
24
+ MaskedLMOutput,
25
+ )
26
+ from transformers.modeling_utils import PreTrainedModel
27
+ from transformers.configuration_utils import PretrainedConfig
28
+ from transformers.generation import GenerationConfig
29
+
30
+ from .configuration_rnd import RND1Config
31
+ from .generation_utils import RND1GenerationMixin
32
+ from .generation_config import RND1GenerationConfig
33
+
34
+ from transformers.models.qwen3_moe.modeling_qwen3_moe import (
35
+ Qwen3MoeConfig,
36
+ Qwen3MoeRMSNorm,
37
+ Qwen3MoeRotaryEmbedding,
38
+ Qwen3MoeSparseMoeBlock,
39
+ Qwen3MoeMLP,
40
+ apply_rotary_pos_emb
41
+ )
42
+ import torch.nn.functional as F
43
+
44
+ try:
45
+ import flashinfer.fused_moe as fused_moe
46
+ except Exception:
47
+ fused_moe = None
48
+
49
+ try:
50
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe as sglang_fused_moe
51
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
52
+ except Exception:
53
+ sglang_fused_moe = None
54
+ StandardTopKOutput = None
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+
59
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
60
+ """Expand key/value heads to match query heads for grouped-query attention."""
61
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
62
+ if n_rep == 1:
63
+ return hidden_states
64
+ hidden_states = hidden_states[:, :, None, :, :].expand(
65
+ batch, num_key_value_heads, n_rep, slen, head_dim
66
+ )
67
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
68
+
69
+
70
+ class RND1Attention(nn.Module):
71
+ """RND1 attention layer with bidirectional attention for diffusion modeling."""
72
+
73
+ def __init__(self, config: RND1Config, layer_idx: int):
74
+ super().__init__()
75
+ self.config = config
76
+ self.layer_idx = layer_idx
77
+
78
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
79
+ self.num_heads = config.num_attention_heads
80
+ self.num_key_value_heads = config.num_key_value_heads
81
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
82
+
83
+ self.scaling = self.head_dim ** -0.5
84
+ self.attention_dropout = config.attention_dropout
85
+ self.is_causal = False
86
+
87
+ self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
88
+ self.k_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
89
+ self.v_proj = nn.Linear(config.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
90
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
91
+
92
+ self.q_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
93
+ self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
94
+
95
+ self.sliding_window = getattr(config, "sliding_window", None)
96
+
97
+ self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
98
+
99
+ def forward(
100
+ self,
101
+ hidden_states: torch.Tensor,
102
+ attention_mask: Optional[torch.Tensor] = None,
103
+ position_ids: Optional[torch.LongTensor] = None,
104
+ past_key_values: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None,
105
+ cache_position: Optional[torch.LongTensor] = None,
106
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
107
+ dual_cache: Optional[bool] = False,
108
+ replace_position: Optional[torch.Tensor] = None,
109
+ **kwargs,
110
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]]]:
111
+
112
+ bsz, q_len, _ = hidden_states.size()
113
+ input_shape = hidden_states.shape[:-1]
114
+ hidden_shape = (*input_shape, -1, self.head_dim)
115
+
116
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
117
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
118
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
119
+
120
+ cos, sin = position_embeddings
121
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
122
+
123
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
124
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
125
+
126
+ use_sdpa = (getattr(self.config, "_attn_implementation", "eager") == "sdpa")
127
+
128
+ if use_sdpa:
129
+ if attention_mask is not None and isinstance(attention_mask, torch.Tensor):
130
+ if attention_mask.dtype not in [torch.bool, torch.float32, torch.float16, torch.bfloat16]:
131
+ attention_mask = attention_mask.to(dtype=query_states.dtype)
132
+
133
+ assert not self.is_causal, f"Attention layer {self.layer_idx} is causal"
134
+ attn_out = torch.nn.functional.scaled_dot_product_attention(
135
+ query_states, key_states, value_states,
136
+ attn_mask=attention_mask if isinstance(attention_mask, torch.Tensor) else None,
137
+ dropout_p=self.attention_dropout if self.training else 0.0,
138
+ is_causal=self.is_causal,
139
+ )
140
+ attn_out = attn_out.transpose(1, 2).contiguous()
141
+ attn_out = attn_out.view(bsz, q_len, self.num_heads * self.head_dim)
142
+ attn_out = self.o_proj(attn_out)
143
+ return attn_out, None
144
+
145
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
146
+
147
+ if attention_mask is not None:
148
+ attn_weights = attn_weights + attention_mask[:, :, :, : key_states.shape[-2]]
149
+
150
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
151
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
152
+
153
+ attn_out = torch.matmul(attn_weights, value_states)
154
+ attn_out = attn_out.transpose(1, 2).contiguous().view(hidden_states.size(0), hidden_states.size(1), -1)
155
+ attn_out = self.o_proj(attn_out)
156
+
157
+ return attn_out, None
158
+
159
+
160
+ class RND1DecoderLayer(nn.Module):
161
+ """RND1 decoder layer with bidirectional attention for diffusion language modeling."""
162
+
163
+ def __init__(self, config: RND1Config, layer_idx: int):
164
+ super().__init__()
165
+ self.self_attn = RND1Attention(config, layer_idx)
166
+ self.mlp = RND1SparseMoeBlock(config)
167
+ self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
168
+ self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
169
+
170
+ def forward(
171
+ self,
172
+ hidden_states: torch.Tensor,
173
+ attention_mask: Optional[torch.Tensor] = None,
174
+ position_ids: Optional[torch.LongTensor] = None,
175
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
176
+ replace_position: Optional[torch.Tensor] = None,
177
+ **kwargs,
178
+ ) -> Tuple[torch.FloatTensor, Optional[torch.Tensor]]:
179
+ residual = hidden_states
180
+ hidden_states = self.input_layernorm(hidden_states)
181
+
182
+ attn_out, attn_weights = self.self_attn(
183
+ hidden_states,
184
+ attention_mask=attention_mask,
185
+ position_ids=position_ids,
186
+ position_embeddings=position_embeddings,
187
+ replace_position=replace_position,
188
+ )
189
+ hidden_states = residual + attn_out
190
+
191
+ residual = hidden_states
192
+ hidden_states = self.post_attention_layernorm(hidden_states)
193
+ ff_out = self.mlp(hidden_states)
194
+ if isinstance(ff_out, tuple):
195
+ ff_out = ff_out[0]
196
+ hidden_states = residual + ff_out
197
+
198
+ return hidden_states, attn_weights
199
+
200
+
201
+ class RND1SparseMoeBlock(nn.Module):
202
+ """RND1 Sparse MoE block with multiple backend support (HF, FlashInfer, SGLang)."""
203
+
204
+ def __init__(self, config: RND1Config):
205
+ super().__init__()
206
+ self.config = config
207
+ self.backend = getattr(config, "moe_backend", "hf")
208
+ self.num_experts = config.num_experts
209
+ self.top_k = config.num_experts_per_tok
210
+ self.norm_topk_prob = config.norm_topk_prob
211
+ self.hidden_size = config.hidden_size
212
+ self.intermediate_size = getattr(config, "moe_intermediate_size", config.intermediate_size)
213
+
214
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
215
+ self.experts = nn.ModuleList(
216
+ [Qwen3MoeMLP(config, intermediate_size=self.intermediate_size) for _ in range(self.num_experts)]
217
+ )
218
+
219
+ # Cached weight tensors for optimized backends
220
+ self._flashinfer_fc1_weights = None
221
+ self._flashinfer_fc2_weights = None
222
+ self._sglang_w1 = None
223
+ self._sglang_w2 = None
224
+ if self.backend == "sglang":
225
+ if sglang_fused_moe is None or StandardTopKOutput is None:
226
+ raise RuntimeError("sglang is not available, cannot use sglang backend")
227
+ elif self.backend == "flashinfer":
228
+ if fused_moe is None:
229
+ raise RuntimeError("flashinfer is not available, cannot use flashinfer backend")
230
+
231
+ def _initialize_flashinfer_weights(self):
232
+ """Initialize FlashInfer-compatible weight format."""
233
+ fc1_list = []
234
+ fc2_list = []
235
+
236
+ for expert in self.experts:
237
+ gate_w = expert.gate_proj.weight # [I, H]
238
+ up_w = expert.up_proj.weight # [I, H]
239
+ down_w = expert.down_proj.weight # [H, I]
240
+ # FlashInfer expects [up; gate] ordering
241
+ fc1_list.append(torch.cat([up_w, gate_w], dim=0)) # [2I, H]
242
+ fc2_list.append(down_w) # [H, I]
243
+
244
+ self._flashinfer_fc1_weights = torch.stack(fc1_list, dim=0).contiguous()
245
+ self._flashinfer_fc2_weights = torch.stack(fc2_list, dim=0).contiguous()
246
+
247
+ def _initialize_sglang_weights(self):
248
+ """Initialize SGLang-compatible weight format."""
249
+ w1_list = []
250
+ w2_list = []
251
+
252
+ for expert in self.experts:
253
+ gate_w = expert.gate_proj.weight # [I, H]
254
+ up_w = expert.up_proj.weight # [I, H]
255
+ down_w = expert.down_proj.weight # [H, I]
256
+ w1 = torch.cat([gate_w, up_w], dim=0) # [2I, H]
257
+ w1_list.append(w1)
258
+ w2_list.append(down_w)
259
+
260
+ self._sglang_w1 = torch.stack(w1_list, dim=0).contiguous()
261
+ self._sglang_w2 = torch.stack(w2_list, dim=0).contiguous()
262
+
263
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
264
+ """Forward pass with expert routing and computation."""
265
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
266
+ x = hidden_states.view(-1, hidden_dim)
267
+
268
+ # Expert routing
269
+ router_logits = self.gate(x)
270
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
271
+ routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
272
+ if self.norm_topk_prob:
273
+ routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
274
+
275
+ if self.backend == "hf":
276
+ final_hidden_states = torch.zeros(
277
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
278
+ )
279
+
280
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
281
+ expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
282
+
283
+ for expert_idx in expert_hit:
284
+ expert_layer = self.experts[expert_idx]
285
+ idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
286
+ current_state = x[top_x]
287
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
288
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
289
+ out = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
290
+ return out, router_logits.view(batch_size, sequence_length, -1)
291
+
292
+ elif self.backend == "flashinfer":
293
+ if self._flashinfer_fc1_weights is None or self._flashinfer_fc2_weights is None:
294
+ self._initialize_flashinfer_weights()
295
+
296
+ result = fused_moe.cutlass_fused_moe(
297
+ input=x,
298
+ token_selected_experts=selected_experts.to(torch.int),
299
+ token_final_scales=routing_weights.to(torch.float32),
300
+ fc1_expert_weights=self._flashinfer_fc1_weights,
301
+ fc2_expert_weights=self._flashinfer_fc2_weights,
302
+ output_dtype=x.dtype,
303
+ quant_scales=None,
304
+ )
305
+ if isinstance(result, (list, tuple)):
306
+ out_flat = result[0]
307
+ else:
308
+ out_flat = result
309
+ out = out_flat.view(batch_size, sequence_length, hidden_dim)
310
+ return out, router_logits.view(batch_size, sequence_length, -1)
311
+
312
+ elif self.backend == "sglang":
313
+ if self._sglang_w1 is None or self._sglang_w2 is None:
314
+ self._initialize_sglang_weights()
315
+
316
+ topk_output = StandardTopKOutput(
317
+ topk_weights=routing_weights,
318
+ topk_ids=selected_experts,
319
+ router_logits=router_logits,
320
+ )
321
+
322
+ out_flat = sglang_fused_moe(
323
+ hidden_states=x,
324
+ w1=self._sglang_w1,
325
+ w2=self._sglang_w2,
326
+ topk_output=topk_output,
327
+ )
328
+ out = out_flat.view(batch_size, sequence_length, hidden_dim)
329
+ return out, router_logits.view(batch_size, sequence_length, -1)
330
+
331
+ else:
332
+ raise ValueError(f"Invalid backend: {self.backend}")
333
+
334
+
335
+ class RND1PreTrainedModel(PreTrainedModel):
336
+ """Base class for RND1 models with weight initialization and loading support."""
337
+ config_class = RND1Config
338
+ base_model_prefix = "model"
339
+ supports_gradient_checkpointing = True
340
+ _no_split_modules = ["RND1DecoderLayer"]
341
+ _skip_keys_device_placement = "past_key_values"
342
+ _supports_flash_attn_2 = True
343
+ _supports_sdpa = True
344
+ _supports_cache_class = True
345
+ _supports_quantized_cache = True
346
+ _supports_static_cache = True
347
+
348
+ def _init_weights(self, module):
349
+ """Initialize weights using normal distribution."""
350
+ std = self.config.initializer_range
351
+ if isinstance(module, nn.Linear):
352
+ module.weight.data.normal_(mean=0.0, std=std)
353
+ if module.bias is not None:
354
+ module.bias.data.zero_()
355
+ elif isinstance(module, nn.Embedding):
356
+ module.weight.data.normal_(mean=0.0, std=std)
357
+ if module.padding_idx is not None:
358
+ module.weight.data[module.padding_idx].zero_()
359
+
360
+ @classmethod
361
+ def from_pretrained(
362
+ cls,
363
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
364
+ *model_args,
365
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
366
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
367
+ ignore_mismatched_sizes: bool = False,
368
+ force_download: bool = False,
369
+ local_files_only: bool = False,
370
+ token: Optional[Union[str, bool]] = None,
371
+ revision: str = "main",
372
+ use_safetensors: Optional[bool] = None,
373
+ weights_only: bool = True,
374
+ **kwargs,
375
+ ):
376
+ """Load pretrained model with generation config."""
377
+ _model = super().from_pretrained(
378
+ pretrained_model_name_or_path,
379
+ *model_args,
380
+ config=config,
381
+ cache_dir=cache_dir,
382
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
383
+ force_download=force_download,
384
+ local_files_only=local_files_only,
385
+ token=token,
386
+ revision=revision,
387
+ use_safetensors=use_safetensors,
388
+ weights_only=weights_only,
389
+ **kwargs,
390
+ )
391
+
392
+ resume_download = kwargs.get("resume_download", None)
393
+ proxies = kwargs.get("proxies", None)
394
+ subfolder = kwargs.get("subfolder", "")
395
+ from_auto_class = kwargs.get("_from_auto", False)
396
+ from_pipeline = kwargs.get("_from_pipeline", None)
397
+
398
+ _model.generation_config = GenerationConfig.from_pretrained(
399
+ pretrained_model_name_or_path,
400
+ cache_dir=cache_dir,
401
+ force_download=force_download,
402
+ resume_download=resume_download,
403
+ proxies=proxies,
404
+ local_files_only=local_files_only,
405
+ token=token,
406
+ revision=revision,
407
+ subfolder=subfolder,
408
+ _from_auto=from_auto_class,
409
+ _from_pipeline=from_pipeline,
410
+ )
411
+
412
+ return _model
413
+
414
+
415
+ class RND1Model(RND1PreTrainedModel):
416
+ """RND1 transformer model with bidirectional attention for diffusion language modeling."""
417
+
418
+ def __init__(self, config: RND1Config):
419
+ super().__init__(config)
420
+
421
+ self.padding_idx = config.pad_token_id
422
+ self.vocab_size = config.vocab_size
423
+
424
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
425
+ self.layers = nn.ModuleList([RND1DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
426
+ self.norm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
427
+
428
+ self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config)
429
+
430
+ self.post_init()
431
+
432
+
433
+ def forward(
434
+ self,
435
+ input_ids: Optional[torch.LongTensor] = None,
436
+ attention_mask: Optional[torch.Tensor] = None,
437
+ position_ids: Optional[torch.LongTensor] = None,
438
+ inputs_embeds: Optional[torch.FloatTensor] = None,
439
+ **kwargs,
440
+ ) -> MoeModelOutputWithPast:
441
+ """Forward pass through the RND1 model."""
442
+
443
+ if (input_ids is None) == (inputs_embeds is None):
444
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
445
+
446
+ if inputs_embeds is None:
447
+ inputs_embeds = self.embed_tokens(input_ids)
448
+
449
+ if position_ids is None:
450
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
451
+
452
+ position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
453
+
454
+ hidden_states = inputs_embeds
455
+
456
+ for layer in self.layers:
457
+ hidden_states, _ = layer(
458
+ hidden_states,
459
+ attention_mask=attention_mask,
460
+ position_ids=position_ids,
461
+ position_embeddings=position_embeddings,
462
+ )
463
+
464
+ hidden_states = self.norm(hidden_states)
465
+
466
+ return MoeModelOutputWithPast(
467
+ last_hidden_state=hidden_states,
468
+ router_logits=None,
469
+ )
470
+
471
+
472
+ class RND1LM(RND1PreTrainedModel, RND1GenerationMixin):
473
+ """Radical Numerics Diffusion Language Model with bidirectional attention."""
474
+
475
+ def __init__(self, config: RND1Config):
476
+ super().__init__(config)
477
+ self.model = RND1Model(config)
478
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
479
+ self.post_init()
480
+
481
+ def get_input_embeddings(self):
482
+ """Get the input embeddings layer."""
483
+ return self.model.embed_tokens
484
+
485
+ def set_input_embeddings(self, value):
486
+ """Set the input embeddings layer."""
487
+ self.model.embed_tokens = value
488
+
489
+ def get_output_embeddings(self):
490
+ """Get the output embeddings layer (lm_head)."""
491
+ return self.lm_head
492
+
493
+ def set_output_embeddings(self, new_embeddings):
494
+ """Set the output embeddings layer (lm_head)."""
495
+ self.lm_head = new_embeddings
496
+
497
+ @classmethod
498
+ def can_generate(cls) -> bool:
499
+ """Indicates this model can generate text."""
500
+ return True
501
+
502
+ def forward(
503
+ self,
504
+ input_ids: Optional[torch.LongTensor] = None,
505
+ attention_mask: Optional[torch.Tensor] = None,
506
+ position_ids: Optional[torch.LongTensor] = None,
507
+ inputs_embeds: Optional[torch.FloatTensor] = None,
508
+ labels: Optional[torch.LongTensor] = None,
509
+ **kwargs,
510
+ ) -> MaskedLMOutput:
511
+ """Forward pass with optional loss computation."""
512
+ outputs = self.model(
513
+ input_ids=input_ids,
514
+ attention_mask=attention_mask,
515
+ position_ids=position_ids,
516
+ inputs_embeds=inputs_embeds,
517
+ **kwargs,
518
+ )
519
+ logits = self.lm_head(outputs.last_hidden_state)
520
+
521
+ loss = None
522
+ if labels is not None:
523
+ loss_fct = nn.CrossEntropyLoss()
524
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
525
+
526
+ return MaskedLMOutput(
527
+ loss=loss,
528
+ logits=logits,
529
+ )
sampling.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RND1 sampling module for masked diffusion generation.
3
+
4
+ This module implements entropy-based token selection for iterative denoising
5
+ in diffusion language models. Supports both greedy and stochastic sampling
6
+ with optional prefix/suffix constraints and infilling.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from typing import Optional, Tuple, Union
13
+
14
+
15
+ def apply_top_k_filtering(logits: torch.Tensor, k: int) -> torch.Tensor:
16
+ """
17
+ Apply top-k filtering to logits: with non-top-k values set to -inf
18
+ """
19
+ top_k_values, top_k_indices = torch.topk(logits, min(k, logits.size(-1)), dim=-1)
20
+ filtered_logits = torch.full_like(logits, float('-inf'))
21
+ filtered_logits.scatter_(-1, top_k_indices, top_k_values)
22
+ return filtered_logits
23
+
24
+
25
+ def apply_top_p_filtering(logits: torch.Tensor, p: float) -> torch.Tensor:
26
+ """
27
+ Apply top-p (nucleus) filtering to logits: with tokens beyond threshold set to -inf
28
+ """
29
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
30
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
31
+
32
+ # Remove tokens with cumulative probability above threshold
33
+ sorted_indices_to_remove = cumulative_probs > p
34
+ sorted_indices_to_remove[..., 0] = False # Keep at least one token
35
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
36
+
37
+ indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
38
+ return logits.masked_fill(indices_to_remove, float('-inf'))
39
+
40
+
41
+ @torch.no_grad()
42
+ def diffusion_sample(
43
+ model: nn.Module,
44
+ seq_len: int = 256,
45
+ num_steps: int = 256,
46
+ top_k: Optional[int] = None,
47
+ top_p: Optional[float] = None,
48
+ temperature: float = 1.0,
49
+ greedy: bool = True,
50
+ mask_token_id: int = 151669,
51
+ prefix_ids: Optional[torch.LongTensor] = None,
52
+ suffix_ids: Optional[torch.LongTensor] = None,
53
+ infill_length: Optional[int] = None,
54
+ eos_token_id: int = 151645,
55
+ pad_token_id: Optional[int] = None,
56
+ bos_token_id: Optional[int] = None,
57
+ device: Optional[Union[str, torch.device]] = None,
58
+ generator: Optional[torch.Generator] = None,
59
+ visualizer: Optional['TerminalVisualizer'] = None,
60
+ ) -> torch.LongTensor:
61
+ """
62
+ Perform masked diffusion sampling with entropy-based token selection.
63
+
64
+ Args:
65
+ model: The RND1 language model
66
+ seq_len: Target sequence length
67
+ num_steps: Number of denoising steps
68
+ top_k: Optional top-k filtering for sampling (None = no filtering)
69
+ top_p: Optional nucleus (top-p) filtering for sampling (None = no filtering)
70
+ When both top_k and top_p are set, top_k is applied first, then top_p
71
+ temperature: Temperature for sampling (higher = more random, lower = more deterministic)
72
+ Values close to 0 are clamped to 1e-8 to avoid division by zero
73
+ greedy: Whether to use greedy sampling (True) or stochastic (False)
74
+ mask_token_id: Token ID for masked positions (default: 151669)
75
+ prefix_ids: Optional prefix token IDs to preserve
76
+ suffix_ids: Optional suffix token IDs to preserve
77
+ infill_length: Length of infill region between prefix/suffix
78
+ eos_token_id: End of sequence token ID (default: 151645)
79
+ pad_token_id: Padding token ID (default: None, uses 0 if needed)
80
+ bos_token_id: Beginning of sequence token ID (default: None)
81
+ device: Device for computation (None = infer from model)
82
+ generator: Optional torch generator for reproducible sampling
83
+ visualizer: Optional TerminalVisualizer for live visualization
84
+
85
+ Returns:
86
+ Generated token IDs as LongTensor
87
+ """
88
+ model.eval()
89
+
90
+ if device is None:
91
+ device = next(model.parameters()).device
92
+ else:
93
+ device = torch.device(device)
94
+ dtype = next(model.parameters()).dtype
95
+
96
+ if pad_token_id is None:
97
+ pad_token_id = 0
98
+
99
+ # Build initial masked sequence
100
+ # When prefix_ids is provided, we create a sequence of length seq_len where:
101
+ # - The prefix occupies the first pre_len positions
102
+ # - The remaining (seq_len - pre_len) positions are filled with mask tokens to be generated
103
+ if prefix_ids is not None or suffix_ids is not None:
104
+ if prefix_ids is not None:
105
+ prefix_ids = prefix_ids.to(device) if isinstance(prefix_ids, torch.Tensor) else torch.tensor(prefix_ids, device=device)
106
+ pre_len = prefix_ids.shape[-1] if prefix_ids.dim() > 0 else 0
107
+ else:
108
+ pre_len = 0
109
+
110
+ if suffix_ids is not None:
111
+ suffix_ids = suffix_ids.to(device) if isinstance(suffix_ids, torch.Tensor) else torch.tensor(suffix_ids, device=device)
112
+ suf_len = suffix_ids.shape[-1] if suffix_ids.dim() > 0 else 0
113
+ else:
114
+ suf_len = 0
115
+
116
+ reserved = (1 if bos_token_id is not None else 0) + (1 if eos_token_id is not None else 0)
117
+ used = pre_len + suf_len + reserved
118
+
119
+ if used > seq_len:
120
+ raise ValueError(
121
+ f"Combined length of prefix ({pre_len}), suffix ({suf_len}), "
122
+ f"and special tokens ({reserved}) = {used} exceeds seq_len ({seq_len}). "
123
+ f"Please increase seq_len or reduce input lengths."
124
+ )
125
+ elif used == seq_len:
126
+ raise ValueError(
127
+ f"No space for generation: prefix ({pre_len}) + suffix ({suf_len}) "
128
+ f"+ special tokens ({reserved}) = seq_len ({seq_len}). "
129
+ f"Need at least 1 position for generation."
130
+ )
131
+
132
+ infill_length = min(infill_length or (seq_len - used), seq_len - used)
133
+
134
+ x = torch.full((1, seq_len), pad_token_id, dtype=torch.long, device=device)
135
+ pos = 0
136
+ if bos_token_id is not None:
137
+ x[0, pos] = bos_token_id; pos += 1
138
+ if pre_len > 0:
139
+ x[0, pos:pos+pre_len] = prefix_ids.flatten()[:pre_len]; pos += pre_len
140
+ fill_start, fill_end = pos, pos + infill_length
141
+ x[0, fill_start:fill_end] = mask_token_id
142
+ pos = fill_end
143
+ if suf_len > 0:
144
+ x[0, pos:pos+suf_len] = suffix_ids.flatten()[:suf_len]; pos += suf_len
145
+
146
+ if eos_token_id is not None and pos < seq_len:
147
+ if isinstance(eos_token_id, (list, tuple)):
148
+ x[0, pos] = eos_token_id[0]
149
+ else:
150
+ x[0, pos] = eos_token_id
151
+
152
+ init_maskable = torch.zeros_like(x, dtype=torch.bool)
153
+ init_maskable[0, fill_start:fill_end] = True
154
+ else:
155
+ x = torch.full((1, seq_len), mask_token_id, dtype=torch.long, device=device)
156
+ if bos_token_id is not None:
157
+ x[0, 0] = bos_token_id
158
+ if eos_token_id is not None:
159
+ # If eos_token_id is a list, use the first one
160
+ if isinstance(eos_token_id, (list, tuple)):
161
+ x[0, -1] = eos_token_id[0]
162
+ else:
163
+ x[0, -1] = eos_token_id
164
+ init_maskable = x.eq(mask_token_id)
165
+
166
+ if bos_token_id is not None:
167
+ init_maskable[:, 0] = False
168
+ if eos_token_id is not None:
169
+ # Handle both single token and list of tokens
170
+ if isinstance(eos_token_id, (list, tuple)):
171
+ for eos_id in eos_token_id:
172
+ init_maskable &= x.ne(eos_id)
173
+ else:
174
+ init_maskable &= x.ne(eos_token_id)
175
+ init_maskable &= x.ne(pad_token_id)
176
+
177
+ maskable = init_maskable.clone()
178
+ xt = x.clone()
179
+
180
+ if visualizer:
181
+ visualizer.start_visualization(xt, maskable, num_steps)
182
+
183
+ def forward_scores(tokens):
184
+ """Compute predictions and entropy scores for next tokens."""
185
+ # Try with input_ids parameter first (standard HF models)
186
+ try:
187
+ model_output = model(input_ids=tokens)
188
+ except TypeError:
189
+ # Fall back to positional argument
190
+ model_output = model(tokens)
191
+
192
+ safe_temperature = max(temperature, 1e-8) # Prevent division by zero
193
+ logits = model_output.logits / safe_temperature
194
+
195
+ # Note: When both top_k and top_p are provided, they are applied sequentially:
196
+ # First top_k filters to k tokens, then top_p filters from those k tokens
197
+ if top_k is not None and top_k > 0:
198
+ logits = apply_top_k_filtering(logits, top_k)
199
+
200
+ if top_p is not None and 0 < top_p < 1.0:
201
+ logits = apply_top_p_filtering(logits, top_p)
202
+
203
+ logp = torch.log_softmax(logits, dim=-1)
204
+
205
+ if greedy:
206
+ pred_next = logp.argmax(-1)
207
+ else:
208
+ # Sample from categorical distribution with proper RNG handling
209
+ if generator is not None:
210
+ # Use multinomial with generator for reproducible sampling
211
+ probs = logp.exp()
212
+ pred_next = torch.multinomial(probs.view(-1, probs.size(-1)), 1, generator=generator).squeeze(-1).view(probs.shape[:-1])
213
+ else:
214
+ pred_next = torch.distributions.Categorical(logits=logp).sample()
215
+
216
+ conf_next = torch.gather(logp, -1, pred_next.unsqueeze(-1)).squeeze(-1)
217
+
218
+ p = logp.exp()
219
+ ent_next = -(p * logp).sum(-1)
220
+
221
+ # Shift predictions: pos i predicts token i+1
222
+ pred_i = tokens.clone()
223
+ conf_i = torch.full_like(conf_next, torch.finfo(conf_next.dtype).min)
224
+ ent_i = torch.zeros_like(ent_next)
225
+
226
+ pred_i[:, 1:] = pred_next[:, :-1]
227
+ conf_i[:, 1:] = conf_next[:, :-1]
228
+ ent_i[:, 1:] = ent_next[:, :-1]
229
+
230
+ return pred_i, conf_i, ent_i
231
+
232
+ pred_i, conf_i, ent_i = forward_scores(xt)
233
+ total_masked = init_maskable.sum(1, keepdim=True)
234
+ finf = torch.finfo(conf_i.dtype)
235
+
236
+ for step in range(num_steps - 1, 0, -1):
237
+ rate = step / num_steps
238
+ cutoff_len = (total_masked * rate).long().clamp(min=0)
239
+
240
+ # Choose HIGH-entropy tokens to keep masked
241
+ sel_scores = ent_i.masked_fill(~maskable, -finf.max)
242
+ B, L = sel_scores.shape
243
+ k_max = cutoff_len.max().item()
244
+ if k_max > 0:
245
+ sss, idx = torch.topk(sel_scores, k_max, dim=-1, largest=True)
246
+ keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
247
+ for b in range(B):
248
+ k_b = int(cutoff_len[b].item())
249
+ if k_b > 0:
250
+ keep_mask[b, idx[b, :k_b]] = True
251
+ else:
252
+ keep_mask = torch.zeros_like(sel_scores, dtype=torch.bool)
253
+
254
+ to_unmask = maskable & ~keep_mask
255
+ if to_unmask.any():
256
+ xt[to_unmask] = pred_i[to_unmask]
257
+ maskable[to_unmask] = False
258
+
259
+ if visualizer:
260
+ visualizer.update_step(xt, maskable, num_steps - step, ent_i, conf_i)
261
+
262
+ if maskable.any():
263
+ pred_i, conf_i, ent_i = forward_scores(xt)
264
+
265
+ if maskable.any():
266
+ xt[maskable] = pred_i[maskable]
267
+
268
+ if visualizer:
269
+ visualizer.stop_visualization()
270
+
271
+ return xt
special_tokens_map.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ },
31
+ "mask_token": {
32
+ "content": "<|mask|>",
33
+ "lstrip": false,
34
+ "normalized": false,
35
+ "rstrip": false,
36
+ "single_word": false
37
+ }
38
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
tokenizer_config.json ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ },
213
+ "151669": {
214
+ "content": "<|mask|>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": true
220
+ }
221
+ },
222
+ "additional_special_tokens": [
223
+ "<|im_start|>",
224
+ "<|im_end|>",
225
+ "<|object_ref_start|>",
226
+ "<|object_ref_end|>",
227
+ "<|box_start|>",
228
+ "<|box_end|>",
229
+ "<|quad_start|>",
230
+ "<|quad_end|>",
231
+ "<|vision_start|>",
232
+ "<|vision_end|>",
233
+ "<|vision_pad|>",
234
+ "<|image_pad|>",
235
+ "<|video_pad|>"
236
+ ],
237
+ "bos_token": null,
238
+ "clean_up_tokenization_spaces": false,
239
+ "eos_token": "<|im_end|>",
240
+ "errors": "replace",
241
+ "extra_special_tokens": {},
242
+ "model_max_length": 131072,
243
+ "pad_token": "<|endoftext|>",
244
+ "split_special_tokens": false,
245
+ "tokenizer_class": "Qwen2Tokenizer",
246
+ "unk_token": null,
247
+ "mask_token": "<|mask|>",
248
+ "mask_token_id": 151669
249
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff