from typing import Optional import torch, math import torch.nn from einops import rearrange from torch import nn from functools import partial from einops import rearrange def attention(q, k, v, attn_mask, mode="torch"): q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2) x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) x = rearrange(x, "b n s d -> b s (n d)") return x class MLP(nn.Module): """MLP as used in Vision Transformer, MLP-Mixer and related networks""" def __init__( self, in_channels, hidden_channels=None, out_features=None, act_layer=nn.GELU, norm_layer=None, bias=True, drop=0.0, use_conv=False, device=None, dtype=None, ): super().__init__() out_features = out_features or in_channels hidden_channels = hidden_channels or in_channels bias = (bias, bias) drop_probs = (drop, drop) linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear self.fc1 = linear_layer( in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype ) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.norm = ( norm_layer(hidden_channels, device=device, dtype=dtype) if norm_layer is not None else nn.Identity() ) self.fc2 = linear_layer( hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype ) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.norm(x) x = self.fc2(x) x = self.drop2(x) return x class TextProjection(nn.Module): """ Projects text embeddings. Also handles dropout for classifier-free guidance. Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): factory_kwargs = {"dtype": dtype, "device": device} super().__init__() self.linear_1 = nn.Linear( in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs, ) self.act_1 = act_layer() self.linear_2 = nn.Linear( in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs, ) def forward(self, caption): hidden_states = self.linear_1(caption) hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__( self, hidden_size, act_layer, frequency_embedding_size=256, max_period=10000, out_size=None, dtype=None, device=None, ): factory_kwargs = {"dtype": dtype, "device": device} super().__init__() self.frequency_embedding_size = frequency_embedding_size self.max_period = max_period if out_size is None: out_size = hidden_size self.mlp = nn.Sequential( nn.Linear( frequency_embedding_size, hidden_size, bias=True, **factory_kwargs ), act_layer(), nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), ) nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. Args: t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. dim (int): the dimension of the output. max_period (int): controls the minimum frequency of the embeddings. Returns: embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py """ half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding def forward(self, t): t_freq = self.timestep_embedding( t, self.frequency_embedding_size, self.max_period ).type(self.mlp[0].weight.dtype) # type: ignore t_emb = self.mlp(t_freq) return t_emb def apply_gate(x, gate=None, tanh=False): """AI is creating summary for apply_gate Args: x (torch.Tensor): input tensor. gate (torch.Tensor, optional): gate tensor. Defaults to None. tanh (bool, optional): whether to use tanh function. Defaults to False. Returns: torch.Tensor: the output tensor after apply gate. """ if gate is None: return x if tanh: return x * gate.unsqueeze(1).tanh() else: return x * gate.unsqueeze(1) class RMSNorm(nn.Module): def __init__( self, dim: int, elementwise_affine=True, eps: float = 1e-6, device=None, dtype=None, ): """ Initialize the RMSNorm normalization layer. Args: dim (int): The dimension of the input tensor. eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. Attributes: eps (float): A small value added to the denominator for numerical stability. weight (nn.Parameter): Learnable scaling parameter. """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps if elementwise_affine: self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) def _norm(self, x): """ Apply the RMSNorm normalization to the input tensor. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The normalized tensor. """ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): """ Forward pass through the RMSNorm layer. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor after applying RMSNorm. """ output = self._norm(x.float()).type_as(x) if hasattr(self, "weight"): output = output * self.weight return output def get_norm_layer(norm_layer): """ Get the normalization layer. Args: norm_layer (str): The type of normalization layer. Returns: norm_layer (nn.Module): The normalization layer. """ if norm_layer == "layer": return nn.LayerNorm elif norm_layer == "rms": return RMSNorm else: raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") def get_activation_layer(act_type): """get activation layer Args: act_type (str): the activation type Returns: torch.nn.functional: the activation layer """ if act_type == "gelu": return lambda: nn.GELU() elif act_type == "gelu_tanh": return lambda: nn.GELU(approximate="tanh") elif act_type == "relu": return nn.ReLU elif act_type == "silu": return nn.SiLU else: raise ValueError(f"Unknown activation type: {act_type}") class IndividualTokenRefinerBlock(torch.nn.Module): def __init__( self, hidden_size, heads_num, mlp_width_ratio: str = 4.0, mlp_drop_rate: float = 0.0, act_type: str = "silu", qk_norm: bool = False, qk_norm_type: str = "layer", qkv_bias: bool = True, need_CA: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.need_CA = need_CA self.heads_num = heads_num head_dim = hidden_size // heads_num mlp_hidden_dim = int(hidden_size * mlp_width_ratio) self.norm1 = nn.LayerNorm( hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs ) self.self_attn_qkv = nn.Linear( hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs ) qk_norm_layer = get_norm_layer(qk_norm_type) self.self_attn_q_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.self_attn_k_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.self_attn_proj = nn.Linear( hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs ) self.norm2 = nn.LayerNorm( hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs ) act_layer = get_activation_layer(act_type) self.mlp = MLP( in_channels=hidden_size, hidden_channels=mlp_hidden_dim, act_layer=act_layer, drop=mlp_drop_rate, **factory_kwargs, ) self.adaLN_modulation = nn.Sequential( act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), ) if self.need_CA: self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size, heads_num=heads_num, mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, act_type=act_type, qk_norm=qk_norm, qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, **factory_kwargs,) # Zero-initialize the modulation nn.init.zeros_(self.adaLN_modulation[1].weight) nn.init.zeros_(self.adaLN_modulation[1].bias) def forward( self, x: torch.Tensor, c: torch.Tensor, # timestep_aware_representations + context_aware_representations attn_mask: torch.Tensor = None, y: torch.Tensor = None, ): gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) norm_x = self.norm1(x) qkv = self.self_attn_qkv(norm_x) q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) # Apply QK-Norm if needed q = self.self_attn_q_norm(q).to(v) k = self.self_attn_k_norm(k).to(v) # Self-Attention attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) x = x + apply_gate(self.self_attn_proj(attn), gate_msa) if self.need_CA: x = self.cross_attnblock(x, c, attn_mask, y) # FFN Layer x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp) return x class CrossAttnBlock(torch.nn.Module): def __init__( self, hidden_size, heads_num, mlp_width_ratio: str = 4.0, mlp_drop_rate: float = 0.0, act_type: str = "silu", qk_norm: bool = False, qk_norm_type: str = "layer", qkv_bias: bool = True, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.heads_num = heads_num head_dim = hidden_size // heads_num self.norm1 = nn.LayerNorm( hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs ) self.norm1_2 = nn.LayerNorm( hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs ) self.self_attn_q = nn.Linear( hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs ) self.self_attn_kv = nn.Linear( hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs ) qk_norm_layer = get_norm_layer(qk_norm_type) self.self_attn_q_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.self_attn_k_norm = ( qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() ) self.self_attn_proj = nn.Linear( hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs ) self.norm2 = nn.LayerNorm( hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs ) act_layer = get_activation_layer(act_type) self.adaLN_modulation = nn.Sequential( act_layer(), nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs), ) # Zero-initialize the modulation nn.init.zeros_(self.adaLN_modulation[1].weight) nn.init.zeros_(self.adaLN_modulation[1].bias) def forward( self, x: torch.Tensor, c: torch.Tensor, # timestep_aware_representations + context_aware_representations attn_mask: torch.Tensor = None, y: torch.Tensor=None, ): gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1) norm_x = self.norm1(x) norm_y = self.norm1_2(y) q = self.self_attn_q(norm_x) q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num) kv = self.self_attn_kv(norm_y) k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num) # Apply QK-Norm if needed q = self.self_attn_q_norm(q).to(v) k = self.self_attn_k_norm(k).to(v) # Self-Attention attn = attention(q, k, v, mode="torch", attn_mask=attn_mask) x = x + apply_gate(self.self_attn_proj(attn), gate_msa) return x class IndividualTokenRefiner(torch.nn.Module): def __init__( self, hidden_size, heads_num, depth, mlp_width_ratio: float = 4.0, mlp_drop_rate: float = 0.0, act_type: str = "silu", qk_norm: bool = False, qk_norm_type: str = "layer", qkv_bias: bool = True, need_CA:bool=False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.need_CA = need_CA self.blocks = nn.ModuleList( [ IndividualTokenRefinerBlock( hidden_size=hidden_size, heads_num=heads_num, mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, act_type=act_type, qk_norm=qk_norm, qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, need_CA=self.need_CA, **factory_kwargs, ) for _ in range(depth) ] ) def forward( self, x: torch.Tensor, c: torch.LongTensor, mask: Optional[torch.Tensor] = None, y:torch.Tensor=None, ): self_attn_mask = None if mask is not None: batch_size = mask.shape[0] seq_len = mask.shape[1] mask = mask.to(x.device) # batch_size x 1 x seq_len x seq_len self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat( 1, 1, seq_len, 1 ) # batch_size x 1 x seq_len x seq_len self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() # avoids self-attention weight being NaN for padding tokens self_attn_mask[:, :, :, 0] = True for block in self.blocks: x = block(x, c, self_attn_mask,y) return x class SingleTokenRefiner(torch.nn.Module): """ A single token refiner block for llm text embedding refine. """ def __init__( self, in_channels, hidden_size, heads_num, depth, mlp_width_ratio: float = 4.0, mlp_drop_rate: float = 0.0, act_type: str = "silu", qk_norm: bool = False, qk_norm_type: str = "layer", qkv_bias: bool = True, need_CA:bool=False, attn_mode: str = "torch", dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.attn_mode = attn_mode self.need_CA = need_CA assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner." self.input_embedder = nn.Linear( in_channels, hidden_size, bias=True, **factory_kwargs ) if self.need_CA: self.input_embedder_CA = nn.Linear( in_channels, hidden_size, bias=True, **factory_kwargs ) act_layer = get_activation_layer(act_type) # Build timestep embedding layer self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs) # Build context embedding layer self.c_embedder = TextProjection( in_channels, hidden_size, act_layer, **factory_kwargs ) self.individual_token_refiner = IndividualTokenRefiner( hidden_size=hidden_size, heads_num=heads_num, depth=depth, mlp_width_ratio=mlp_width_ratio, mlp_drop_rate=mlp_drop_rate, act_type=act_type, qk_norm=qk_norm, qk_norm_type=qk_norm_type, qkv_bias=qkv_bias, need_CA=need_CA, **factory_kwargs, ) def forward( self, x: torch.Tensor, t: torch.LongTensor, mask: Optional[torch.LongTensor] = None, y: torch.LongTensor=None, ): timestep_aware_representations = self.t_embedder(t) if mask is None: context_aware_representations = x.mean(dim=1) else: mask_float = mask.unsqueeze(-1) # [b, s1, 1] context_aware_representations = (x * mask_float).sum( dim=1 ) / mask_float.sum(dim=1) context_aware_representations = self.c_embedder(context_aware_representations) c = timestep_aware_representations + context_aware_representations x = self.input_embedder(x) if self.need_CA: y = self.input_embedder_CA(y) x = self.individual_token_refiner(x, c, mask, y) else: x = self.individual_token_refiner(x, c, mask) return x class Qwen2Connector(torch.nn.Module): def __init__( self, # biclip_dim=1024, in_channels=3584, hidden_size=4096, heads_num=32, depth=2, need_CA=False, device=None, dtype=torch.bfloat16, ): super().__init__() factory_kwargs = {"device": device, "dtype":dtype} self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs) self.global_proj_out=nn.Linear(in_channels,768) self.scale_factor = nn.Parameter(torch.zeros(1)) with torch.no_grad(): self.scale_factor.data += -(1 - 0.09) def forward(self, x,t,mask): mask_float = mask.unsqueeze(-1) # [b, s1, 1] x_mean = (x * mask_float).sum( dim=1 ) / mask_float.sum(dim=1) * (1 + self.scale_factor) global_out=self.global_proj_out(x_mean) encoder_hidden_states = self.S(x,t,mask) return encoder_hidden_states,global_out @staticmethod def state_dict_converter(): return Qwen2ConnectorStateDictConverter() class Qwen2ConnectorStateDictConverter: def __init__(self): pass def from_diffusers(self, state_dict): return state_dict def from_civitai(self, state_dict): state_dict_ = {} for name, param in state_dict.items(): if name.startswith("connector."): name_ = name[len("connector."):] state_dict_[name_] = param return state_dict_