File size: 6,186 Bytes
07f1f64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
import torch.nn as nn


class PartiallyFrozenEmbedding(nn.Module):
    """Split an existing `nn.Embedding` module that splits the embedding into:

    - A frozen embedding for indices [0..freeze_until_idx].
    - A trainable embedding for indices [freeze_until_idx+1..vocab_size-1].

    This should work with both Zero-2 and Zero-3 seamlessly
    """

    def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int):
        """
        :param original_embedding: An instance of nn.Embedding (the original embedding layer).
        :param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen.
        """
        super().__init__()
        self.freeze_until_idx = freeze_until_idx
        self.original_vocab_size = original_embedding.num_embeddings
        self.embedding_dim = original_embedding.embedding_dim

        # Split the original embedding into frozen and trainable parts
        self.embedding_frozen = nn.Embedding(
            freeze_until_idx,
            self.embedding_dim,
            dtype=original_embedding.weight.dtype,
            device=original_embedding.weight.device,
        )
        self.embedding_trainable = nn.Embedding(
            self.original_vocab_size - freeze_until_idx,
            self.embedding_dim,
            dtype=original_embedding.weight.dtype,
            device=original_embedding.weight.device,
        )

        # Copy weights from the original embedding into the frozen and trainable parts
        with torch.no_grad():
            self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx])
            self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:])

        # Freeze the frozen embedding
        self.embedding_frozen.weight.requires_grad = False

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Forward pass for the split embedding wrapper.
        :param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1].
        """
        # Masks to separate frozen and trainable indices
        # (bsz, seq_len)
        mask_frozen = input_ids < self.freeze_until_idx
        mask_trainable = ~mask_frozen

        # Output tensor for embedding results
        batch_size, seq_len = input_ids.shape
        embeddings = torch.zeros(
            batch_size,
            seq_len,
            self.embedding_dim,
            device=input_ids.device,
            dtype=self.embedding_frozen.weight.dtype,
        )

        # Handle frozen embedding
        if mask_frozen.any():
            frozen_ids = input_ids[mask_frozen]
            frozen_emb = self.embedding_frozen(frozen_ids)
            embeddings[mask_frozen] = frozen_emb

        # Handle trainable embedding
        if mask_trainable.any():
            # Adjust trainable IDs to the local index space of the trainable embedding
            trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx)
            trainable_emb = self.embedding_trainable(trainable_ids)
            embeddings[mask_trainable] = trainable_emb

        return embeddings

    def to_unsplit(self) -> nn.Embedding:
        unsplit_embedding = nn.Embedding(
            self.original_vocab_size,
            self.embedding_dim,
            dtype=self.embedding_frozen.weight.dtype,
            device=self.embedding_frozen.weight.device,
        )

        with torch.no_grad():
            unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight)
            unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight)

        return unsplit_embedding


class PartiallyFrozenLinear(nn.Module):
    """A wrapper around nn.Linear to partially freeze part of the weight matrix."""

    def __init__(self, original_linear: nn.Linear, freeze_until_idx: int):
        """
        :param original_linear: The original nn.Linear layer.
        :param freeze_until_idx: The index up to which the rows of the weight matrix are frozen.
        """
        super().__init__()
        assert original_linear.bias is None, "Currently only support linear module without bias"

        self.freeze_until_idx = freeze_until_idx
        self.input_dim = original_linear.in_features
        self.output_dim = original_linear.out_features

        # Create frozen and trainable linear layers
        self.linear_frozen = nn.Linear(
            self.input_dim,
            freeze_until_idx,
            bias=False,
            dtype=original_linear.weight.dtype,
            device=original_linear.weight.device,
        )
        self.linear_trainable = nn.Linear(
            self.input_dim,
            self.output_dim - freeze_until_idx,
            bias=False,
            dtype=original_linear.weight.dtype,
            device=original_linear.weight.device,
        )

        # Copy weights from the original linear layer
        with torch.no_grad():
            self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx])
            self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:])

        # Freeze the frozen linear layer
        self.linear_frozen.weight.requires_grad = False

    def forward(self, input_tensor):
        # input_tensor: (bsz, seq_len, hidden_state_dim)
        frozen_output = self.linear_frozen(input_tensor)
        trainable_output = self.linear_trainable(input_tensor)
        return torch.cat((frozen_output, trainable_output), dim=-1)

    def to_unsplit(self) -> nn.Linear:
        unsplit_linear = nn.Linear(
            self.input_dim,
            self.output_dim,
            bias=False,
            dtype=self.linear_frozen.weight.dtype,
            device=self.linear_frozen.weight.device,
        )

        # Copy weights from the frozen and trainable layers into the unsplit linear layer
        with torch.no_grad():
            unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight)
            unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight)

        return unsplit_linear