Use input attention mask instead of casual mask in attention
#101
by
CyberZHG
- opened
- modelling_RW.py +2 -2
modelling_RW.py
CHANGED
|
@@ -281,13 +281,14 @@ class Attention(nn.Module):
|
|
| 281 |
else:
|
| 282 |
present = None
|
| 283 |
|
|
|
|
| 284 |
if alibi is None:
|
| 285 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 286 |
key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 287 |
value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 288 |
|
| 289 |
attn_output = F.scaled_dot_product_attention(
|
| 290 |
-
query_layer_, key_layer_, value_layer_,
|
| 291 |
)
|
| 292 |
|
| 293 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
|
@@ -300,7 +301,6 @@ class Attention(nn.Module):
|
|
| 300 |
assert not output_attentions # not supported.
|
| 301 |
return outputs
|
| 302 |
else:
|
| 303 |
-
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
|
| 304 |
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
| 305 |
|
| 306 |
# change view to [batch_size, num_heads, q_length, kv_length]
|
|
|
|
| 281 |
else:
|
| 282 |
present = None
|
| 283 |
|
| 284 |
+
attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(query_layer.dtype)
|
| 285 |
if alibi is None:
|
| 286 |
query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 287 |
key_layer_ = key_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 288 |
value_layer_ = value_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
| 289 |
|
| 290 |
attn_output = F.scaled_dot_product_attention(
|
| 291 |
+
query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
|
| 292 |
)
|
| 293 |
|
| 294 |
x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
|
|
|
|
| 301 |
assert not output_attentions # not supported.
|
| 302 |
return outputs
|
| 303 |
else:
|
|
|
|
| 304 |
matmul_result = query_layer @ key_layer.transpose(-1, -2)
|
| 305 |
|
| 306 |
# change view to [batch_size, num_heads, q_length, kv_length]
|