File size: 10,928 Bytes
7f37159 6fc5e4e 7f37159 6fc5e4e 7f37159 6fc5e4e 7f37159 6fc5e4e 7f37159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
import argparse
import json
from pathlib import Path
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
from safetensors.flax import save_file
from tqdm import tqdm
SIGLIP_PREFIX = "SigLiPFromPatches_0/siglip_encoder"
def flatten(x: np.ndarray, start: int = 0, end: int = -1):
if start < 0:
start += x.ndim
if end < 0:
end += x.ndim
new_shape = x.shape[:start] + (-1,) + x.shape[end + 1 :]
return x.reshape(new_shape)
def unflatten(x: np.ndarray, dim: int, sizes: tuple[int, ...]):
new_shape = x.shape[:dim] + tuple(sizes) + x.shape[dim + 1 :]
return x.reshape(new_shape)
# correct quantization parameters mean quantization error = 0 (or close to 0)
def check_groups(groups: np.ndarray, scales: np.ndarray, dim: int):
# groups: (a, b, c, 32, d, e, f)
# scales: (a, b, c, 1, d, e, f)
inv_scale = 1.0 / scales.clip(1e-12)
q_group = np.round(groups * inv_scale)
max_diff = np.abs(q_group * scales - groups).max(dim, keepdims=True)
return max_diff < 1e-6, max_diff
def find_scales(w: np.ndarray, dim: int):
w = unflatten(w, dim, (-1, 32))
group_range = w.max(dim + 1, keepdims=True) - w.min(dim + 1, keepdims=True)
scales = np.zeros_like(group_range)
for q in range(15, 0, -1):
try_scale = group_range / q
ok, _ = check_groups(w, try_scale, dim + 1)
scales[ok] = try_scale[ok]
ok, _ = check_groups(w, scales, dim + 1)
assert ok.all()
return scales.squeeze(dim + 1)
def convert_siglip(params, num_layers: int):
state_dict = dict()
def convert_layer(prefix: str, layer: dict[str, np.ndarray]):
bias = layer["bias"]
if "kernel" in layer:
w = layer["kernel"]
if w.ndim == 2: # linear layer
w = w.T
elif w.ndim == 3: # attn projection
# qkv projection - (dim, num_heads, head_dim)
if bias.ndim == 2:
w = flatten(w, 1, 2).T
bias = bias.reshape(-1)
# o projection - (num_heads, head_dim, dim)
elif bias.ndim == 1:
w = flatten(w, 0, 1).T
elif w.ndim == 4: # conv2d layer
w = w.transpose(3, 2, 0, 1)
else:
raise RuntimeError(f"Unsupported {w.shape=}")
elif "scale" in layer: # layer norm
w = layer["scale"]
else:
raise RuntimeError
state_dict[f"{prefix}weight"] = w
state_dict[f"{prefix}bias"] = bias
convert_layer("embeddings.patch_embedding.", params[f"{SIGLIP_PREFIX}/embedding"])
state_dict["embeddings.position_embedding.weight"] = params[SIGLIP_PREFIX]["pos_embedding"].squeeze(0)
convert_layer("post_layernorm.", params[f"{SIGLIP_PREFIX}/Transformer/encoder_norm"])
for layer_idx in range(num_layers):
prefix = f"encoder.layers.{layer_idx}."
layer_prefix = f"{SIGLIP_PREFIX}/Transformer/encoderblock_{layer_idx}/"
convert_layer(f"{prefix}layer_norm1.", params[f"{layer_prefix}LayerNorm_0"])
convert_layer(f"{prefix}layer_norm2.", params[f"{layer_prefix}LayerNorm_1"])
attn_prefix = f"{layer_prefix}MultiHeadDotProductAttention_0/"
convert_layer(f"{prefix}self_attn.q_proj.", params[f"{attn_prefix}query"])
convert_layer(f"{prefix}self_attn.k_proj.", params[f"{attn_prefix}key"])
convert_layer(f"{prefix}self_attn.v_proj.", params[f"{attn_prefix}value"])
convert_layer(f"{prefix}self_attn.out_proj.", params[f"{attn_prefix}out"])
mlp_prefix = f"{layer_prefix}MlpBlock_0/"
convert_layer(f"{prefix}mlp.fc1.", params[f"{mlp_prefix}Dense_0"])
convert_layer(f"{prefix}mlp.fc2.", params[f"{mlp_prefix}Dense_1"])
return state_dict
# convert to HF format first, then apply quantization
def convert_to_hf(path: Path):
path = path.absolute() # orbax only works with absolute path
ckpt = ocp.StandardCheckpointer()
metadata = dict(ckpt.metadata(path))
metadata = jax.tree.map(ocp.utils.to_shape_dtype_struct, metadata)
num_layers = num_siglip_layers = 0
while f"transformer/layer_{num_layers}/attn/_key_norm" in metadata:
num_layers += 1
while f"{SIGLIP_PREFIX}/Transformer/encoderblock_{num_siglip_layers}/LayerNorm_0" in metadata:
num_siglip_layers += 1
print(f"{num_layers=}")
print(f"{num_siglip_layers=}")
# NOTE: all gemma3 models use tied embeddings, even for the 27B version.
params = ckpt.restore(path)
state_dict = dict()
if num_siglip_layers > 0:
# HF append unused tokens for no reason???
embed = params["transformer/embedder"]["input_embedding"]
params["transformer/embedder"]["input_embedding"] = np.pad(embed, ((0, 64), (0, 0)))
gemma_prefix = "language_model."
prefix = "multi_modal_projector.mm_"
jax_prefix = "transformer/embedder/"
state_dict[f"{prefix}input_projection_weight"] = params[f"{jax_prefix}mm_input_projection"]["w"]
state_dict[f"{prefix}soft_emb_norm.weight"] = params[f"{jax_prefix}mm_soft_embedding_norm"]["scale"]
else:
gemma_prefix = ""
state_dict[f"{gemma_prefix}model.embed_tokens.weight"] = params["transformer/embedder"]["input_embedding"]
state_dict[f"{gemma_prefix}model.norm.weight"] = params["transformer/final_norm"]["scale"]
yield state_dict
for layer_idx in range(num_layers):
jax_prefix = f"transformer/layer_{layer_idx}/"
state_dict = dict()
prefix = f"{gemma_prefix}model.layers.{layer_idx}."
state_dict[f"{prefix}input_layernorm.weight"] = params[f"{jax_prefix}pre_attention_norm"]["scale"]
state_dict[f"{prefix}post_attention_layernorm.weight"] = params[f"{jax_prefix}post_attention_norm"]["scale"]
state_dict[f"{prefix}pre_feedforward_layernorm.weight"] = params[f"{jax_prefix}pre_ffw_norm"]["scale"]
state_dict[f"{prefix}post_feedforward_layernorm.weight"] = params[f"{jax_prefix}post_ffw_norm"]["scale"]
prefix = f"{gemma_prefix}model.layers.{layer_idx}.self_attn."
jax_prefix = f"transformer/layer_{layer_idx}/attn/"
state_dict[f"{prefix}q_norm.weight"] = params[f"{jax_prefix}_query_norm"]["scale"]
state_dict[f"{prefix}k_norm.weight"] = params[f"{jax_prefix}_key_norm"]["scale"]
# (num_heads, hidden_size, head_dim) -> (num_heads * head_dim, hidden_size)
state_dict[f"{prefix}q_proj.weight"] = flatten(params[f"{jax_prefix}q_einsum"]["w"].transpose(0, 2, 1), end=1)
state_dict[f"{prefix}k_proj.weight"] = flatten(
params[f"{jax_prefix}kv_einsum"]["w"][0].transpose(0, 2, 1), end=1
)
state_dict[f"{prefix}v_proj.weight"] = flatten(
params[f"{jax_prefix}kv_einsum"]["w"][1].transpose(0, 2, 1), end=1
)
# (num_heads, head_dim, hidden_size) -> (hidden_size, num_heads * head_dim)
state_dict[f"{prefix}o_proj.weight"] = flatten(params[f"{jax_prefix}attn_vec_einsum"]["w"], end=1).T
prefix = f"{gemma_prefix}model.layers.{layer_idx}.mlp."
jax_prefix = f"transformer/layer_{layer_idx}/mlp/"
state_dict[f"{prefix}gate_proj.weight"] = params[f"{jax_prefix}gating_einsum"]["w"][0]
state_dict[f"{prefix}up_proj.weight"] = params[f"{jax_prefix}gating_einsum"]["w"][1]
state_dict[f"{prefix}down_proj.weight"] = params[f"{jax_prefix}linear"]["w"].T
yield state_dict
# vision tower
if num_siglip_layers > 0:
siglip_state_dict = convert_siglip(params, num_siglip_layers)
for k, v in siglip_state_dict.items():
state_dict[f"vision_tower.vision_model.{k}"] = v
yield state_dict
def convert_awq(state_dict: dict[str, np.ndarray]):
awq_state_dict = dict()
for k, v in state_dict.items():
if (
k.endswith("model.embed_tokens.weight") # AWQ doesn't support INT4 embeddings
or k.startswith(("vision_tower", "multi_modal_projector")) # vision tower is not quantized
or v.ndim == 1
):
awq_state_dict[k] = v.astype(jnp.bfloat16)
continue
assert v.ndim == 2
v = v.T # AWQ transpose the weight
K, N = v.shape
scales = find_scales(v, dim=0) # (K/32, N)
inv_scale = 1 / scales.clip(1e-12)
qweight = np.round(v.reshape(K // 32, 32, N) * inv_scale[:, None])
# AWQ is actually UINT4 (instead of INT4)
# hence, we will shift qweight up by 8 (even though Google AQT only uses [-7,7])
# and set zero_point = 8
qweight = (qweight + 8).astype(np.uint32)
# AWQ pack 8 int4 into UINT32 in the following layout (from high bits to low bits)
# [7 5 3 1 6 4 2 0] along the 2nd dim
qweight = qweight.reshape(K, N // 8, 8)
qweight_packed = (
(qweight[..., 7] << (7 * 4))
| (qweight[..., 5] << (6 * 4))
| (qweight[..., 3] << (5 * 4))
| (qweight[..., 1] << (4 * 4))
| (qweight[..., 6] << (3 * 4))
| (qweight[..., 4] << (2 * 4))
| (qweight[..., 2] << (1 * 4))
| (qweight[..., 0] << (0 * 4))
)
qweight_packed = qweight_packed.view(np.int32).reshape(K, N // 8)
prefix = k.removesuffix(".weight")
awq_state_dict[f"{prefix}.qweight"] = qweight_packed
awq_state_dict[f"{prefix}.qzeros"] = np.full((K // 32, N // 8), 0x8888_8888, dtype=np.uint32).view(np.int32)
awq_state_dict[f"{prefix}.scales"] = scales.astype(jnp.bfloat16)
return awq_state_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_dir", required=True, type=Path)
parser.add_argument("--save_dir", required=True, type=Path)
args = parser.parse_args()
args.save_dir.mkdir(parents=True, exist_ok=True)
total_size = 0
weight_map = dict()
state_dict = dict()
size = 0
shard_idx = 0
filename = f"model-{shard_idx + 1:05d}.safetensors"
for sub_state_dict in tqdm(convert_to_hf(args.ckpt_dir)):
sub_state_dict = convert_awq(sub_state_dict)
new_size = sum(v.nbytes for v in sub_state_dict.values())
if size + new_size > 5e9:
save_file(state_dict, args.save_dir / filename)
state_dict = dict()
size = 0
shard_idx += 1
filename = f"model-{shard_idx + 1:05d}.safetensors"
# assume that new_size < 5e9
size += new_size
total_size += new_size
for k, v in sub_state_dict.items():
state_dict[k] = v
weight_map[k] = filename
save_file(state_dict, args.save_dir / filename)
json.dump(
dict(metadata=dict(total_size=total_size), weight_map=weight_map),
open(args.save_dir / "model.safetensors.index.json", "w"),
)
|