File size: 5,453 Bytes
4bf9661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from dataclasses import dataclass
from transformers import CLIPModel as HFCLIPModel
from transformers import AutoTokenizer

from torch import nn, einsum

from .base_model import BaseModelConfig

from transformers import CLIPConfig
from typing import Any, Optional, Tuple, Union
import torch

from .cross_modeling import Cross_model

import json, os

class XCLIPModel(HFCLIPModel):
    def __init__(self, config: CLIPConfig):
        super().__init__(config)
    
    def get_text_features(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.FloatTensor:

        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        text_outputs = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # pooled_output = text_outputs[1]
        # text_features = self.text_projection(pooled_output)
        last_hidden_state = text_outputs[0]
        text_features = self.text_projection(last_hidden_state)

        pooled_output = text_outputs[1]
        text_features_EOS = self.text_projection(pooled_output)


        # del last_hidden_state, text_outputs
        # gc.collect()

        return text_features, text_features_EOS

    def get_image_features(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.FloatTensor:
        
        # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        vision_outputs = self.vision_model(
            pixel_values=pixel_values,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        # pooled_output = vision_outputs[1]  # pooled_output
        # image_features = self.visual_projection(pooled_output)
        last_hidden_state = vision_outputs[0]
        image_features = self.visual_projection(last_hidden_state)  

        return image_features



@dataclass
class ClipModelConfig(BaseModelConfig):
    _target_: str = "diffsynth.extensions.QualityMetric.trainer.models.clip_model.CLIPModel"
    pretrained_model_name_or_path: str ="checkpoints/clip-vit-base-patch32"


class CLIPModel(nn.Module):
    def __init__(self, ckpt, config_file=False):
        super().__init__()
        if config_file is None:
            self.model = XCLIPModel.from_pretrained(ckpt)
        else:
            with open(os.path.join(ckpt, "config.json"), "r", encoding="utf-8") as f:
                config = json.load(f)
            config = CLIPConfig(**config)
            self.model = XCLIPModel._from_config(config)
        self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16)
    
    def get_text_features(self, *args, **kwargs):
        return self.model.get_text_features(*args, **kwargs)

    def get_image_features(self, *args, **kwargs):
        return self.model.get_image_features(*args, **kwargs)

    def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None):
        outputs = ()

        text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024
        outputs += text_EOS,

        image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024
        condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024

        sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f)
        sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0]
        sim_text_condition = sim_text_condition / sim_text_condition.max()
        mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77

        mask = mask.repeat(1,image_f.shape[1],1) # B*257*77
        bc = int(image_f.shape[0]/2)

        sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half())
        sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half())
        outputs += sim0[:,0,:],
        outputs += sim1[:,0,:],

        return outputs

    @property
    def logit_scale(self):
        return self.model.logit_scale

    def save(self, path):
        self.model.save_pretrained(path)