MoodSpace / dino_clip_featextract.py
huzey's picture
commit
456aee9
import gc
import torch
from my_ipadapter_model import extract_clip_embedding_tensor, load_ipadapter
def free_memory():
torch.cuda.empty_cache()
gc.collect()
@torch.no_grad()
def extract_dino_image_embeds(images, batch_size=32):
dino = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
dino = dino.eval().cuda()
num_batches = (images.shape[0] + batch_size - 1) // batch_size
dino_image_embeds = []
for i in range(num_batches):
start_idx = i * batch_size
end_idx = min((i + 1) * batch_size, images.shape[0])
batch = images[start_idx:end_idx].cuda()
batch_embeds = dino.get_intermediate_layers(batch)[-1]
dino_image_embeds.append(batch_embeds)
dino_image_embeds = torch.cat(dino_image_embeds, dim=0).cpu()
del dino
free_memory()
return dino_image_embeds
@torch.no_grad()
def extract_clip_image_embeds(images, batch_size=32):
ipmodel = load_ipadapter()
num_batches = (images.shape[0] + batch_size - 1) // batch_size
clip_image_embeds = []
for i in range(num_batches):
start_idx = i * batch_size
end_idx = min((i + 1) * batch_size, images.shape[0])
batch = images[start_idx:end_idx].cuda()
batch_embeds = extract_clip_embedding_tensor(batch, ipmodel)
clip_image_embeds.append(batch_embeds)
clip_image_embeds = torch.cat(clip_image_embeds, dim=0).cpu()
del ipmodel
free_memory()
return clip_image_embeds
from torchvision import transforms
img_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img_transform_inv = transforms.Compose([
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1/0.229, 1/0.224, 1/0.225]),
transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1.0, 1.0, 1.0]),
transforms.ToPILImage(),
])