|
import torch |
|
from transformers import Qwen2_5OmniThinkerTextModel, Qwen2_5OmniThinkerForConditionalGeneration |
|
from transformers.cache_utils import Cache |
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
|
from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import Qwen2_5OmniThinkerConfig |
|
|
|
class BidirectQwen2_5OmniThinkerTextModel(Qwen2_5OmniThinkerTextModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
for layer in self.layers: |
|
layer.self_attn.is_causal = False |
|
|
|
|
|
def _update_causal_mask( |
|
self, |
|
attention_mask: torch.Tensor, |
|
input_tensor: torch.Tensor, |
|
cache_position: torch.Tensor, |
|
past_key_values: Cache, |
|
output_attentions: bool = False, |
|
): |
|
calculated_attention_mask = super()._update_causal_mask( |
|
attention_mask, |
|
input_tensor, |
|
cache_position, |
|
past_key_values, |
|
output_attentions) |
|
if calculated_attention_mask is None: |
|
return None |
|
if self.config._attn_implementation == "flash_attention_2": |
|
if attention_mask is not None and 0.0 in attention_mask: |
|
return attention_mask |
|
causal_mask = _prepare_4d_attention_mask( |
|
attention_mask, |
|
dtype=input_tensor.dtype, |
|
) |
|
return causal_mask |
|
|
|
class NVOmniEmbedConfig(Qwen2_5OmniThinkerConfig): |
|
model_type = "nvomniembed" |
|
|
|
class NVOmniEmbedModel(Qwen2_5OmniThinkerForConditionalGeneration): |
|
config_class = NVOmniEmbedConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.model = BidirectQwen2_5OmniThinkerTextModel._from_config( |
|
config.text_config, attn_implementation=config._attn_implementation |
|
) |
|
|