Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,970 Bytes
12edc27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import cv2
import numpy as np
import torch
from typing import Union
from einops import rearrange, repeat
from PIL import Image
from safetensors.torch import load_file as load_sft
from torch import nn
from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel
from ..utils.process_util import print_load_warning
class ReduxImageEncoder(nn.Module):
def __init__(
self,
redux_path: str,
siglip_path: str = "google/siglip-so400m-patch14-384",
redux_dim: int = 1152,
txt_in_features: int = 4096,
device: Union[str, torch.device, None] = None,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.redux_up = nn.Linear(redux_dim, txt_in_features * 3)
self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features)
sd = load_sft(redux_path)
missing, unexpected = self.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
self.siglip = SiglipVisionModel.from_pretrained(siglip_path)
self.normalize = SiglipImageProcessor.from_pretrained(siglip_path)
self.to(device)
def __call__(self, x: Image.Image, device: Union[str, torch.device, None] = None, dtype: Union[str, torch.dtype, None] = None) -> torch.Tensor:
if isinstance(device, str):
device = torch.device(device)
if isinstance(dtype, str):
dtype = torch.dtype(dtype)
if device is None:
device = next(self.parameters()).device
if dtype is None:
dtype = next(self.parameters()).dtype
imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True)
_encoded_x = self.siglip(**imgs.to(device=device, dtype=dtype)).last_hidden_state
projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x)))
return projected_x |