Spaces:
No application file
No application file
| import json | |
| import os | |
| import time | |
| from typing import Literal, Optional, Union, List, Tuple | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from torch import nn | |
| from transformers import ( | |
| CLIPVisionModelWithProjection, | |
| CLIPImageProcessor, | |
| AutoProcessor, | |
| ) | |
| import h5py | |
| import torch | |
| import numpy as np | |
| from ...data.extract_feature.base_extract_feature import BaseFeatureExtractor | |
| from ...data.emb.h5py_emb import save_value_with_h5py | |
| from ..process.image_process import dynamic_crop_resize_image | |
| from ..utils.data_type_util import convert_images | |
| __all__ = [ | |
| "ImageClipVisionFeatureExtractor", | |
| "ImageClipVisionFeatureExtractorV2", | |
| "ImageClipVisionFeatureExtractorV3", | |
| "ImageClipVisionFeatureExtractorV4", | |
| "VerstailSDLastHiddenState2ImageEmb", | |
| "OriginLastHiddenState2ImageEmbd", | |
| "OriginLastHiddenState2Poolout", | |
| ] | |
| class ImageClipVisionFeatureExtractor(BaseFeatureExtractor): | |
| """选择clip的image_embeds,一张图像的输出特征是N,根据模型的选择可能是512、768、1024 | |
| Args: | |
| BaseFeatureExtractor (_type_): _description_ | |
| """ | |
| def __init__( | |
| self, | |
| pretrained_model_name_or_path: str, | |
| name: str = None, | |
| device: str = "cpu", | |
| dtype=torch.float32, | |
| ): | |
| super().__init__(device, dtype, name) | |
| self.pretrained_model_name_or_path = pretrained_model_name_or_path | |
| # 保持和 ipadapter 一致 | |
| self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( | |
| pretrained_model_name_or_path | |
| ).to(device=device, dtype=dtype) | |
| # TODO: 存在多种初始化代码,待后续统一 | |
| if os.path.isdir(pretrained_model_name_or_path): | |
| self.clip_image_processor = CLIPImageProcessor() | |
| else: | |
| self.clip_image_processor = AutoProcessor.from_pretrained( | |
| pretrained_model_name_or_path | |
| ) | |
| def extract_images( | |
| self, | |
| data: Union[str, List[str], Image.Image, List[Image.Image], np.ndarray], | |
| target_width: int = None, | |
| target_height: int = None, | |
| return_type: str = "numpy", | |
| input_rgb_order: str = "rgb", | |
| ) -> Union[np.ndarray, torch.Tensor]: | |
| data = convert_images(data, return_type="pil", input_rgb_order=input_rgb_order) | |
| if target_height is not None and target_width is not None: | |
| data = [ | |
| dynamic_crop_resize_image( | |
| image, | |
| target_height=target_height, | |
| target_width=target_width, | |
| ) | |
| for image in data | |
| ] | |
| with torch.no_grad(): | |
| clip_image = self.clip_image_processor( | |
| images=data, return_tensors="pt" | |
| ).pixel_values | |
| emb = self.get_target_emb( | |
| clip_image.to(device=self.device, dtype=self.dtype) | |
| ) | |
| if return_type == "numpy": | |
| emb = emb.cpu().numpy() | |
| return emb | |
| def get_target_emb(self, data): | |
| outputs = self.image_encoder(data).image_embeds | |
| return outputs | |
| def extract_video( | |
| self, | |
| video_dataset, | |
| target_width: int = None, | |
| target_height: int = None, | |
| return_type: str = "numpy", | |
| track_performance: bool = False, | |
| input_rgb_order: str = "rgb", | |
| ) -> Union[np.ndarray, torch.Tensor]: | |
| embs = [] | |
| sample_indexs = [] | |
| if track_performance: | |
| performance = {} | |
| with torch.no_grad(): | |
| for i, (batch, batch_index) in enumerate(video_dataset): | |
| # TODO: 现阶段复用hugging face diffusers img2img pipeline中的抽取代码, | |
| # 由于该代码目前只支持Image的预处理,故先将numpy.ndarray转换成PIL.Image | |
| batch = [Image.fromarray(batch[b_i]) for b_i in range(len(batch))] | |
| emb = self.extract_images( | |
| data=batch, | |
| target_width=target_width, | |
| target_height=target_height, | |
| return_type=return_type, | |
| input_rgb_order=input_rgb_order, | |
| ) | |
| embs.append(emb) | |
| sample_indexs.extend(batch_index) | |
| sample_indexs = np.array(sample_indexs) | |
| if return_type == "numpy": | |
| embs = np.concatenate(embs, axis=0) | |
| elif return_type == "torch": | |
| embs = torch.concat(embs) | |
| sample_indexs = torch.from_numpy(sample_indexs) | |
| return sample_indexs, embs | |
| def extract( | |
| self, | |
| data: Union[str, List[str]], | |
| data_type: Literal["image", "video"], | |
| return_type: str = "numpy", | |
| save_emb_path: str = None, | |
| save_type: str = "h5py", | |
| emb_key: str = "image_embeds", | |
| sample_index_key: str = "sample_indexs", | |
| insert_name_to_key: bool = False, | |
| overwrite: bool = False, | |
| input_rgb_order: str = "rgb", | |
| save_sample_index: bool = True, | |
| **kwargs, | |
| ) -> Union[np.ndarray, torch.tensor]: | |
| if self.name is not None and insert_name_to_key: | |
| emb_key = f"{self.name}_{emb_key}" | |
| sample_index_key = f"{self.name}_{sample_index_key}" | |
| if save_emb_path is not None and os.path.exists(save_emb_path): | |
| with h5py.File(save_emb_path, "r") as f: | |
| if not overwrite and emb_key in f and sample_index_key in f: | |
| return None | |
| if data_type == "image": | |
| emb = self.extract_images( | |
| data=data, | |
| return_type=return_type, | |
| input_rgb_order=input_rgb_order, | |
| **kwargs, | |
| ) | |
| if save_emb_path is None: | |
| return emb | |
| else: | |
| raise NotImplementedError("save images emb") | |
| elif data_type == "video": | |
| sample_indexs, emb = self.extract_video( | |
| video_dataset=data, | |
| return_type=return_type, | |
| input_rgb_order=input_rgb_order, | |
| **kwargs, | |
| ) | |
| if save_emb_path is None: | |
| return sample_indexs, emb | |
| else: | |
| if save_type == "h5py": | |
| self.save_video_emb_with_h5py( | |
| save_emb_path=save_emb_path, | |
| emb=emb, | |
| emb_key=emb_key, | |
| sample_indexs=sample_indexs, | |
| sample_index_key=sample_index_key, | |
| overwrite=overwrite, | |
| save_sample_index=save_sample_index, | |
| ) | |
| return sample_indexs, emb | |
| else: | |
| raise ValueError(f"only support save_type={save_type}") | |
| def save_images_emb_with_h5py( | |
| save_emb_path: str, | |
| emb: np.ndarray = None, | |
| emb_key: str = "image_embeds", | |
| ) -> h5py.File: | |
| save_value_with_h5py(save_emb_path, value=emb, key=emb_key) | |
| def save_video_emb_with_h5py( | |
| save_emb_path: str, | |
| emb: np.ndarray = None, | |
| emb_key: str = "image_embeds", | |
| sample_indexs: np.ndarray = None, | |
| sample_index_key: str = "sample_indexs", | |
| overwrite: bool = False, | |
| save_sample_index: bool = True, | |
| ) -> h5py.File: | |
| save_value_with_h5py( | |
| save_emb_path, | |
| value=emb, | |
| key=emb_key, | |
| overwrite=overwrite, | |
| dtype=np.float16, | |
| ) | |
| if save_sample_index: | |
| save_value_with_h5py( | |
| save_emb_path, | |
| value=sample_indexs, | |
| key=sample_index_key, | |
| overwrite=overwrite, | |
| dtype=np.uint32, | |
| ) | |
| class ImageClipVisionFeatureExtractorV2(ImageClipVisionFeatureExtractor): | |
| """选择clip的 hidden_states[-2],一张图像的输出特征是M*D,如257*1280, | |
| Args: | |
| BaseFeatureExtractor (_type_): _description_ | |
| """ | |
| def __init__( | |
| self, | |
| pretrained_model_name_or_path: str, | |
| name: str = None, | |
| device: str = "cpu", | |
| dtype=torch.float32, | |
| ): | |
| super().__init__(pretrained_model_name_or_path, name, device, dtype) | |
| def get_target_emb(self, data): | |
| outputs = self.image_encoder(data, output_hidden_states=True).hidden_states[-2] | |
| return outputs | |
| class ImageClipVisionFeatureExtractorV3(ImageClipVisionFeatureExtractor): | |
| """选择clip的 hidden_states[-2],一张图像的输出特征是M*D,如257*1280, | |
| Args: | |
| BaseFeatureExtractor (_type_): _description_ | |
| """ | |
| def __init__( | |
| self, | |
| pretrained_model_name_or_path: str, | |
| name: str = None, | |
| device: str = "cpu", | |
| dtype=torch.float32, | |
| ): | |
| super().__init__(pretrained_model_name_or_path, name, device, dtype) | |
| def get_target_emb(self, data): | |
| outputs = self.image_encoder(data, output_hidden_states=True).last_hidden_state | |
| return outputs | |
| class ImageClipVisionFeatureExtractorV4(ImageClipVisionFeatureExtractor): | |
| """ | |
| 参考 https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py#L114 | |
| Args: | |
| BaseFeatureExtractor (_type_): _description_ | |
| """ | |
| def __init__( | |
| self, | |
| pretrained_model_name_or_path: str, | |
| name: str = None, | |
| device: str = "cpu", | |
| dtype=torch.float32, | |
| ): | |
| super().__init__(pretrained_model_name_or_path, name, device, dtype) | |
| def get_target_emb(self, data): | |
| encoder_output = self.image_encoder(data, output_hidden_states=True) | |
| embeds = self.image_encoder.vision_model.post_layernorm( | |
| encoder_output.last_hidden_state | |
| ) | |
| embeds = self.image_encoder.visual_projection(embeds) | |
| embeds_pooled = embeds[:, 0:1] | |
| embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) | |
| return embeds | |
| class OriginLastHiddenState2Poolout(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| projection_dim: int, | |
| layer_norm_eps: float, | |
| ): | |
| super().__init__() | |
| self.post_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) | |
| self.visual_projection = nn.Linear(hidden_size, projection_dim, bias=False) | |
| def load_state_dict_from_pretrained(self, pretrained_model_name_or_path): | |
| model_pretrained = torch.load( | |
| os.path.join(pretrained_model_name_or_path, "pytorch_model.bin"), | |
| map_location="cpu", | |
| ) | |
| post_layernorm_params = { | |
| k.replace("vision_model.post_layernorm.", ""): v | |
| for k, v in model_pretrained.items() | |
| if "vision_model.post_layernorm." in k | |
| } | |
| self.post_layernorm.load_state_dict(post_layernorm_params) | |
| visual_projection_params = { | |
| k.replace("visual_projection.", ""): v | |
| for k, v in model_pretrained.items() | |
| if "visual_projection." in k | |
| } | |
| self.visual_projection.load_state_dict(visual_projection_params) | |
| def from_pretrained( | |
| cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs | |
| ): | |
| cfg_path = os.path.join(pretrained_model_name_or_path, "config.json") | |
| with open(cfg_path, "r") as f: | |
| config = json.load(f) | |
| model = cls( | |
| hidden_size=config["hidden_size"], | |
| projection_dim=config["projection_dim"], | |
| layer_norm_eps=config["layer_norm_eps"], | |
| ) | |
| model.load_state_dict_from_pretrained(pretrained_model_name_or_path) | |
| return model | |
| def forward(self, data): | |
| last_hidden_state = data | |
| pooled_output = last_hidden_state[:, 0, :] | |
| pooled_output = self.post_layernorm(pooled_output) | |
| # image_embeds = self.visual_projection(pooled_output) | |
| return pooled_output | |
| class OriginLastHiddenState2ImageEmbd(OriginLastHiddenState2Poolout): | |
| def __init__(self, hidden_size: int, projection_dim: int, layer_norm_eps: float): | |
| super().__init__(hidden_size, projection_dim, layer_norm_eps) | |
| def forward(self, data): | |
| pooled_output = super().forward(data) | |
| image_embeds = self.visual_projection(pooled_output) | |
| return image_embeds | |
| class VerstailSDLastHiddenState2ImageEmb(OriginLastHiddenState2ImageEmbd): | |
| def __init__(self, hidden_size: int, projection_dim: int, layer_norm_eps: float): | |
| super().__init__(hidden_size, projection_dim, layer_norm_eps) | |
| def forward(self, data): | |
| embeds = self.post_layernorm(data) | |
| embeds = self.visual_projection(embeds) | |
| embeds_pooled = embeds[:, 0:1] | |
| embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True) | |
| return embeds | |