jeromeku commited on
Commit
aec6f57
·
verified ·
1 Parent(s): 2e94fb9

config fixes

Browse files
Files changed (1) hide show
  1. configuration_rnd.py +21 -8
configuration_rnd.py CHANGED
@@ -7,6 +7,7 @@ extending Qwen3MoeConfig with RND1-specific parameters.
7
 
8
  from typing import Optional
9
  from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
 
10
 
11
 
12
  class RND1Config(Qwen3MoeConfig):
@@ -34,10 +35,19 @@ class RND1Config(Qwen3MoeConfig):
34
  **kwargs,
35
  ):
36
  # Force non-causal and no caching for RND1
37
- kwargs['use_cache'] = False
38
- kwargs['is_causal'] = False
39
  super().__init__(**kwargs)
40
 
 
 
 
 
 
 
 
 
 
41
  # RND1-specific parameters
42
  self.moe_backend = moe_backend
43
  self.num_diffusion_steps = num_diffusion_steps
@@ -55,9 +65,12 @@ class RND1Config(Qwen3MoeConfig):
55
  the correct custom classes are automatically resolved.
56
  """
57
  data = super().to_dict()
58
- data.setdefault("auto_map", {
59
- "AutoConfig": "configuration_rnd.RND1Config",
60
- "AutoModel": "modeling_rnd.RND1Model",
61
- "AutoModelForMaskedLM": "modeling_rnd.RND1LM",
62
- })
63
- return data
 
 
 
 
7
 
8
  from typing import Optional
9
  from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig
10
+ from transformers import AutoConfig
11
 
12
 
13
  class RND1Config(Qwen3MoeConfig):
 
35
  **kwargs,
36
  ):
37
  # Force non-causal and no caching for RND1
38
+ kwargs["use_cache"] = False
39
+ kwargs["is_causal"] = False
40
  super().__init__(**kwargs)
41
 
42
+ # `head_dim` needs to be 128 for Qwen3MoE
43
+ # need to ensure that the config has this attr if directly passing config to RND1LM at instantiation
44
+ if not hasattr(self, "head_dim"):
45
+ self.head_dim = 128
46
+
47
+ # Note that in transformers 4.57.0 there is an error in the config
48
+ # num_hidden_layers is defaulted to 24
49
+ self.num_hidden_layers = 48
50
+
51
  # RND1-specific parameters
52
  self.moe_backend = moe_backend
53
  self.num_diffusion_steps = num_diffusion_steps
 
65
  the correct custom classes are automatically resolved.
66
  """
67
  data = super().to_dict()
68
+ data.setdefault(
69
+ "auto_map",
70
+ {
71
+ "AutoConfig": "configuration_rnd.RND1Config",
72
+ "AutoModel": "modeling_rnd.RND1Model",
73
+ "AutoModelForMaskedLM": "modeling_rnd.RND1LM",
74
+ },
75
+ )
76
+ return data