Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import clip | |
import copy | |
from torch.autograd import Function | |
from collections import OrderedDict | |
from torchvision import transforms | |
def convert_state_dict(state_dict): | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
if k.startswith("module."): | |
k = k.replace("module.", "") | |
new_state_dict[k] = v | |
return new_state_dict | |
def convert_weights_float(model: nn.Module): | |
"""Convert applicable model parameters to fp32""" | |
def _convert_weights_to_fp32(l): | |
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): | |
l.weight.data = l.weight.data.float() | |
if l.bias is not None: | |
l.bias.data = l.bias.data.float() | |
if isinstance(l, nn.MultiheadAttention): | |
for attr in [ | |
*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], | |
"in_proj_bias", | |
"bias_k", | |
"bias_v", | |
]: | |
tensor = getattr(l, attr) | |
if tensor is not None: | |
tensor.data = tensor.data.float() | |
for name in ["text_projection", "proj"]: | |
if hasattr(l, name): | |
attr = getattr(l, name) | |
if attr is not None: | |
attr.data = attr.data.float() | |
model.apply(_convert_weights_to_fp32) | |
class ReverseLayerF(Function): | |
def forward(ctx, x, alpha): | |
ctx.alpha = alpha | |
return x.view_as(x) | |
def backward(ctx, grad_output): | |
output = grad_output.neg() * ctx.alpha | |
return output, None | |
## taken from https://github.com/moein-shariatnia/OpenAI-CLIP/blob/master/modules.py | |
class ProjectionHead(nn.Module): | |
def __init__(self, embedding_dim, projection_dim, dropout=0): | |
super().__init__() | |
self.projection = nn.Linear(embedding_dim, projection_dim) | |
self.gelu = nn.GELU() | |
self.fc = nn.Linear(projection_dim, projection_dim) | |
self.dropout = nn.Dropout(dropout) | |
self.layer_norm = nn.LayerNorm(projection_dim) | |
def forward(self, x): | |
projected = self.projection(x) | |
x = self.gelu(projected) | |
x = self.fc(x) | |
x = self.dropout(x) | |
x = x + projected | |
x = self.layer_norm(x) | |
return x | |
def init_weights(m): # TODO: do we need init for layernorm? | |
if isinstance(m, nn.Linear): | |
torch.nn.init.xavier_uniform_(m.weight) | |
if m.bias is not None: | |
nn.init.normal_(m.bias, std=1e-6) | |
class CSD_CLIP(nn.Module): | |
"""backbone + projection head""" | |
def __init__(self, name="vit_large", content_proj_head="default", model_path=None): | |
super(CSD_CLIP, self).__init__() | |
self.content_proj_head = content_proj_head | |
if name == "vit_large": | |
if model_path is None: | |
clipmodel, _ = clip.load("ViT-L/14") | |
else: | |
clipmodel, _ = clip.load(model_path) | |
self.backbone = clipmodel.visual | |
self.embedding_dim = 1024 | |
elif name == "vit_base": | |
if model_path is None: | |
clipmodel, _ = clip.load("ViT-B/16") | |
else: | |
clipmodel, _ = clip.load(model_path) | |
self.backbone = clipmodel.visual | |
self.embedding_dim = 768 | |
self.feat_dim = 512 | |
else: | |
raise Exception("This model is not implemented") | |
convert_weights_float(self.backbone) | |
self.last_layer_style = copy.deepcopy(self.backbone.proj) | |
if content_proj_head == "custom": | |
self.last_layer_content = ProjectionHead(self.embedding_dim, self.feat_dim) | |
self.last_layer_content.apply(init_weights) | |
else: | |
self.last_layer_content = copy.deepcopy(self.backbone.proj) | |
self.backbone.proj = None | |
def dtype(self): | |
return self.backbone.conv1.weight.dtype | |
def forward(self, input_data, alpha=None): | |
feature = self.backbone(input_data) | |
if alpha is not None: | |
reverse_feature = ReverseLayerF.apply(feature, alpha) | |
else: | |
reverse_feature = feature | |
style_output = feature @ self.last_layer_style | |
style_output = nn.functional.normalize(style_output, dim=1, p=2) | |
# if alpha is not None: | |
if self.content_proj_head == "custom": | |
content_output = self.last_layer_content(reverse_feature) | |
else: | |
content_output = reverse_feature @ self.last_layer_content | |
content_output = nn.functional.normalize(content_output, dim=1, p=2) | |
return feature, content_output, style_output | |
def create_model_and_transforms(model_path="models/csd_clip.pth"): | |
# init model | |
model = CSD_CLIP("vit_large", "default") | |
# load model | |
checkpoint = torch.load(model_path, map_location="cpu") | |
state_dict = convert_state_dict(checkpoint["model_state_dict"]) | |
model.load_state_dict(state_dict, strict=False) | |
# normalization | |
normalize = transforms.Normalize( | |
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) | |
) | |
preprocess = transforms.Compose( | |
[ | |
transforms.Resize( | |
size=224, interpolation=transforms.functional.InterpolationMode.BICUBIC | |
), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
normalize, | |
] | |
) | |
return model, preprocess, preprocess | |