Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Utilities for extracting and manipulating attention weights from transformer models, | |
starting from pre-computed hidden states. | |
This module provides functions to compute attention weights from various transformer | |
models (like Llama, Phi, Qwen, Gemma) and use them for attribution. We compute only | |
the relevant attention weights (as specified by `attribution_start` and | |
`attribution_end`) in order to be able to efficiently compute and store them. If we | |
were to use `output_attentions=True` in the forward pass, we would (1) only be able | |
to use the `eager` attention implementation, and (2) would need to store the entire | |
attention matrix which grows quadratically with the sequence length. Most of the | |
logic here is replicated from the `transformers` library. | |
If you'd like to perform attribution on a model that is not currently supported, | |
you can add it yourself by modifying `infer_model_type` and | |
`get_layer_attention_weights`. Please see `tests/attribution/test_attention.py` | |
to ensure that your implementation matches the expected attention weights when | |
using the `output_attentions=True`. | |
""" | |
import math | |
from typing import Any, Optional | |
import torch as ch | |
import transformers.models | |
def infer_model_type(model): | |
model_type_to_keyword = { | |
"llama": "llama", | |
"phi3": "phi", | |
"qwen2": "qwen", | |
"gemma3": "gemma", | |
} | |
for model_type, keyword in model_type_to_keyword.items(): | |
if keyword in model.name_or_path.lower(): | |
return model_type | |
else: | |
raise ValueError(f"Unknown model: {model.name_or_path}. Specify `model_type`.") | |
def get_helpers(model_type): | |
#for model_name in dir(transformers.models): | |
# if not model_name.startswith('__') and ("gemma" in model_name or "chatglm" in model_name): | |
# print(model_name) | |
if not hasattr(transformers.models, model_type): | |
raise ValueError(f"Unknown model: {model_type}") | |
model_module = getattr(transformers.models, model_type) | |
modeling_module = getattr(model_module, f"modeling_{model_type}") | |
return modeling_module.apply_rotary_pos_emb, modeling_module.repeat_kv | |
def get_position_ids_and_attention_mask(model, hidden_states): | |
input_embeds = hidden_states[0] | |
_, seq_len, _ = input_embeds.shape | |
position_ids = ch.arange(0, seq_len, device=model.device).unsqueeze(0) | |
attention_mask = ch.ones( | |
seq_len, seq_len + 1, device=model.device, dtype=model.dtype | |
) | |
attention_mask = ch.triu(attention_mask, diagonal=1) | |
attention_mask *= ch.finfo(model.dtype).min | |
attention_mask = attention_mask[None, None] | |
return position_ids, attention_mask | |
def get_attentions_shape(model): | |
num_layers = len(model.model.layers) | |
num_heads = model.model.config.num_attention_heads | |
return num_layers, num_heads | |
def get_layer_attention_weights( | |
model, | |
hidden_states, | |
layer_index, | |
position_ids, | |
attention_mask, | |
attribution_start=None, | |
attribution_end=None, | |
model_type=None, | |
): | |
model_type = model_type or infer_model_type(model) | |
assert layer_index >= 0 and layer_index < len(model.model.layers) | |
layer = model.model.layers[layer_index] | |
self_attn = layer.self_attn | |
hidden_states = hidden_states[layer_index] | |
#print("hidden_states_shape: ", hidden_states.shape) | |
hidden_states = layer.input_layernorm(hidden_states) | |
bsz, q_len, _ = hidden_states.size() | |
num_attention_heads = model.model.config.num_attention_heads | |
num_key_value_heads = model.model.config.num_key_value_heads | |
head_dim = self_attn.head_dim | |
if model_type in ("llama", "qwen2", "qwen1.5","gemma3","glm"): | |
query_states = self_attn.q_proj(hidden_states) | |
key_states = self_attn.k_proj(hidden_states) | |
elif model_type in ("phi3",): | |
qkv = self_attn.qkv_proj(hidden_states) | |
query_pos = num_attention_heads * head_dim | |
query_states = qkv[..., :query_pos] | |
key_states = qkv[..., query_pos : query_pos + num_key_value_heads * head_dim] | |
else: | |
raise ValueError(f"Unknown model: {model.name_or_path}") | |
query_states = query_states.view(bsz, q_len, num_attention_heads, head_dim) | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim) | |
key_states = key_states.transpose(1, 2) | |
if model_type in ["gemma3"]: | |
query_states = self_attn.q_norm(query_states) | |
key_states = self_attn.k_norm(key_states) | |
if self_attn.is_sliding: | |
position_embeddings = model.model.rotary_emb_local( | |
hidden_states, position_ids | |
) | |
else: | |
position_embeddings = model.model.rotary_emb(hidden_states, position_ids) | |
else: | |
position_embeddings = model.model.rotary_emb(hidden_states, position_ids) | |
cos, sin = position_embeddings | |
apply_rotary_pos_emb, repeat_kv = get_helpers(model_type) | |
#query_states = query_states.to("cuda:0") | |
#key_states = key_states.to("cuda:0") | |
#cos = cos.to("cuda:0") | |
#sin = sin.to("cuda:0") | |
#print("D1", query_states.device) | |
#print("D2", key_states.device) | |
# print("D3", cos.device) | |
#print("D4", sin.device) | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
key_states = repeat_kv(key_states, self_attn.num_key_value_groups) | |
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] | |
attribution_start = attribution_start if attribution_start is not None else 1 | |
attribution_end = attribution_end if attribution_end is not None else q_len + 1 | |
causal_mask = causal_mask[:, :, attribution_start - 1 : attribution_end - 1] | |
query_states = query_states[:, :, attribution_start - 1 : attribution_end - 1] | |
attn_weights = ch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( | |
head_dim | |
) | |
attn_weights = attn_weights + causal_mask | |
dtype = attn_weights.dtype | |
attn_weights = ch.softmax(attn_weights, dim=-1, dtype=ch.float32).to(dtype) | |
return attn_weights | |
def get_attention_weights( | |
model: Any, | |
hidden_states: Any, | |
attribution_start: Optional[int] = None, | |
attribution_end: Optional[int] = None, | |
model_type: Optional[str] = None, | |
) -> Any: | |
""" | |
Compute the attention weights for the given model and hidden states. | |
Args: | |
model: The model to compute the attention weights for. | |
hidden_states: The pre-computed hidden states. | |
attribution_start: The start index of the tokens we would like to attribute. | |
attribution_end: The end index of the tokens we would like to attribute. | |
model_type: The type of model to compute the attention weights for (each model | |
in the `transformers` library has its own specific attention implementation). | |
""" | |
with ch.no_grad(): | |
position_ids, attention_mask = get_position_ids_and_attention_mask( | |
model, hidden_states | |
) | |
num_layers, num_heads = get_attentions_shape(model) | |
num_tokens = hidden_states[0].shape[1] + 1 | |
attribution_start = attribution_start if attribution_start is not None else 1 | |
attribution_end = attribution_end if attribution_end is not None else num_tokens | |
num_target_tokens = attribution_end - attribution_start | |
weights = ch.zeros( | |
num_layers, | |
num_heads, | |
num_target_tokens, | |
num_tokens - 1, | |
device=model.device, | |
dtype=model.dtype, | |
) | |
for i in range(len(model.model.layers)): | |
cur_weights = get_layer_attention_weights( | |
model, | |
hidden_states, | |
i, | |
position_ids, | |
attention_mask, | |
attribution_start=attribution_start, | |
attribution_end=attribution_end, | |
model_type=model_type, | |
) | |
weights[i, :, :, :] = cur_weights[0] | |
return weights | |
def get_attention_weights_one_layer( | |
model: Any, | |
hidden_states: Any, | |
layer_index: int, | |
attribution_start: Optional[int] = None, | |
attribution_end: Optional[int] = None, | |
model_type: Optional[str] = None, | |
) -> Any: | |
""" | |
Compute the attention weights for the given model and hidden states. | |
Args: | |
model: The model to compute the attention weights for. | |
hidden_states: The pre-computed hidden states. | |
attribution_start: The start index of the tokens we would like to attribute. | |
attribution_end: The end index of the tokens we would like to attribute. | |
model_type: The type of model to compute the attention weights for (each model | |
in the `transformers` library has its own specific attention implementation). | |
""" | |
with ch.no_grad(): | |
position_ids, attention_mask = get_position_ids_and_attention_mask( | |
model, hidden_states | |
) | |
num_layers, num_heads = get_attentions_shape(model) | |
num_tokens = hidden_states[0].shape[1] + 1 | |
attribution_start = attribution_start if attribution_start is not None else 1 | |
attribution_end = attribution_end if attribution_end is not None else num_tokens | |
num_target_tokens = attribution_end - attribution_start | |
weights = ch.zeros( | |
num_layers, | |
num_heads, | |
num_target_tokens, | |
num_tokens - 1, | |
device=model.device, | |
dtype=model.dtype, | |
) | |
weights = get_layer_attention_weights( | |
model, | |
hidden_states, | |
layer_index, | |
position_ids, | |
attention_mask, | |
attribution_start=attribution_start, | |
attribution_end=attribution_end, | |
model_type=model_type, | |
) | |
return weights | |
def get_hidden_states_one_layer( | |
model: Any, | |
hidden_states: Any, | |
layer_index: int, | |
attribution_start: Optional[int] = None, | |
attribution_end: Optional[int] = None, | |
model_type: Optional[str] = None, | |
) -> Any: | |
def get_hidden_states( | |
model, | |
hidden_states, | |
layer_index, | |
position_ids, | |
attention_mask, | |
attribution_start=None, | |
attribution_end=None, | |
model_type=None, | |
): | |
model_type = model_type or infer_model_type(model) | |
assert layer_index >= 0 and layer_index < len(model.model.layers) | |
layer = model.model.layers[layer_index] | |
self_attn = layer.self_attn | |
hidden_states = hidden_states[layer_index] | |
#print("hidden_states_shape: ", hidden_states.shape) | |
hidden_states = layer.input_layernorm(hidden_states) | |
bsz, q_len, _ = hidden_states.size() | |
num_attention_heads = model.model.config.num_attention_heads | |
num_key_value_heads = model.model.config.num_key_value_heads | |
head_dim = self_attn.head_dim | |
if model_type in ("llama", "qwen2", "qwen1.5","gemma3","glm"): | |
query_states = self_attn.q_proj(hidden_states) | |
key_states = self_attn.k_proj(hidden_states) | |
elif model_type in ("phi3",): | |
qkv = self_attn.qkv_proj(hidden_states) | |
query_pos = num_attention_heads * head_dim | |
query_states = qkv[..., :query_pos] | |
key_states = qkv[..., query_pos : query_pos + num_key_value_heads * head_dim] | |
else: | |
raise ValueError(f"Unknown model: {model.name_or_path}") | |
query_states = query_states.view(bsz, q_len, num_attention_heads, head_dim) | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.view(bsz, q_len, num_key_value_heads, head_dim).mean(dim=(0, 2)) | |
return key_states | |
""" | |
Compute the attention weights for the given model and hidden states. | |
Args: | |
model: The model to compute the attention weights for. | |
hidden_states: The pre-computed hidden states. | |
attribution_start: The start index of the tokens we would like to attribute. | |
attribution_end: The end index of the tokens we would like to attribute. | |
model_type: The type of model to compute the attention weights for (each model | |
in the `transformers` library has its own specific attention implementation). | |
""" | |
with ch.no_grad(): | |
position_ids, attention_mask = get_position_ids_and_attention_mask( | |
model, hidden_states | |
) | |
num_layers, num_heads = get_attentions_shape(model) | |
num_tokens = hidden_states[0].shape[1] + 1 | |
attribution_start = attribution_start if attribution_start is not None else 1 | |
attribution_end = attribution_end if attribution_end is not None else num_tokens | |
num_target_tokens = attribution_end - attribution_start | |
weights = ch.zeros( | |
num_layers, | |
num_heads, | |
num_target_tokens, | |
num_tokens - 1, | |
device=model.device, | |
dtype=model.dtype, | |
) | |
hidden_states = get_hidden_states( | |
model, | |
hidden_states, | |
layer_index, | |
position_ids, | |
attention_mask, | |
attribution_start=attribution_start, | |
attribution_end=attribution_end, | |
model_type=model_type, | |
) | |
return hidden_states |