from transformers import CLIPTokenizer, CLIPImageProcessor, CLIPTextModel, CLIPPreTrainedModel, CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPTextEmbeddings, CLIPEncoder, CLIPAttention, CLIPMLP, CLIPEncoderLayer, _create_4d_causal_attention_mask, _prepare_4d_attention_mask, BaseModelOutputWithPooling, CLIPTextModelOutput from typing import Optional, Union, Tuple import torch from torch import nn class CustomCLIPTokenizer(CLIPTokenizer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Inherit everything from the original tokenizer # No additional initialization needed unless you want to add specific features class CustomCLIPImageProcessor(CLIPImageProcessor): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Inherit everything from the original processor # No additional initialization needed unless you want to add specific features class CustomCLIPTextTransformer(nn.Module): def __init__(self, config): super().__init__() self.config = config embed_dim = config.hidden_size self.embeddings = CLIPTextEmbeddings(config) self.encoder = CLIPEncoder(config) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) # For `pooled_output` computation self.eos_token_id = config.eos_token_id def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") if input_ids is None and inputs_embeds is None: raise ValueError("You must provide either input_ids or inputs_embeds") if inputs_embeds is not None: # inputs_embeds are already embeddings, just add positional embeddings inputs_embeds = self.embeddings.position_embedding(self.embeddings.position_ids[:, :inputs_embeds.size(1)]) + inputs_embeds else: inputs_embeds = self.embeddings(input_ids=input_ids, position_ids=position_ids) # CLIP's text model uses causal mask, prepare it here. causal_attention_mask = _create_4d_causal_attention_mask( inputs_embeds.size()[:-1], inputs_embeds.dtype, device=inputs_embeds.device ) # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) encoder_outputs = self.encoder( inputs_embeds=inputs_embeds, attention_mask=attention_mask, causal_attention_mask=causal_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) last_hidden_state = encoder_outputs[0] last_hidden_state = self.final_layer_norm(last_hidden_state) # Update the pooled output computation to work with both input types if input_ids is not None: # Use input_ids to find the EOS token position if self.eos_token_id == 2: pooled_output = last_hidden_state[ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), ] else: pooled_output = last_hidden_state[ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id) .int() .argmax(dim=-1), ] else: # When using inputs_embeds, use the last token as the pooled output pooled_output = last_hidden_state[:, -1] if not return_dict: return (last_hidden_state, pooled_output) + encoder_outputs[1:] return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class CustomCLIPTextModel(CLIPPreTrainedModel): config_class = CLIPTextConfig _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] def __init__(self, config: CLIPTextConfig): super().__init__(config) self.text_model = CustomCLIPTextTransformer(config) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.text_model.embeddings.token_embedding def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPooling]: return self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) class CustomCLIPTextModelWithProjection(CLIPPreTrainedModel): config_class = CLIPTextConfig _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] def __init__(self, config: CLIPTextConfig): super().__init__(config) self.text_model = CustomCLIPTextTransformer(config) # Add the projection layer for SDXL's second text encoder projection_dim = getattr(config, 'projection_dim', config.hidden_size) self.text_projection = nn.Linear(config.hidden_size, projection_dim, bias=False) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self) -> nn.Module: return self.text_model.embeddings.token_embedding def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CLIPTextModelOutput]: text_outputs = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = text_outputs[1] if not return_dict else text_outputs.pooler_output # Apply the projection to the pooled output text_embeds = self.text_projection(pooled_output) if not return_dict: # Include both last_hidden_state, pooler_output, text_embeds, and other outputs return (text_outputs[0], text_outputs[1], text_embeds) + text_outputs[2:] return CLIPTextModelOutput( text_embeds=text_embeds, # Projected embeddings (for similarity) last_hidden_state=text_outputs.last_hidden_state, # All token representations hidden_states=text_outputs.hidden_states, attentions=text_outputs.attentions, )