File size: 1,970 Bytes
12edc27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import cv2
import numpy as np
import torch
from typing import Union
from einops import rearrange, repeat
from PIL import Image
from safetensors.torch import load_file as load_sft
from torch import nn
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel

from ..utils.process_util import print_load_warning

class ReduxImageEncoder(nn.Module):

    def __init__(
        self,
        redux_path: str,
        siglip_path: str = "google/siglip-so400m-patch14-384",
        redux_dim: int = 1152,
        txt_in_features: int = 4096,
        device: Union[str, torch.device, None] = None,
    ) -> None:
        super().__init__()

        self.redux_dim = redux_dim

        self.redux_up = nn.Linear(redux_dim, txt_in_features * 3)
        self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features)

        sd = load_sft(redux_path)
        missing, unexpected = self.load_state_dict(sd, strict=False, assign=True)
        print_load_warning(missing, unexpected)

        self.siglip = SiglipVisionModel.from_pretrained(siglip_path)
        self.normalize = SiglipImageProcessor.from_pretrained(siglip_path)

        self.to(device)

    def __call__(self, x: Image.Image, device: Union[str, torch.device, None] = None, dtype: Union[str, torch.dtype, None] = None) -> torch.Tensor:
        if isinstance(device, str):
            device = torch.device(device)

        if isinstance(dtype, str):
            dtype = torch.dtype(dtype)

        if device is None:
            device = next(self.parameters()).device

        if dtype is None:
            dtype = next(self.parameters()).dtype

        imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True)
        _encoded_x = self.siglip(**imgs.to(device=device, dtype=dtype)).last_hidden_state
        projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x)))

        return projected_x