| | """ |
| | Functions in this file are courtesty of @ashen-sensored on GitHub - thankyou so much! <3 |
| | |
| | Used to merge DreamSim LoRA weights into the base ViT models manually, so we don't need |
| | to use an ancient version of PeFT that is no longer supported (and kind of broken) |
| | """ |
| | import logging |
| | from os import PathLike |
| | from pathlib import Path |
| |
|
| | import torch |
| | from safetensors.torch import load_file |
| | from torch import Tensor, nn |
| |
|
| | from .model import DreamsimModel |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @torch.no_grad() |
| | def calculate_merged_weight( |
| | lora_a: Tensor, |
| | lora_b: Tensor, |
| | base: Tensor, |
| | scale: float, |
| | qkv_switches: list[bool], |
| | ) -> Tensor: |
| | n_switches = len(qkv_switches) |
| | n_groups = sum(qkv_switches) |
| |
|
| | qkv_mask = torch.tensor(qkv_switches, dtype=torch.bool).reshape(len(qkv_switches), -1) |
| | qkv_mask = qkv_mask.broadcast_to((-1, base.shape[0] // n_switches)).reshape(-1) |
| |
|
| | lora_b = lora_b.squeeze() |
| | delta_w = base.new_zeros(lora_b.shape[0], base.shape[1]) |
| |
|
| | grp_in_ch = lora_a.shape[0] // n_groups |
| | grp_out_ch = lora_b.shape[0] // n_groups |
| | for i in range(n_groups): |
| | islice = slice(i * grp_in_ch, (i + 1) * grp_in_ch) |
| | oslice = slice(i * grp_out_ch, (i + 1) * grp_out_ch) |
| | delta_w[oslice, :] = lora_b[oslice, :] @ lora_a[islice, :] |
| |
|
| | delta_w_full = base.new_zeros(base.shape) |
| | delta_w_full[qkv_mask, :] = delta_w |
| |
|
| | merged = base + scale * delta_w_full |
| | return merged.to(base) |
| |
|
| |
|
| | @torch.no_grad() |
| | def merge_dreamsim_lora( |
| | base_model: nn.Module, |
| | lora_path: PathLike, |
| | torch_device: torch.device | str = torch.device("cpu"), |
| | ): |
| | lora_path = Path(lora_path) |
| | |
| | base_model = base_model.eval().requires_grad_(False).to(torch_device) |
| |
|
| | |
| | if lora_path.suffix.lower() in [".pt", ".pth", ".bin"]: |
| | lora_sd = torch.load(lora_path, map_location=torch_device, weights_only=True) |
| | elif lora_path.suffix.lower() == ".safetensors": |
| | lora_sd = load_file(lora_path) |
| | else: |
| | raise ValueError(f"Unsupported file extension '{lora_path.suffix}'") |
| |
|
| | |
| | group_prefix = "base_model.model.base_model.model.model." |
| | |
| | group_weights = {k.replace(group_prefix, ""): v for k, v in lora_sd.items() if k.startswith(group_prefix)} |
| | |
| | group_layers = set([k.rsplit(".", 2)[0] for k in group_weights.keys()]) |
| |
|
| | base_weights = base_model.state_dict() |
| | for key in [x for x in base_weights.keys() if "attn.qkv.weight" in x]: |
| | param_name = key.rsplit(".", 1)[0] |
| | if param_name not in group_layers: |
| | logger.warning(f"QKV param '{param_name}' not found in lora weights") |
| | continue |
| | new_weight = calculate_merged_weight( |
| | group_weights[f"{param_name}.lora_A.weight"], |
| | group_weights[f"{param_name}.lora_B.weight"], |
| | base_weights[key], |
| | 0.5 / 16, |
| | [True, False, True], |
| | ) |
| | base_weights[key] = new_weight |
| |
|
| | base_model.load_state_dict(base_weights) |
| | return base_model.requires_grad_(False) |
| |
|
| |
|
| | def remap_clip(state_dict: dict[str, Tensor], variant: str) -> dict[str, Tensor]: |
| | """Remap keys from the original DreamSim checkpoint to match new model structure.""" |
| |
|
| | def prepend_extractor(state_dict: dict[str, Tensor]) -> dict[str, Tensor]: |
| | if variant.endswith("single"): |
| | return {f"extractor.{k}": v for k, v in state_dict.items()} |
| | return state_dict |
| |
|
| | if "clip" not in variant: |
| | return prepend_extractor(state_dict) |
| |
|
| | if "patch_embed.proj.bias" in state_dict: |
| | _ = state_dict.pop("patch_embed.proj.bias", None) |
| | if "pos_drop.weight" in state_dict: |
| | state_dict["norm_pre.weight"] = state_dict.pop("pos_drop.weight") |
| | state_dict["norm_pre.bias"] = state_dict.pop("pos_drop.bias") |
| | if "head.weight" in state_dict and "head.bias" not in state_dict: |
| | state_dict["head.bias"] = torch.zeros(state_dict["head.weight"].shape[0]) |
| |
|
| | return prepend_extractor(state_dict) |
| |
|
| |
|
| | def convert_dreamsim_single( |
| | ckpt_path: PathLike, |
| | variant: str, |
| | ensemble: bool = False, |
| | ) -> DreamsimModel: |
| | ckpt_path = Path(ckpt_path) |
| | if ckpt_path.exists(): |
| | if ckpt_path.is_dir(): |
| | ckpt_path = ckpt_path.joinpath("ensemble" if ensemble else variant) |
| | ckpt_path = ckpt_path.joinpath(f"{variant}_merged.safetensors") |
| |
|
| | |
| | patch_size = 16 |
| | layer_norm_eps = 1e-6 |
| | pre_norm = False |
| | act_layer = "gelu" |
| |
|
| | match variant: |
| | case "open_clip_vitb16" | "open_clip_vitb32" | "clip_vitb16" | "clip_vitb32": |
| | patch_size = 32 if "b32" in variant else 16 |
| | layer_norm_eps = 1e-5 |
| | pre_norm = True |
| | img_mean = (0.48145466, 0.4578275, 0.40821073) |
| | img_std = (0.26862954, 0.26130258, 0.27577711) |
| | act_layer = "quick_gelu" if variant.startswith("clip_") else "gelu" |
| | case "dino_vitb16": |
| | img_mean = (0.485, 0.456, 0.406) |
| | img_std = (0.229, 0.224, 0.225) |
| | case _: |
| | raise NotImplementedError(f"Unsupported model variant '{variant}'") |
| |
|
| | model: DreamsimModel = DreamsimModel( |
| | image_size=224, |
| | patch_size=patch_size, |
| | layer_norm_eps=layer_norm_eps, |
| | pre_norm=pre_norm, |
| | act_layer=act_layer, |
| | img_mean=img_mean, |
| | img_std=img_std, |
| | ) |
| | state_dict = load_file(ckpt_path, device="cpu") |
| | state_dict = remap_clip(state_dict) |
| | model.extractor.load_state_dict(state_dict) |
| | return model |
| |
|