|
""" |
|
Custom model class for LLM2Vec4CXR that properly handles latent attention pooling. |
|
""" |
|
|
|
from llm2vec.models.bidirectional_llama import LlamaBiModel |
|
from llm2vec.pooling import LatentAttentionPooling |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class LLM2Vec4CXRModel(LlamaBiModel): |
|
""" |
|
Custom LlamaBiModel that includes latent attention pooling by default. |
|
This prevents the warning about unused latent attention weights. |
|
""" |
|
|
|
def __init__(self, config, **kwargs): |
|
super().__init__(config, **kwargs) |
|
|
|
|
|
self.latent_attn = LatentAttentionPooling( |
|
d_model=config.hidden_size, |
|
num_heads=8, |
|
num_latents=512 |
|
) |
|
|
|
|
|
if hasattr(self, 'model') and hasattr(self.model, 'embed_tokens'): |
|
device = self.model.embed_tokens.weight.device |
|
dtype = self.model.embed_tokens.weight.dtype |
|
self.latent_attn = self.latent_attn.to(device=device, dtype=dtype) |
|
|
|
def forward(self, input_ids, attention_mask=None, embed_mask=None, **kwargs): |
|
""" |
|
Forward pass that properly handles latent attention pooling. |
|
""" |
|
|
|
outputs = super().forward(input_ids, attention_mask=attention_mask, **kwargs) |
|
|
|
|
|
if hasattr(self, 'latent_attn') and self.latent_attn is not None: |
|
if embed_mask is not None: |
|
|
|
pooled_output = self.latent_attn(outputs.last_hidden_state, embed_mask) |
|
else: |
|
|
|
pooled_output = self.latent_attn(outputs.last_hidden_state, attention_mask) |
|
return pooled_output |
|
|
|
return outputs.last_hidden_state |
|
|
|
|
|
|
|
from transformers import AutoModel |
|
AutoModel.register(LLM2Vec4CXRModel.__name__, LLM2Vec4CXRModel) |
|
|