""" vLLM-compatible implementation of KORMo MoE This file should be placed in: /usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/kormo_moe.py Usage: from vllm import LLM llm = LLM( model="/path/to/kormo_moe_model", trust_remote_code=False, # Not needed with this implementation dtype="float16", ) """ from collections.abc import Iterable from typing import Any, Optional, Union import torch import torch.nn.functional as F from torch import nn from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors try: from transformers import PretrainedConfig except ImportError: # Fallback for environments without transformers PretrainedConfig = object from .interfaces import SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) logger = init_logger(__name__) class KORMoMoeConfig(PretrainedConfig): """Configuration class for KORMo MoE""" model_type = "kormo_moe" def __init__( self, vocab_size=112576, hidden_size=6144, intermediate_size=21504, num_hidden_layers=48, num_attention_heads=40, num_key_value_heads=8, hidden_act="silu", max_position_embeddings=131072, initializer_range=0.02, rms_norm_eps=1e-05, use_cache=True, pad_token_id=None, bos_token_id=0, eos_token_id=1, tie_word_embeddings=False, rope_theta=500000.0, attention_dropout=0.0, rope_scaling=None, head_dim=128, # MoE specific num_experts=2, num_experts_per_tok=2, moe_intermediate_size=None, shared_expert_intermediate_size=None, norm_topk_prob=True, decoder_sparse_step=1, **kwargs, ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads or num_attention_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_dropout = attention_dropout self.head_dim = head_dim or (self.hidden_size // self.num_attention_heads) # MoE specific self.num_experts = num_experts self.num_experts_per_tok = num_experts_per_tok self.moe_intermediate_size = ( moe_intermediate_size if moe_intermediate_size is not None else intermediate_size ) self.shared_expert_intermediate_size = shared_expert_intermediate_size self.norm_topk_prob = norm_topk_prob self.decoder_sparse_step = decoder_sparse_step super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) class KORMoMoEMLP(nn.Module): """MLP for KORMo, used for shared expert""" def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, reduce_results=reduce_results, ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. Only silu is supported.") self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x class KORMoSparseMoeBlock(nn.Module): """KORMo Sparse MoE Block optimized for vLLM""" def __init__( self, config: KORMoMoeConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}." ) # Use vLLM's FusedMoE for optimized expert routing self.experts = FusedMoE( num_experts=config.num_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=f"{prefix}.experts", ) # Router/gate self.gate = ReplicatedLinear( config.hidden_size, config.num_experts, bias=False, quant_config=None, ) # Shared expert (optional) if config.shared_expert_intermediate_size and config.shared_expert_intermediate_size > 0: self.shared_expert = KORMoMoEMLP( hidden_size=config.hidden_size, intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=self.experts.must_reduce_shared_expert_outputs(), ) self.shared_expert_gate = nn.Linear(config.hidden_size, 1, bias=False) else: self.shared_expert = None self.shared_expert_gate = None def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) # Shared expert처리 shared_output = None if self.shared_expert is not None: shared_output = self.shared_expert(hidden_states) if self.shared_expert_gate is not None: shared_output = F.sigmoid( self.shared_expert_gate(hidden_states) ) * shared_output # Router logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) # FusedMoE에서 expert routing 수행 final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits, ) # Shared expert 결과 추가 if shared_output is not None: final_hidden_states = final_hidden_states + shared_output # Tensor parallel reduction if self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( final_hidden_states ) return final_hidden_states.view(orig_shape) class KORMoMoeAttention(nn.Module): """KORMo MoE Attention mechanism""" def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, rope_theta: float = 500000, rope_scaling: Optional[dict[str, Any]] = None, max_position_embeddings: int = 131072, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: assert self.total_num_kv_heads % tp_size == 0 else: assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=False, quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output class KORMoMoeDecoderLayer(nn.Module): """KORMo MoE Decoder Layer""" def __init__( self, config: KORMoMoeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size # Attention self.self_attn = KORMoMoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, rope_theta=config.rope_theta, rope_scaling=config.rope_scaling, max_position_embeddings=config.max_position_embeddings, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", ) # MoE MLP self.mlp = KORMoSparseMoeBlock( config=config, quant_config=quant_config, prefix=f"{prefix}.mlp", ) # LayerNorms (using KORMo naming convention) self.pre_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention if residual is None: residual = hidden_states hidden_states = self.pre_attention_layernorm(hidden_states) else: hidden_states, residual = self.pre_attention_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # MoE MLP hidden_states, residual = self.pre_mlp_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile class KORMoMoeModel(nn.Module): """KORMo MoE Model""" def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.vocab_size = config.vocab_size self.config = config self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: KORMoMoeDecoderLayer( config=config, cache_config=cache_config, quant_config=quant_config, prefix=prefix, ), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.get_input_embeddings(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] for layer in self.layers[self.start_layer : self.end_layer]: hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, "residual": residual, }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: """Return expert parameter mapping for weight loading""" return FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: # Handle stacked parameters for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue if "mlp.experts" in name: continue name = name.replace(weight_name, param_name) if (name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue if name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: # Handle expert parameters for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) if is_pp_missing_parameter(name, self): continue if (name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader weight_loader( param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id, ) break else: # Handle regular parameters if (name.endswith(".bias") or name.endswith("_bias")) and name not in params_dict: continue if is_pp_missing_parameter(name, self): continue # Fix gate weight naming: gate.linear.weight -> gate.weight if ".gate.linear.weight" in name: name = name.replace(".gate.linear.weight", ".gate.weight") if name not in params_dict: logger.warning(f"Parameter {name} not found in model") continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params class KORMoMoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA): """KORMo MoE for Causal Language Modeling""" fall_back_to_pt_during_load = False packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], "gate_up_proj": ["gate_proj", "up_proj"], } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config self.model = KORMoMoeModel(vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config, ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, intermediate_tensors, inputs_embeds) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping()