File size: 3,793 Bytes
f056744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import abc
import types

import torch
from diffusers.models.transformers.transformer_flux import (
    FluxSingleTransformerBlock, FluxTransformerBlock)

from .flux_transformer_forward import (joint_transformer_forward,
                                       single_transformer_forward)


class FeatureCollector:
    def __init__(self, transformer, controller, layer_list=[]):
        self.transformer = transformer
        self.controller = controller
        self.layer_list = layer_list

    def register_transformer_control(self):
        index = 0
        for joint_transformer in self.transformer.transformer_blocks:
            place_in_transformer = f'joint_{index}'
            joint_transformer.forward = joint_transformer_forward(joint_transformer, self.controller, place_in_transformer)
            index +=1
            
        for i, single_transformer in enumerate(self.transformer.single_transformer_blocks):
            place_in_transformer = f'single_{index}'
            single_transformer.forward = single_transformer_forward(single_transformer, self.controller, place_in_transformer)
            index +=1

        self.controller.num_layers = index

    def restore_orig_transformer(self):
        place_in_transformer=''
        
        for joint_transformer in self.transformer.transformer_blocks:
            joint_transformer.forward = joint_transformer_forward(joint_transformer, None, place_in_transformer)

        for i, single_transformer in enumerate(self.transformer.single_transformer_blocks):
            single_transformer.forward = single_transformer_forward(single_transformer, None, place_in_transformer)


class FeatureControl(abc.ABC):
    def __init__(self):
        self.cur_step = 0
        self.num_layers = -1
        self.cur_layer = 0

    def step_callback(self, x_t):
        return x_t

    def between_steps(self):
        return

    @abc.abstractmethod
    def forward(self, attn, place_in_transformer: str):
        raise NotImplementedError

    @torch.no_grad()
    def __call__(self, hidden_state, place_in_transformer: str):
        hidden_state = self.forward(hidden_state, place_in_transformer)
        self.cur_layer = self.cur_layer + 1

        if self.cur_layer == self.num_layers:
            self.cur_layer = 0
            self.cur_step = self.cur_step + 1
            self.between_steps()

        return hidden_state

    def reset(self):
        self.cur_step = 0
        self.cur_layer = 0


class FeatureReplace(FeatureControl):
    def __init__(
        self, 
        layer_list=[],
        feature_steps=7
    ):
        super(FeatureReplace, self).__init__()
        self.layer_list = layer_list
        self.feature_steps = feature_steps

    
    def forward(self, hidden_states, place_in_transformer):
        layer_index = int(place_in_transformer.split('_')[-1])
        if (layer_index not in self.layer_list) or (self.cur_step not in range(0, self.feature_steps)):
            return hidden_states

        hs_dim = hidden_states.shape[1]

        t5_dim = 512
        latent_dim = 4096
        attn_dim = t5_dim + latent_dim
        index_all = torch.arange(attn_dim)
        t5_index, latent_index = index_all.split([t5_dim, latent_dim])

        if 'single' in place_in_transformer:
            mask = torch.ones(hs_dim).to(device=hidden_states.device, dtype=hidden_states.dtype)
            mask[t5_index] = 0 # Only use image latent
        else:
            mask = torch.ones(hs_dim).to(device=hidden_states.device, dtype=hidden_states.dtype)

        mask = mask[None, :, None]
        
        source_hs = hidden_states[:1]
        target_hs = hidden_states[1:]
        
        target_hs = source_hs * mask + target_hs * (1 - mask)
        hidden_states[1:] = target_hs
        return hidden_states