Update modeling_cogvlm.py: remove the dependence of triton
Browse files- modeling_cogvlm.py +57 -6
modeling_cogvlm.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, List, Union, Literal, Dict, A
|
|
| 5 |
import math
|
| 6 |
import torch
|
| 7 |
from torch import nn
|
|
|
|
| 8 |
from torch.nn import CrossEntropyLoss
|
| 9 |
from torchvision import transforms
|
| 10 |
from einops import rearrange
|
|
@@ -15,7 +16,6 @@ from transformers.activations import ACT2FN
|
|
| 15 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 16 |
|
| 17 |
from .configuration_cogvlm import CogVLMConfig
|
| 18 |
-
from .util import FastRotaryEmbedding
|
| 19 |
from .visual import EVA2CLIPModel
|
| 20 |
|
| 21 |
if TYPE_CHECKING:
|
|
@@ -144,6 +144,57 @@ def attention_fn(
|
|
| 144 |
return context_layer
|
| 145 |
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
class VisionExpertAttention(nn.Module):
|
| 148 |
def __init__(self, config):
|
| 149 |
super().__init__()
|
|
@@ -153,8 +204,7 @@ class VisionExpertAttention(nn.Module):
|
|
| 153 |
self.head_dim = self.hidden_size // self.num_heads
|
| 154 |
self.max_position_embeddings = config.max_position_embeddings
|
| 155 |
|
| 156 |
-
|
| 157 |
-
self.rotary_emb = FastRotaryEmbedding(dim=self.head_dim, pos_idx_in_fp32=False)
|
| 158 |
self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
|
| 159 |
self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
| 160 |
self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
|
|
@@ -193,8 +243,8 @@ class VisionExpertAttention(nn.Module):
|
|
| 193 |
kv_seq_len = key_states.shape[-2]
|
| 194 |
if past_key_value is not None:
|
| 195 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 196 |
-
|
| 197 |
-
query_states, key_states =
|
| 198 |
|
| 199 |
if past_key_value is not None:
|
| 200 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
@@ -706,7 +756,8 @@ class CogVLMForCausalLM(CogVLMPreTrainedModel):
|
|
| 706 |
# update token_type_ids with last value
|
| 707 |
if "token_type_ids" in model_kwargs:
|
| 708 |
token_type_ids = model_kwargs["token_type_ids"]
|
| 709 |
-
new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
|
|
|
|
| 710 |
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
|
| 711 |
|
| 712 |
if not is_encoder_decoder:
|
|
|
|
| 5 |
import math
|
| 6 |
import torch
|
| 7 |
from torch import nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
from torch.nn import CrossEntropyLoss
|
| 10 |
from torchvision import transforms
|
| 11 |
from einops import rearrange
|
|
|
|
| 16 |
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
| 17 |
|
| 18 |
from .configuration_cogvlm import CogVLMConfig
|
|
|
|
| 19 |
from .visual import EVA2CLIPModel
|
| 20 |
|
| 21 |
if TYPE_CHECKING:
|
|
|
|
| 144 |
return context_layer
|
| 145 |
|
| 146 |
|
| 147 |
+
class RotaryEmbedding(torch.nn.Module):
|
| 148 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 149 |
+
super().__init__()
|
| 150 |
+
|
| 151 |
+
self.dim = dim
|
| 152 |
+
self.max_position_embeddings = max_position_embeddings
|
| 153 |
+
self.base = base
|
| 154 |
+
inv_freq = self._compute_inv_freq(device)
|
| 155 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 156 |
+
self.max_seq_len_cached = 0
|
| 157 |
+
|
| 158 |
+
def _compute_inv_freq(self, device=None):
|
| 159 |
+
return 1.0 / (
|
| 160 |
+
self.base
|
| 161 |
+
** (torch.arange(0, self.dim, 2, device=device) / self.dim)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 165 |
+
self.max_seq_len_cached = seq_len
|
| 166 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 167 |
+
|
| 168 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 169 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
| 170 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 171 |
+
self.register_buffer("cos_cached", emb.cos()[:, None, :].to(dtype), persistent=False)
|
| 172 |
+
self.register_buffer("sin_cached", emb.sin()[:, None, :].to(dtype), persistent=False)
|
| 173 |
+
|
| 174 |
+
def forward(self, x, seq_len):
|
| 175 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
| 176 |
+
if seq_len > self.max_seq_len_cached:
|
| 177 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 178 |
+
|
| 179 |
+
return (
|
| 180 |
+
self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
|
| 181 |
+
self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def rotate_half(x):
|
| 186 |
+
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
| 187 |
+
return torch.cat((-x2, x1), dim=x1.ndim - 1)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def apply_rotary_pos_emb_index_bhs(q, k, cos, sin, position_id):
|
| 191 |
+
# batch_size, num_head, seq_len, hidden_size
|
| 192 |
+
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(1), \
|
| 193 |
+
F.embedding(position_id, sin.squeeze(1)).unsqueeze(1)
|
| 194 |
+
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 195 |
+
return q, k
|
| 196 |
+
|
| 197 |
+
|
| 198 |
class VisionExpertAttention(nn.Module):
|
| 199 |
def __init__(self, config):
|
| 200 |
super().__init__()
|
|
|
|
| 204 |
self.head_dim = self.hidden_size // self.num_heads
|
| 205 |
self.max_position_embeddings = config.max_position_embeddings
|
| 206 |
|
| 207 |
+
self.rotary_emb = RotaryEmbedding(self.head_dim)
|
|
|
|
| 208 |
self.vision_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
|
| 209 |
self.vision_expert_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
| 210 |
self.language_expert_query_key_value = nn.Linear(self.hidden_size, self.hidden_size * 3, bias=False)
|
|
|
|
| 243 |
kv_seq_len = key_states.shape[-2]
|
| 244 |
if past_key_value is not None:
|
| 245 |
kv_seq_len += past_key_value[0].shape[-2]
|
| 246 |
+
cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1)
|
| 247 |
+
query_states, key_states = apply_rotary_pos_emb_index_bhs(query_states, key_states, cos, sin, position_ids)
|
| 248 |
|
| 249 |
if past_key_value is not None:
|
| 250 |
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
|
|
| 756 |
# update token_type_ids with last value
|
| 757 |
if "token_type_ids" in model_kwargs:
|
| 758 |
token_type_ids = model_kwargs["token_type_ids"]
|
| 759 |
+
new_token_type_ids = torch.ones(size=(token_type_ids.shape[0], 1), dtype=token_type_ids.dtype,
|
| 760 |
+
device=token_type_ids.device) * LANGUAGE_TOKEN_TYPE
|
| 761 |
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, new_token_type_ids], dim=-1)
|
| 762 |
|
| 763 |
if not is_encoder_decoder:
|