File size: 12,357 Bytes
fe64bad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import numpy as np
import torch
import torch.nn as nn

import clip
from ..tools.losses import get_loss_function
from ..rotation2xyz import Rotation2xyz

loss_ce = nn.CrossEntropyLoss()
loss_mse = nn.MSELoss()

cosine_sim = nn.CosineSimilarity(dim=1, eps=1e-6)
from tqdm import tqdm


class MOTIONCLIP(nn.Module):
    def __init__(self, encoder, decoder, device, lambdas, latent_dim, outputxyz,

                 pose_rep, glob, glob_rot, translation, jointstype, vertstrans, clip_lambdas={}, **kwargs):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder

        self.outputxyz = outputxyz

        self.lambdas = lambdas
        self.clip_lambdas = clip_lambdas

        self.latent_dim = latent_dim
        self.pose_rep = pose_rep
        self.glob = glob
        self.glob_rot = glob_rot
        self.device = device
        self.translation = translation
        self.jointstype = jointstype
        self.vertstrans = vertstrans

        self.clip_model = kwargs['clip_model']
        self.clip_training = kwargs.get('clip_training', False)
        if self.clip_training and self.clip_model:
            self.clip_model.training = True
        else:
            if self.clip_model:
                assert self.clip_model.training == False  # make sure clip is frozen

        self.losses = list(self.lambdas) + ["mixed"]

        self.rotation2xyz = Rotation2xyz(device=self.device)
        self.param2xyz = {"pose_rep": self.pose_rep,
                          "glob_rot": self.glob_rot,
                          "glob": self.glob,
                          "jointstype": self.jointstype,
                          "translation": self.translation,
                          "vertstrans": self.vertstrans}

    def rot2xyz(self, x, mask, get_rotations_back=False, **kwargs):
        kargs = self.param2xyz.copy()
        kargs.update(kwargs)
        return self.rotation2xyz(x, mask, get_rotations_back=get_rotations_back, **kargs)

    def compute_loss(self, batch):

        # compute all losses other than clip
        mixed_loss = 0.
        losses = {}
        for ltype, lam in self.lambdas.items():
            loss_function = get_loss_function(ltype)
            loss = loss_function(self, batch)
            mixed_loss += loss * lam
            losses[ltype] = loss.item()

        # compute clip losses
        mixed_clip_loss, clip_losses = self.compute_clip_losses(batch)

        # mix and add clip losses
        mixed_loss_with_clip = mixed_loss + mixed_clip_loss  # this is the ultimate loss to optimize, combining ALL losses
        losses.update(clip_losses)
        losses["mixed_without_clip"] = mixed_loss.item()
        losses["mixed_clip_only"] = mixed_clip_loss if isinstance(mixed_clip_loss, float) else mixed_clip_loss.item()
        losses["mixed_with_clip"] = mixed_loss_with_clip if isinstance(mixed_loss_with_clip,
                                                                       float) else mixed_loss_with_clip.item()

        return mixed_loss_with_clip, losses

    def compute_clip_losses(self, batch):
        mixed_clip_loss = 0.
        clip_losses = {}

        if self.clip_training:
            for d in self.clip_training.split('_'):
                if d == 'image':
                    features = self.clip_model.encode_image(
                        batch['clip_images']).float()  # preprocess is done in dataloader
                elif d == 'text':
                    texts = clip.tokenize(batch['clip_text']).to(self.device)
                    features = self.clip_model.encode_text(texts).float()

                # normalized features
                features_norm = features / features.norm(dim=-1, keepdim=True)
                seq_motion_features_norm = batch["z"] / batch["z"].norm(dim=-1, keepdim=True)
                logit_scale = self.clip_model.logit_scale.exp()
                logits_per_motion = logit_scale * seq_motion_features_norm @ features_norm.t()
                logits_per_d = logits_per_motion.t()

                batch_size = batch['x'].shape[0]
                ground_truth = torch.arange(batch_size, dtype=torch.long, device=self.device)

                ce_from_motion_loss = loss_ce(logits_per_motion, ground_truth)
                ce_from_d_loss = loss_ce(logits_per_d, ground_truth)
                clip_mixed_loss = (ce_from_motion_loss + ce_from_d_loss) / 2.

                clip_losses[f'{d}_ce_from_d'] = ce_from_d_loss.item()
                clip_losses[f'{d}_ce_from_motion'] = ce_from_motion_loss.item()
                clip_losses[f'{d}_mixed_ce'] = clip_mixed_loss.item()
                mixed_clip_loss += clip_mixed_loss
        else:
            for d in self.clip_lambdas.keys():
                if len(self.clip_lambdas[d].keys()) == 0:
                    continue
                with torch.no_grad():
                    if d == 'image':
                        features = self.clip_model.encode_image(
                            batch['clip_images']).float()  # preprocess is done in dataloader
                    elif d == 'text':
                        texts = clip.tokenize(batch['clip_text']).to(self.device)
                        features = self.clip_model.encode_text(texts).float()
                    else:
                        raise ValueError(f'Invalid clip domain [{d}]')

                # normalized features
                features_norm = features / features.norm(dim=-1, keepdim=True)
                seq_motion_features_norm = batch["z"] / batch["z"].norm(dim=-1, keepdim=True)

                if 'ce' in self.clip_lambdas[d].keys():
                    logit_scale = self.clip_model.logit_scale.exp()
                    logits_per_motion = logit_scale * seq_motion_features_norm @ features_norm.t()
                    logits_per_d = logits_per_motion.t()

                    batch_size = batch['x'].shape[0]
                    ground_truth = torch.arange(batch_size, dtype=torch.long, device=self.device)

                    ce_from_motion_loss = loss_ce(logits_per_motion, ground_truth)
                    ce_from_d_loss = loss_ce(logits_per_d, ground_truth)
                    clip_mixed_loss = (ce_from_motion_loss + ce_from_d_loss) / 2.

                    clip_losses[f'{d}_ce_from_d'] = ce_from_d_loss.item()
                    clip_losses[f'{d}_ce_from_motion'] = ce_from_motion_loss.item()
                    clip_losses[f'{d}_mixed_ce'] = clip_mixed_loss.item()
                    mixed_clip_loss += clip_mixed_loss * self.clip_lambdas[d]['ce']

                if 'mse' in self.clip_lambdas[d].keys():
                    mse_clip_loss = loss_mse(features, batch["z"])
                    clip_losses[f'{d}_mse'] = mse_clip_loss.item()
                    mixed_clip_loss += mse_clip_loss * self.clip_lambdas[d]['mse']

                if 'cosine' in self.clip_lambdas[d].keys():
                    cos = cosine_sim(features_norm, seq_motion_features_norm)
                    cosine_loss = (1 - cos).mean()
                    clip_losses[f'{d}_cosine'] = cosine_loss.item()
                    mixed_clip_loss += cosine_loss * self.clip_lambdas[d]['cosine']

        return mixed_clip_loss, clip_losses

    @staticmethod
    def lengths_to_mask(lengths):
        max_len = max(lengths)
        if isinstance(max_len, torch.Tensor):
            max_len = max_len.item()
        index = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len)
        mask = index < lengths.unsqueeze(1)
        return mask

    def generate_one(self, cls, duration, fact=1, xyz=False):
        y = torch.tensor([cls], dtype=int, device=self.device)[None]
        lengths = torch.tensor([duration], dtype=int, device=self.device)
        mask = self.lengths_to_mask(lengths)
        z = torch.randn(self.latent_dim, device=self.device)[None]

        batch = {"z": fact * z, "y": y, "mask": mask, "lengths": lengths}
        batch = self.decoder(batch)

        if not xyz:
            return batch["output"][0]

        output_xyz = self.rot2xyz(batch["output"], batch["mask"])

        return output_xyz[0]

    def generate(self, classes, durations, nspa=1,

                 # noise_same_action="random", noise_diff_action="random",

                 # fact=1,

                 is_amass=False, is_clip_features=False,

                 # input_type="motion",

                 textual_labels=None):
        clip_dim = self.clip_model.ln_final.normalized_shape[0]
        if is_clip_features:
            # assumed dims: classes [nspa, nats, 512]
            assert len(classes.shape) == 3
            assert classes.shape[-1] == clip_dim
            clip_features = classes.reshape([-1, clip_dim])
            nspa, nats = classes.shape[:2]
            # y = torch.zeros(y_action_names.shape, dtype=int)
            y = clip_features
            if textual_labels is not None:
                y = np.array(textual_labels).reshape([-1])

        if len(durations.shape) == 1:
            lengths = durations.to(self.device).repeat(nspa)
        else:
            lengths = durations.to(self.device).reshape(clip_features.shape[0])

        mask = self.lengths_to_mask(lengths)

        batch = {"z": clip_features,  # fact*z,
                 "y": y,
                 "mask": mask, "lengths": lengths}

        if not is_clip_features:
            batch['y'] = y

        batch = self.decoder(batch)

        if is_amass:  # lose global orientation for amass dataset
            batch['output'][:, 0] = torch.tensor([1, 0, 0, 0, -1, 0]).unsqueeze(0).unsqueeze(2)

        if self.outputxyz:
            batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"])
        elif self.pose_rep == "xyz":
            batch["output_xyz"] = batch["output"]

        return batch

    def generate_from_embedding(self, classes, durations, nspa=1, is_amass=False, classes_gaussians=None):

        if nspa is None:
            nspa = 1
        nats = len(classes)

        y = classes.to(self.device).repeat(nspa)  # (view(nspa, nats))
        if len(durations.shape) == 1:
            lengths = durations.to(self.device).repeat(nspa)
        else:
            lengths = durations.to(self.device).reshape(y.shape)
        mask = self.lengths_to_mask(lengths)
        classes_np = classes.cpu().detach().numpy()

        #motion_samples_ = np.zeros((classes_np.shape[0], 512), dtype='float32')
        motion_samples_ = np.zeros((classes_np.shape[0], 512), dtype='float32')
        for class_label in tqdm(np.unique(classes_np), total=len(np.unique(classes_np))):
            class_mask = np.where(classes_np == class_label)[0]
            sample_mu = classes_gaussians[class_label]['mu']
            sample_var = classes_gaussians[class_label]['var']

            sample = np.random.multivariate_normal(sample_mu, sample_var, size=len(class_mask))
            motion_samples_[class_mask, :] = sample

        zz = torch.from_numpy(motion_samples_).to(self.device)

        batch = {"z": zz,
                 "y": y, "mask": mask, "lengths": lengths}
        batch = self.decoder(batch)

        if is_amass:  # lose global orientation for amass dataset
            batch['output'][:, 0] = torch.tensor([1, 0, 0, 0, -1, 0]).unsqueeze(0).unsqueeze(2)

        if self.outputxyz:
            batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"])
        elif self.pose_rep == "xyz":
            batch["output_xyz"] = batch["output"]
        return batch

    def forward(self, batch):
        if self.outputxyz:
            batch["x_xyz"] = self.rot2xyz(batch["x"], batch["mask"])
        elif self.pose_rep == "xyz":
            batch["x_xyz"] = batch["x"]
        # encode
        batch.update(self.encoder(batch))

        batch["z"] = batch["mu"]
        # decode
        batch.update(self.decoder(batch))

        # if we want to output xyz
        if self.outputxyz:
            batch["output_xyz"] = self.rot2xyz(batch["output"], batch["mask"])
        elif self.pose_rep == "xyz":
            batch["output_xyz"] = batch["output"]
        return batch