pooyanrg commited on
Commit
ad4721b
·
1 Parent(s): 0884186

initial commit

Browse files
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ RUN useradd -m -u 1000 user
4
+ USER user
5
+ ENV PATH="/home/user/.local/bin:$PATH"
6
+
7
+ WORKDIR /app
8
+
9
+ COPY --chown=user ./requirements.txt requirements.txt
10
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
11
+
12
+ COPY --chown=user . /app
13
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw
4
+ import torch
5
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
6
+ from utils.model import init_model
7
+ from utils.tokenization_clip import SimpleTokenizer as ClipTokenizer
8
+
9
+ from fastapi.staticfiles import StaticFiles
10
+ from fileservice import app
11
+
12
+
13
+ html_text = """
14
+ <div id="container">
15
+ <canvas id="canvas" width="512" height="512"></canvas><img id="canvas-background" style="display:none;"/>
16
+ </div>
17
+ """
18
+
19
+ def image_to_tensor(image_path):
20
+ image = Image.open(image_path).convert('RGB')
21
+
22
+ preprocess = Compose([
23
+ Resize([224, 224], interpolation=Image.BICUBIC),
24
+ lambda image: image.convert("RGB"),
25
+ ToTensor(),
26
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
27
+ ])
28
+ image_data = preprocess(image)
29
+
30
+ return {'image': image_data}
31
+
32
+ def get_image_data(image_path):
33
+ image_input = image_to_tensor(image_path)
34
+ return image_input
35
+
36
+ def get_intervention_vector(selected_cells_bef, selected_cells_aft):
37
+ left = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
38
+ right = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
39
+
40
+ for (i, j) in selected_cells_bef:
41
+ left[i, j] = 1.
42
+ for (i, j) in selected_cells_aft:
43
+ right[i, j] = 1.
44
+
45
+
46
+ left_map = np.zeros((1, 14 * 14 + 1))
47
+ right_map = np.zeros((1, 14 * 14 + 1))
48
+
49
+ left_map[0, 1:] = np.reshape(left, (1, 14 * 14))
50
+ right_map[0, 1:] = np.reshape(right, (1, 14 * 14))
51
+
52
+
53
+ if len(selected_cells_bef) == 0:
54
+ left_map[0, 0] = 0.0
55
+
56
+ if len(selected_cells_aft) == 0:
57
+ right_map[0, 0] = 0.0
58
+
59
+
60
+ return left_map, right_map
61
+
62
+ def _get_rawimage(image_path):
63
+ # Pair x L x T x 3 x H x W
64
+ image = np.zeros((1, 3, 224,
65
+ 224), dtype=np.float)
66
+
67
+ for i in range(1):
68
+
69
+ raw_image_data = get_image_data(image_path)
70
+ raw_image_data = raw_image_data['image']
71
+
72
+ image[i] = raw_image_data
73
+
74
+ return image
75
+
76
+
77
+ def greedy_decode(model, tokenizer, video, video_mask, gt_left_map, gt_right_map):
78
+ visual_output, left_map, right_map = model.get_sequence_visual_output(video, video_mask,
79
+ gt_left_map[:, 0, :].squeeze(), gt_right_map[:, 0, :].squeeze())
80
+
81
+ video_mask = torch.ones(visual_output.shape[0], visual_output.shape[1], device=visual_output.device).long()
82
+ input_caption_ids = torch.zeros(visual_output.shape[0], device=visual_output.device).data.fill_(tokenizer.vocab["<|startoftext|>"])
83
+ input_caption_ids = input_caption_ids.long().unsqueeze(1)
84
+ decoder_mask = torch.ones_like(input_caption_ids)
85
+ for i in range(32):
86
+ decoder_scores = model.decoder_caption(visual_output, video_mask, input_caption_ids, decoder_mask, get_logits=True)
87
+ next_words = decoder_scores[:, -1].max(1)[1].unsqueeze(1)
88
+ input_caption_ids = torch.cat([input_caption_ids, next_words], 1)
89
+ next_mask = torch.ones_like(next_words)
90
+ decoder_mask = torch.cat([decoder_mask, next_mask], 1)
91
+
92
+ return input_caption_ids[:, 1:].tolist(), left_map, right_map
93
+
94
+ # Dummy prediction function
95
+ def predict_image(image_bef, image_aft, selected_cells_bef, selected_cells_aft):
96
+ if image_bef is None:
97
+ return "No image provided", "", ""
98
+ if image_aft is None:
99
+ return "No image provided", "", ""
100
+
101
+
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+
104
+ model = init_model('data/pytorch_model.pt', device)
105
+
106
+ tokenizer = ClipTokenizer()
107
+
108
+ left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
109
+
110
+ left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
111
+
112
+ bef_image = torch.from_numpy(_get_rawimage(image_bef)).unsqueeze(1)
113
+ aft_image = torch.from_numpy(_get_rawimage(image_aft)).unsqueeze(1)
114
+
115
+ image_pair = torch.cat([bef_image, aft_image], 1)
116
+
117
+ image_mask = torch.from_numpy(np.ones(2, dtype=np.long)).unsqueeze(0)
118
+
119
+ result_list, left_map, right_map = greedy_decode(model, tokenizer, image_pair, image_mask, left_map, right_map)
120
+
121
+
122
+ decode_text_list = tokenizer.convert_ids_to_tokens(result_list[0])
123
+ if "<|endoftext|>" in decode_text_list:
124
+ SEP_index = decode_text_list.index("<|endoftext|>")
125
+ decode_text_list = decode_text_list[:SEP_index]
126
+ if "!" in decode_text_list:
127
+ PAD_index = decode_text_list.index("!")
128
+ decode_text_list = decode_text_list[:PAD_index]
129
+ decode_text = decode_text_list.strip()
130
+
131
+ # Generate dummy predictions
132
+ pred = f"{decode_text}"
133
+
134
+ # Include information about selected cells
135
+ selected_info_bef = f"{selected_cells_bef}" if selected_cells_bef else "No image patch was selected"
136
+ selected_info_aft = f"{selected_cells_aft}" if selected_cells_aft else "No image patch was selected"
137
+
138
+ return pred, selected_info_bef, selected_info_aft
139
+
140
+ # Add grid to the image
141
+ def add_grid_to_image(image_path, grid_size=14):
142
+ if image_path is None:
143
+ return None
144
+
145
+ image = Image.open(image_path)
146
+ w, h = image.size
147
+
148
+ image = image.convert('RGBA')
149
+
150
+ draw = ImageDraw.Draw(image)
151
+
152
+ x_positions = np.linspace(0, w, grid_size + 1)
153
+ y_positions = np.linspace(0, h, grid_size + 1)
154
+
155
+ # Draw the vertical lines
156
+ for x in x_positions[1:-1]:
157
+ line = ((x, 0), (x, h))
158
+ draw.line(line, fill='white')
159
+
160
+ # Draw the horizontal lines
161
+ for y in y_positions[1:-1]:
162
+ line = ((0, y), (w, y))
163
+ draw.line(line, fill='white')
164
+
165
+ return image, h, w
166
+
167
+ # Handle cell selection
168
+ def handle_click(image, evt: gr.SelectData, selected_cells, image_path):
169
+ if image is None:
170
+ return None, []
171
+
172
+ grid_size = 14
173
+
174
+ image, h, w = add_grid_to_image(image_path, grid_size)
175
+
176
+ x_positions = np.linspace(0, w, grid_size + 1)
177
+ y_positions = np.linspace(0, h, grid_size + 1)
178
+
179
+ # Calculate which cell was clicked
180
+ for index, x in enumerate(x_positions[:-1]):
181
+ if evt.index[0] >= x and evt.index[0] <= x_positions[index+1]:
182
+ row = index
183
+
184
+ for index, y in enumerate(y_positions[:-1]):
185
+ if evt.index[1] >= y and evt.index[1] <= y_positions[index+1]:
186
+ col = index
187
+
188
+ cell_idx = (row, col)
189
+
190
+ # Toggle selection
191
+ if cell_idx in selected_cells:
192
+ selected_cells.remove(cell_idx)
193
+ else:
194
+ selected_cells.append(cell_idx)
195
+
196
+ # Add semi-transparent overlay for selected cells
197
+ highlight_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0)) # Fully transparent layer
198
+ highlight_draw = ImageDraw.Draw(highlight_layer)
199
+
200
+ # Define a lighter green color with 40% transparency
201
+ light_green = (144, 238, 144, 102) # RGB = (144, 238, 144), Alpha = 102 (40% of 255)
202
+
203
+
204
+ for (row, col) in selected_cells:
205
+ cell_top_left = (x_positions[row], y_positions[col])
206
+ cell_bottom_right = (x_positions[row + 1], y_positions[col + 1])
207
+
208
+ highlight_draw.rectangle([cell_top_left, cell_bottom_right], fill=light_green, outline='white')
209
+
210
+ result_img = Image.alpha_composite(image.convert('RGBA'), highlight_layer)
211
+
212
+ return result_img, selected_cells
213
+
214
+
215
+ # Process example images
216
+ def process_example(image_path_bef, image_path_aft):
217
+ # Add grid to the example image
218
+ image_bef_grid, _, _ = add_grid_to_image(image_path_bef, 14)
219
+ image_aft_grid, _, _ = add_grid_to_image(image_path_aft, 14)
220
+ return image_bef_grid, image_aft_grid # Reset selected cells and store original image
221
+
222
+ def display_image(image_path):
223
+ image_grid, _, _ = add_grid_to_image(image_path, 14)
224
+ return image_grid, []
225
+
226
+ with gr.Blocks() as demo:
227
+ gr.Markdown("# TAB: Transformer Attention Bottleneck")
228
+
229
+ # Instructions
230
+ gr.Markdown("""
231
+ ## Instructions:
232
+ 1. Upload an image or select one from the examples
233
+ 2. Click on grid cells to select/deselect them
234
+ 3. Click the 'Predict' button to get model predictions
235
+ """)
236
+
237
+ selected_cells_bef = gr.State([])
238
+ selected_cells_aft = gr.State([])
239
+
240
+ with gr.Row():
241
+ with gr.Column(scale=1):
242
+ # Input components with grid overlay
243
+ image_bef = gr.Image(type="filepath")
244
+ image_aft = gr.Image(type="filepath")
245
+
246
+ predict_btn = gr.Button("Predict")
247
+
248
+ with gr.Column(scale=1):
249
+
250
+ image_display_with_grid_bef = gr.Image(type="pil", label="Before Image with Grid")
251
+ image_display_with_grid_aft = gr.Image(type="pil", label="After Image with Grid")
252
+
253
+ # Add click event to the displayed image
254
+ image_display_with_grid_bef.select(
255
+ handle_click,
256
+ inputs=[image_display_with_grid_bef, selected_cells_bef, image_bef],
257
+ outputs=[image_display_with_grid_bef, selected_cells_bef]
258
+ )
259
+
260
+ image_display_with_grid_aft.select(
261
+ handle_click,
262
+ inputs=[image_display_with_grid_aft, selected_cells_aft, image_aft],
263
+ outputs=[image_display_with_grid_aft, selected_cells_aft]
264
+ )
265
+
266
+ with gr.Row():
267
+ with gr.Column(scale=1):
268
+ # Example images
269
+ examples = gr.Examples(
270
+ examples=[["data/images/CLEVR_default_000572.png", "data/images/CLEVR_semantic_000572.png"],
271
+ ["data/images/CLEVR_default_003339.png", "data/images/CLEVR_semantic_003339.png"]],
272
+ inputs=[image_bef, image_aft],
273
+ outputs=[image_display_with_grid_bef, image_display_with_grid_aft],
274
+ label="Example Images",
275
+ fn=process_example,
276
+ examples_per_page=5
277
+ )
278
+
279
+ # image_bef.change(
280
+ # fn=display_image,
281
+ # inputs=[image_bef],
282
+ # outputs=[image_display_with_grid_bef, selected_cells_bef]
283
+ # )
284
+
285
+ # image_aft.change(
286
+ # fn=display_image,
287
+ # inputs=[image_aft],
288
+ # outputs=[image_display_with_grid_aft, selected_cells_aft]
289
+ # )
290
+
291
+ image_bef.change(
292
+ fn=None,
293
+ inputs=[image_bef],
294
+ outputs=[],
295
+ js="(image) => { initializeEditor(); importBackground(image); return []; }",
296
+ )
297
+
298
+ image_aft.change(
299
+ fn=None,
300
+ inputs=[image_aft],
301
+ outputs=[],
302
+ js="(image) => { initializeEditor(); importBackground(image); return []; }",
303
+ )
304
+
305
+ with gr.Column(scale=1):
306
+ # Output components
307
+ prediction = gr.Textbox(label="Predicted caption")
308
+ selected_info_bef = gr.Textbox(label="Selected patches on before")
309
+ selected_info_aft = gr.Textbox(label="Selected patches on after")
310
+
311
+ # Connect the predict button to the prediction function
312
+ predict_btn.click(
313
+ fn=predict_image,
314
+ inputs=[image_bef, image_aft, selected_cells_bef, selected_cells_aft],
315
+ outputs=[prediction, selected_info_bef, selected_info_aft]
316
+ )
317
+
318
+
319
+
320
+ app.mount("/js", StaticFiles(directory="js"), name="js")
321
+ gr.mount_gradio_app(app, demo, path="/")
322
+
323
+
324
+
fileservice.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, Response
2
+
3
+ filenames = ["js/interactive_grid.js"]
4
+ contents = "\n".join(
5
+ [f"<script type='text/javascript' src='{x}'></script>" for x in filenames]
6
+ )
7
+
8
+ ga_script = """
9
+ <!-- Google tag (gtag.js) -->
10
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-11ZHMNWP9Y"></script>
11
+ <script>
12
+ window.dataLayer = window.dataLayer || [];
13
+ function gtag(){dataLayer.push(arguments);}
14
+ gtag('js', new Date());
15
+
16
+ gtag('config', 'G-11ZHMNWP9Y');
17
+ </script>
18
+ """
19
+
20
+ app = FastAPI()
21
+
22
+
23
+ @app.middleware("http")
24
+ async def insert_js(request: Request, call_next):
25
+ path = request.scope["path"] # get the request route
26
+ response = await call_next(request)
27
+
28
+ if path == "/":
29
+ response_body = ""
30
+ async for chunk in response.body_iterator:
31
+ response_body += chunk.decode()
32
+
33
+ charset_tag = '<meta charset="utf-8" />'
34
+ if charset_tag in response_body:
35
+ response_body = response_body.replace(charset_tag, charset_tag + ga_script)
36
+
37
+ response_body = response_body.replace("</body>", contents + "</body>")
38
+
39
+ del response.headers["content-length"]
40
+
41
+ return Response(
42
+ content=response_body,
43
+ status_code=response.status_code,
44
+ headers=dict(response.headers),
45
+ media_type=response.media_type,
46
+ )
47
+
48
+ return response
js/interactive_grid.js ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const gridSize = 14;
2
+ var cellSize = null;
3
+ var inputImage = null;
4
+ var image = null;
5
+ var canvas = null;
6
+ var ctx = null;
7
+ var canvasBg = null;
8
+ let grid = new Array(gridSize).fill(null).map(() => new Array(gridSize).fill(false));
9
+ var isInitialized = false;
10
+
11
+ let selectedCells = 0;
12
+
13
+ function createGrid() {
14
+ console.log('createGrid')
15
+
16
+ for (let i = 0; i < 196; i++) {
17
+ const div = document.createElement('div');
18
+ div.classList.add('checkbox');
19
+ div.innerHTML = '<input type="checkbox">';
20
+ grid.appendChild(div);
21
+ }
22
+ }
23
+
24
+ function loadImage(event) {
25
+ const file = event.target.files[0];
26
+ const reader = new FileReader();
27
+ reader.onload = function (e) {
28
+ image.src = e.target.result;
29
+ }
30
+ reader.readAsDataURL(file);
31
+ }
32
+
33
+
34
+ function handleMouseDown(event) {
35
+ // console.log("handleMouseDown");
36
+ }
37
+
38
+ function handleMouseMove(event) {
39
+ // console.log("handleMouseMove");
40
+ }
41
+
42
+ function handleMouseUp(event) {
43
+ // console.log("handleMouseUp");
44
+ }
45
+
46
+ function handleMouseLeave(event) {
47
+ // console.log("handleMouseLeave");
48
+ }
49
+
50
+
51
+ function drawGrid() {
52
+ ctx.clearRect(0, 0, canvas.width, canvas.height);
53
+ drawBackground();
54
+ for (let row = 0; row < gridSize; row++) {
55
+ for (let col = 0; col < gridSize; col++) {
56
+ ctx.beginPath();
57
+ ctx.rect(col * cellSize, row * cellSize, cellSize, cellSize);
58
+ ctx.strokeStyle = 'black';
59
+ ctx.lineWidth = 2;
60
+ ctx.stroke();
61
+
62
+ if (grid[row][col]) {
63
+ ctx.fillStyle = 'rgba(0, 255, 0, 0.5)';
64
+ ctx.fillRect(col * cellSize, row * cellSize, cellSize, cellSize);
65
+ }
66
+ }
67
+ }
68
+ }
69
+
70
+
71
+ function initializeEditor() {
72
+ console.log("initializeEditor");
73
+
74
+ if (isInitialized) {
75
+ return;
76
+ }
77
+ isInitialized = true;
78
+
79
+ image = document.getElementById('image');
80
+ canvas = document.getElementById('canvas');
81
+ ctx = canvas.getContext('2d');
82
+
83
+ // Add click event listener to canvas
84
+ canvas.addEventListener('mousedown', handleMouseDown);
85
+ canvas.addEventListener('mousemove', handleMouseMove);
86
+ canvas.addEventListener('mouseup', handleMouseUp);
87
+ canvas.addEventListener('mouseleave', handleMouseLeave);
88
+
89
+ cellSize = canvas.width / gridSize;
90
+
91
+ canvas.addEventListener('click', (event) => {
92
+ const rect = canvas.getBoundingClientRect();
93
+ const scaleX = canvas.width / rect.width;
94
+ const scaleY = canvas.height / rect.height;
95
+ const x = (event.clientX - rect.left) * scaleX;
96
+ const y = (event.clientY - rect.top) * scaleY;
97
+ const row = Math.floor(y / cellSize);
98
+ const col = Math.floor(x / cellSize);
99
+
100
+ // If the cell is already selected, it's always allowed to deselect it
101
+ if (grid[row][col]) {
102
+ grid[row][col] = false;
103
+ selectedCells--; // Decrement the selected cell count
104
+ } else {
105
+ // Only select a new cell if less than 50 cells are already selected
106
+ if (selectedCells < 50) {
107
+ grid[row][col] = true;
108
+ selectedCells++; // Increment the selected cell count
109
+ }
110
+ }
111
+ drawGrid();
112
+ });
113
+
114
+ drawGrid();
115
+ }
116
+
117
+
118
+ function drawBackground() {
119
+ if (canvasBg != null) {
120
+ const canvasWidth = canvas.width;
121
+ const canvasHeight = canvas.height;
122
+
123
+ const bgWidth = canvasBg.width;
124
+ const bgHeight = canvasBg.height;
125
+
126
+ const scaleX = canvasWidth / bgWidth;
127
+ const scaleY = canvasHeight / bgHeight;
128
+
129
+ const scale = Math.max(scaleX, scaleY);
130
+
131
+ const newWidth = bgWidth * scale;
132
+ const newHeight = bgHeight * scale;
133
+
134
+ const xOffset = (canvasWidth - newWidth) / 2;
135
+ const yOffset = (canvasHeight - newHeight) / 2;
136
+
137
+ ctx.drawImage(canvasBg, 0, 0, bgWidth, bgHeight, xOffset, yOffset, newWidth, newHeight);
138
+ }
139
+ }
140
+
141
+ function importBackground(image) {
142
+ if (image == null) {
143
+ canvasBg = null;
144
+ drawGrid();
145
+ return;
146
+ }
147
+
148
+ let m = new Image();
149
+ m.src = image;
150
+ m.onload = function () {
151
+ canvasBg = m;
152
+ drawGrid();
153
+ }
154
+ }
155
+
156
+ function read_js_Data() {
157
+ console.log("read_js_Data");
158
+ console.log("read_js_Data");
159
+ console.log("read_js_Data");
160
+ console.log("read_js_Data");
161
+ console.log("read_js_Data");
162
+ return grid;
163
+ }
164
+
165
+
166
+ function set_grid_from_data(data) {
167
+ if (data.length !== gridSize || data[0].length !== gridSize) {
168
+ throw new Error('Invalid data dimensions. Expected ' + gridSize + 'x' + gridSize);
169
+ }
170
+
171
+ selectedCells = 0; // Reset the selected cell count
172
+ for (let row = 0; row < gridSize; row++) {
173
+ for (let col = 0; col < gridSize; col++) {
174
+ grid[row][col] = data[row][col];
175
+ if (grid[row][col]) {
176
+ selectedCells++; // Count the number of initially selected cells
177
+ }
178
+ }
179
+ }
180
+
181
+ drawGrid();
182
+ }
183
+
184
+
185
+ function clear_grid() {
186
+ console.log("clearGrid");
187
+ for (let row = 0; row < gridSize; row++) {
188
+ for (let col = 0; col < gridSize; col++) {
189
+ grid[row][col] = false;
190
+ }
191
+ }
192
+ selectedCells = 0; // Reset the selected cell count
193
+ drawGrid();
194
+ }
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ boto3==1.36.9
2
+ botocore==1.36.9
3
+ certifi==2024.12.14
4
+ charset-normalizer==3.4.1
5
+ ftfy==6.2.3
6
+ idna==3.10
7
+ jmespath==1.0.1
8
+ numpy==1.23.0
9
+ opencv-python==4.11.0.86
10
+ Pillow==9.3.0
11
+ regex==2024.11.6
12
+ requests==2.32.3
13
+ s3transfer==0.11.2
14
+ gradio
15
+ torch
16
+ torchvision
17
+ torchaudio
18
+ tqdm==4.67.1
19
+ fastapi
20
+ uvicorn[standard]
utils/cross-base/cross_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_probs_dropout_prob": 0.1,
3
+ "hidden_act": "gelu",
4
+ "hidden_dropout_prob": 0.1,
5
+ "hidden_size": 512,
6
+ "initializer_range": 0.02,
7
+ "intermediate_size": 2048,
8
+ "max_position_embeddings": 150,
9
+ "num_attention_heads": 8,
10
+ "num_hidden_layers": 2,
11
+ "vocab_size": 512
12
+ }
utils/decoder-base/decoder_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_probs_dropout_prob": 0.1,
3
+ "hidden_act": "gelu",
4
+ "hidden_dropout_prob": 0.1,
5
+ "hidden_size": 512,
6
+ "initializer_range": 0.02,
7
+ "intermediate_size": 2048,
8
+ "num_attention_heads": 8,
9
+ "num_hidden_layers": 12,
10
+ "type_vocab_size": 2,
11
+ "vocab_size": 49408,
12
+ "num_decoder_layers": 3,
13
+ "max_target_embeddings": 77
14
+ }
utils/file_utils.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4
+ Copyright by the AllenNLP authors.
5
+ """
6
+
7
+ import os
8
+ import logging
9
+ import shutil
10
+ import tempfile
11
+ import json
12
+ from urllib.parse import urlparse
13
+ from pathlib import Path
14
+ from typing import Optional, Tuple, Union, IO, Callable, Set
15
+ from hashlib import sha256
16
+ from functools import wraps
17
+
18
+ from tqdm import tqdm
19
+
20
+ import boto3
21
+ from botocore.exceptions import ClientError
22
+ import requests
23
+
24
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
25
+
26
+ PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
27
+ Path.home() / '.pytorch_pretrained_bert'))
28
+
29
+
30
+ def url_to_filename(url: str, etag: str = None) -> str:
31
+ """
32
+ Convert `url` into a hashed filename in a repeatable way.
33
+ If `etag` is specified, append its hash to the url's, delimited
34
+ by a period.
35
+ """
36
+ url_bytes = url.encode('utf-8')
37
+ url_hash = sha256(url_bytes)
38
+ filename = url_hash.hexdigest()
39
+
40
+ if etag:
41
+ etag_bytes = etag.encode('utf-8')
42
+ etag_hash = sha256(etag_bytes)
43
+ filename += '.' + etag_hash.hexdigest()
44
+
45
+ return filename
46
+
47
+
48
+ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
49
+ """
50
+ Return the url and etag (which may be ``None``) stored for `filename`.
51
+ Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
52
+ """
53
+ if cache_dir is None:
54
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
55
+ if isinstance(cache_dir, Path):
56
+ cache_dir = str(cache_dir)
57
+
58
+ cache_path = os.path.join(cache_dir, filename)
59
+ if not os.path.exists(cache_path):
60
+ raise FileNotFoundError("file {} not found".format(cache_path))
61
+
62
+ meta_path = cache_path + '.json'
63
+ if not os.path.exists(meta_path):
64
+ raise FileNotFoundError("file {} not found".format(meta_path))
65
+
66
+ with open(meta_path) as meta_file:
67
+ metadata = json.load(meta_file)
68
+ url = metadata['url']
69
+ etag = metadata['etag']
70
+
71
+ return url, etag
72
+
73
+
74
+ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
75
+ """
76
+ Given something that might be a URL (or might be a local path),
77
+ determine which. If it's a URL, download the file and cache it, and
78
+ return the path to the cached file. If it's already a local path,
79
+ make sure the file exists and then return the path.
80
+ """
81
+ if cache_dir is None:
82
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
83
+ if isinstance(url_or_filename, Path):
84
+ url_or_filename = str(url_or_filename)
85
+ if isinstance(cache_dir, Path):
86
+ cache_dir = str(cache_dir)
87
+
88
+ parsed = urlparse(url_or_filename)
89
+
90
+ if parsed.scheme in ('http', 'https', 's3'):
91
+ # URL, so get it from the cache (downloading if necessary)
92
+ return get_from_cache(url_or_filename, cache_dir)
93
+ elif os.path.exists(url_or_filename):
94
+ # File, and it exists.
95
+ return url_or_filename
96
+ elif parsed.scheme == '':
97
+ # File, but it doesn't exist.
98
+ raise FileNotFoundError("file {} not found".format(url_or_filename))
99
+ else:
100
+ # Something unknown
101
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
102
+
103
+
104
+ def split_s3_path(url: str) -> Tuple[str, str]:
105
+ """Split a full s3 path into the bucket name and path."""
106
+ parsed = urlparse(url)
107
+ if not parsed.netloc or not parsed.path:
108
+ raise ValueError("bad s3 path {}".format(url))
109
+ bucket_name = parsed.netloc
110
+ s3_path = parsed.path
111
+ # Remove '/' at beginning of path.
112
+ if s3_path.startswith("/"):
113
+ s3_path = s3_path[1:]
114
+ return bucket_name, s3_path
115
+
116
+
117
+ def s3_request(func: Callable):
118
+ """
119
+ Wrapper function for s3 requests in order to create more helpful error
120
+ messages.
121
+ """
122
+
123
+ @wraps(func)
124
+ def wrapper(url: str, *args, **kwargs):
125
+ try:
126
+ return func(url, *args, **kwargs)
127
+ except ClientError as exc:
128
+ if int(exc.response["Error"]["Code"]) == 404:
129
+ raise FileNotFoundError("file {} not found".format(url))
130
+ else:
131
+ raise
132
+
133
+ return wrapper
134
+
135
+
136
+ @s3_request
137
+ def s3_etag(url: str) -> Optional[str]:
138
+ """Check ETag on S3 object."""
139
+ s3_resource = boto3.resource("s3")
140
+ bucket_name, s3_path = split_s3_path(url)
141
+ s3_object = s3_resource.Object(bucket_name, s3_path)
142
+ return s3_object.e_tag
143
+
144
+
145
+ @s3_request
146
+ def s3_get(url: str, temp_file: IO) -> None:
147
+ """Pull a file directly from S3."""
148
+ s3_resource = boto3.resource("s3")
149
+ bucket_name, s3_path = split_s3_path(url)
150
+ s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
151
+
152
+
153
+ def http_get(url: str, temp_file: IO) -> None:
154
+ req = requests.get(url, stream=True)
155
+ content_length = req.headers.get('Content-Length')
156
+ total = int(content_length) if content_length is not None else None
157
+ progress = tqdm(unit="B", total=total)
158
+ for chunk in req.iter_content(chunk_size=1024):
159
+ if chunk: # filter out keep-alive new chunks
160
+ progress.update(len(chunk))
161
+ temp_file.write(chunk)
162
+ progress.close()
163
+
164
+
165
+ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
166
+ """
167
+ Given a URL, look for the corresponding dataset in the local cache.
168
+ If it's not there, download it. Then return the path to the cached file.
169
+ """
170
+ if cache_dir is None:
171
+ cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
172
+ if isinstance(cache_dir, Path):
173
+ cache_dir = str(cache_dir)
174
+
175
+ os.makedirs(cache_dir, exist_ok=True)
176
+
177
+ # Get eTag to add to filename, if it exists.
178
+ if url.startswith("s3://"):
179
+ etag = s3_etag(url)
180
+ else:
181
+ response = requests.head(url, allow_redirects=True)
182
+ if response.status_code != 200:
183
+ raise IOError("HEAD request failed for url {} with status code {}"
184
+ .format(url, response.status_code))
185
+ etag = response.headers.get("ETag")
186
+
187
+ filename = url_to_filename(url, etag)
188
+
189
+ # get cache path to put the file
190
+ cache_path = os.path.join(cache_dir, filename)
191
+
192
+ if not os.path.exists(cache_path):
193
+ # Download to temporary file, then copy to cache dir once finished.
194
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
195
+ with tempfile.NamedTemporaryFile() as temp_file:
196
+ logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
197
+
198
+ # GET file object
199
+ if url.startswith("s3://"):
200
+ s3_get(url, temp_file)
201
+ else:
202
+ http_get(url, temp_file)
203
+
204
+ # we are copying the file before closing it, so flush to avoid truncation
205
+ temp_file.flush()
206
+ # shutil.copyfileobj() starts at the current position, so go to the start
207
+ temp_file.seek(0)
208
+
209
+ logger.info("copying %s to cache at %s", temp_file.name, cache_path)
210
+ with open(cache_path, 'wb') as cache_file:
211
+ shutil.copyfileobj(temp_file, cache_file)
212
+
213
+ logger.info("creating metadata file for %s", cache_path)
214
+ meta = {'url': url, 'etag': etag}
215
+ meta_path = cache_path + '.json'
216
+ with open(meta_path, 'w') as meta_file:
217
+ json.dump(meta, meta_file)
218
+
219
+ logger.info("removing temp file %s", temp_file.name)
220
+
221
+ return cache_path
222
+
223
+
224
+ def read_set_from_file(filename: str) -> Set[str]:
225
+ '''
226
+ Extract a de-duped collection (set) of text from a file.
227
+ Expected file format is one item per line.
228
+ '''
229
+ collection = set()
230
+ with open(filename, 'r', encoding='utf-8') as file_:
231
+ for line in file_:
232
+ collection.add(line.rstrip())
233
+ return collection
234
+
235
+
236
+ def get_file_extension(path: str, dot=True, lower: bool = True):
237
+ ext = os.path.splitext(path)[1]
238
+ ext = ext if dot else ext[1:]
239
+ return ext.lower() if lower else ext
utils/model.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+ from __future__ import division
3
+ from __future__ import print_function
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .until_module import PreTrainedModel
9
+ from .module_cross import CrossModel, CrossConfig
10
+ from .module_decoder import DecoderModel, DecoderConfig
11
+
12
+ from utils.module_clip import CLIP, convert_weights
13
+ from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
14
+
15
+
16
+ def update_attr(target_name, target_config, target_attr_name, source_config, source_attr_name, default_value=None):
17
+ if hasattr(source_config, source_attr_name):
18
+ if default_value is None or getattr(source_config, source_attr_name) != default_value:
19
+ setattr(target_config, target_attr_name, getattr(source_config, source_attr_name))
20
+ return target_config
21
+
22
+ class CLIP4IDCPreTrainedModel(PreTrainedModel, nn.Module):
23
+ """ An abstract class to handle weights initialization and
24
+ a simple interface for dowloading and loading pretrained models.
25
+ """
26
+ def __init__(self, cross_config, decoder_config, *inputs, **kwargs):
27
+ super(CLIP4IDCPreTrainedModel, self).__init__(cross_config, decoder_config)
28
+ self.cross_config = cross_config
29
+ self.decoder_config = decoder_config
30
+ self.clip = None
31
+ self.cross = None
32
+
33
+ @classmethod
34
+ def from_pretrained(cls, cross_model_name, decoder_model_name, state_dict=None, cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
35
+
36
+
37
+ if state_dict is None: state_dict = {}
38
+ pretrained_clip_name = "ViT-B/16"
39
+ clip_state_dict = CLIP.get_config(pretrained_clip_name=pretrained_clip_name)
40
+ for key, val in clip_state_dict.items():
41
+ new_key = "clip." + key
42
+ if new_key not in state_dict:
43
+ state_dict[new_key] = val.clone()
44
+
45
+ cross_config, _ = CrossConfig.get_config(cross_model_name, cache_dir, type_vocab_size, state_dict=None)
46
+ decoder_config, _ = DecoderConfig.get_config(decoder_model_name, cache_dir, type_vocab_size, state_dict=None)
47
+
48
+ model = cls(cross_config, decoder_config, clip_state_dict, *inputs, **kwargs)
49
+
50
+ ## ===> Initialization trick [HARD CODE]
51
+ if model.linear_patch == "3d":
52
+ contain_conv2 = False
53
+ for key in state_dict.keys():
54
+ if key.find("visual.conv2.weight") > -1:
55
+ contain_conv2 = True
56
+ break
57
+ if contain_conv2 is False and hasattr(model.clip.visual, "conv2"):
58
+ cp_weight = state_dict["clip.visual.conv1.weight"].clone()
59
+ kernel_size = model.clip.visual.conv2.weight.size(2)
60
+ conv2_size = model.clip.visual.conv2.weight.size()
61
+ conv2_size = list(conv2_size)
62
+
63
+ left_conv2_size = conv2_size.copy()
64
+ right_conv2_size = conv2_size.copy()
65
+ left_conv2_size[2] = (kernel_size - 1) // 2
66
+ right_conv2_size[2] = kernel_size - 1 - left_conv2_size[2]
67
+
68
+ left_zeros, right_zeros = None, None
69
+ if left_conv2_size[2] > 0:
70
+ left_zeros = torch.zeros(*tuple(left_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device)
71
+ if right_conv2_size[2] > 0:
72
+ right_zeros = torch.zeros(*tuple(right_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device)
73
+
74
+ cat_list = []
75
+ if left_zeros != None: cat_list.append(left_zeros)
76
+ cat_list.append(cp_weight.unsqueeze(2))
77
+ if right_zeros != None: cat_list.append(right_zeros)
78
+ cp_weight = torch.cat(cat_list, dim=2)
79
+
80
+ state_dict["clip.visual.conv2.weight"] = cp_weight
81
+
82
+ ## <=== End of initialization trick
83
+
84
+ if state_dict is not None:
85
+ model = cls.init_preweight(model, state_dict)
86
+
87
+ return model
88
+
89
+
90
+
91
+ class CLIP4IDC(CLIP4IDCPreTrainedModel):
92
+ def __init__(self, cross_config, decoder_config, clip_state_dict):
93
+ super(CLIP4IDC, self).__init__(cross_config, decoder_config)
94
+ self.ignore_video_index = -1
95
+
96
+ # assert self.task_config.max_words <= cross_config.max_position_embeddings
97
+
98
+ # CLIP Encoders: From OpenAI: CLIP [https://github.com/openai/CLIP] ===>
99
+ vit = "visual.proj" in clip_state_dict
100
+ assert vit
101
+ if vit:
102
+ vision_width = clip_state_dict["visual.conv1.weight"].shape[0]
103
+ vision_layers = len(
104
+ [k for k in clip_state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
105
+ vision_patch_size = clip_state_dict["visual.conv1.weight"].shape[-1]
106
+ grid_size = round((clip_state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
107
+ image_resolution = vision_patch_size * grid_size
108
+ else:
109
+ counts: list = [len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"visual.layer{b}"))) for b in
110
+ [1, 2, 3, 4]]
111
+ vision_layers = tuple(counts)
112
+ vision_width = clip_state_dict["visual.layer1.0.conv1.weight"].shape[0]
113
+ output_width = round((clip_state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
114
+ vision_patch_size = None
115
+ assert output_width ** 2 + 1 == clip_state_dict["visual.attnpool.positional_embedding"].shape[0]
116
+ image_resolution = output_width * 32
117
+
118
+ embed_dim = clip_state_dict["text_projection"].shape[1]
119
+ context_length = clip_state_dict["positional_embedding"].shape[0]
120
+ vocab_size = clip_state_dict["token_embedding.weight"].shape[0]
121
+ transformer_width = clip_state_dict["ln_final.weight"].shape[0]
122
+ transformer_heads = transformer_width // 64
123
+ transformer_layers = len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"transformer.resblocks")))
124
+
125
+ self.linear_patch = '2d'
126
+
127
+ # use .float() to avoid overflow/underflow from fp16 weight. https://github.com/openai/CLIP/issues/40
128
+ cut_top_layer = 0
129
+ self.clip = CLIP(
130
+ embed_dim,
131
+ image_resolution, vision_layers-cut_top_layer, vision_width, vision_patch_size,
132
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers-cut_top_layer,
133
+ linear_patch=self.linear_patch, intra_layers=9
134
+ ).float()
135
+
136
+ bert_word_embeddings_weight = self.clip.token_embedding.weight
137
+ bert_position_embeddings_weight = self.clip.positional_embedding
138
+
139
+ for key in ["input_resolution", "context_length", "vocab_size"]:
140
+ if key in clip_state_dict:
141
+ del clip_state_dict[key]
142
+
143
+ convert_weights(self.clip)
144
+ # <=== End of CLIP Encoders
145
+
146
+ self.decoder = DecoderModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight)
147
+
148
+ self.apply(self.init_weights)
149
+
150
+ def get_visual_output(self, video, visual_mask, left_gt_map, right_gt_map, shaped=False, video_frame=-1):
151
+
152
+ bs_pair = visual_mask.size(0)
153
+ visual_hidden, visual_output, left_map, right_map = self.clip.encode_image(video, left_gt_map, right_gt_map, video_frame=video_frame, return_hidden=True)
154
+ visual_hidden = visual_hidden.float()
155
+ visual_output = visual_output.float()
156
+ visual_hidden = visual_hidden.view(bs_pair, -1, visual_hidden.size(-1))
157
+
158
+ left_map = left_map.float()
159
+ right_map = right_map.float()
160
+
161
+ return visual_hidden, visual_output, left_map, right_map
162
+
163
+ def get_sequence_visual_output(self, video, visual_mask, left_gt_map, right_gt_map, shaped=False, video_frame=-1):
164
+ if shaped is False:
165
+ visual_mask = visual_mask.view(-1, visual_mask.shape[-1])
166
+ video = torch.as_tensor(video).float()
167
+ b, pair, channel, h, w = video.shape
168
+ video = video.view(b * pair, channel, h, w)
169
+ video_frame = pair
170
+
171
+ _, visual_hidden, left_map, right_map = self.get_visual_output(video, visual_mask, left_gt_map, right_gt_map, shaped=True, video_frame=video_frame)
172
+
173
+ return visual_hidden, left_map, right_map
174
+
175
+ def _get_decoder_score(self, visual_output, visual_mask, input_caption_ids, decoder_mask):
176
+ res_tuples = ()
177
+ decoder_scores = self.decoder(input_caption_ids, encoder_outs=visual_output, answer_mask=decoder_mask, encoder_mask=visual_mask)
178
+
179
+ return decoder_scores, res_tuples
180
+
181
+ def decoder_caption(self, visual_output, visual_mask, input_caption_ids, decoder_mask, get_logits=False):
182
+
183
+ decoder_scores, _ = self._get_decoder_score(visual_output, visual_mask,
184
+ input_caption_ids, decoder_mask)
185
+
186
+ if get_logits:
187
+ return decoder_scores
188
+
189
+ _, decoder_scores_result = torch.max(decoder_scores, -1)
190
+
191
+ return decoder_scores_result
192
+
193
+
194
+ def init_model(model_path, device):
195
+
196
+ model_state_dict = torch.load(model_path, map_location='cpu')
197
+
198
+ # Prepare model
199
+ cache_dir = ""
200
+ model = CLIP4IDC.from_pretrained("cross-base", "decoder-base", cache_dir=cache_dir, state_dict=model_state_dict)
201
+
202
+ model.to(device)
203
+
204
+ return model
utils/module_clip.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from: https://github.com/openai/CLIP/blob/main/clip/clip.py
3
+ """
4
+ import warnings
5
+ from collections import OrderedDict
6
+ from typing import Tuple, Union, Optional
7
+
8
+ import hashlib
9
+ import os
10
+ import urllib
11
+ import warnings
12
+ from tqdm import tqdm
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ from torch import Tensor
18
+ from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
19
+ from torch.nn.init import xavier_uniform_
20
+ from torch.nn.init import constant_
21
+ from torch.nn.init import xavier_normal_
22
+ from torch.nn.parameter import Parameter
23
+
24
+ from torch.nn.modules.module import Module
25
+ from .module_gated_attention import gated_coattention
26
+ from torch import nn
27
+
28
+
29
+ _MODELS = {
30
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
33
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
34
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
35
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
36
+ }
37
+ _PT_NAME = {
38
+ "RN50": "RN50.pt",
39
+ "RN101": "RN101.pt",
40
+ "RN50x4": "RN50x4.pt",
41
+ "RN50x16": "RN50x16.pt",
42
+ "ViT-B/32": "ViT-B-32.pt",
43
+ "ViT-B/16": "ViT-B-16.pt",
44
+ }
45
+
46
+ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
47
+ os.makedirs(root, exist_ok=True)
48
+ filename = os.path.basename(url)
49
+
50
+ expected_sha256 = url.split("/")[-2]
51
+ download_target = os.path.join(root, filename)
52
+
53
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
54
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
55
+
56
+ if os.path.isfile(download_target):
57
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
58
+ return download_target
59
+ else:
60
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
61
+
62
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
63
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
64
+ while True:
65
+ buffer = source.read(8192)
66
+ if not buffer:
67
+ break
68
+
69
+ output.write(buffer)
70
+ loop.update(len(buffer))
71
+
72
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
73
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
74
+
75
+ return download_target
76
+
77
+ def available_models():
78
+ """Returns the names of available CLIP models"""
79
+ return list(_MODELS.keys())
80
+
81
+ # =============================
82
+
83
+
84
+ class TABAttention(Module):
85
+ r"""Allows the model to jointly attend to information
86
+ from different representation subspaces.
87
+ See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_
88
+
89
+ .. math::
90
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
91
+
92
+ where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
93
+
94
+ Args:
95
+ embed_dim: total dimension of the model.
96
+ num_heads: parallel attention heads.
97
+ dropout: a Dropout layer on attn_output_weights. Default: 0.0.
98
+ bias: add bias as module parameter. Default: True.
99
+ add_bias_kv: add bias to the key and value sequences at dim=0.
100
+ add_zero_attn: add a new batch of zeros to the key and
101
+ value sequences at dim=1.
102
+ kdim: total number of features in key. Default: None.
103
+ vdim: total number of features in value. Default: None.
104
+
105
+ Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
106
+ to :attr:`embed_dim` such that query, key, and value have the same
107
+ number of features.
108
+
109
+ Examples::
110
+
111
+ >>> multihead_attn = TABAttention(embed_dim, num_heads)
112
+ >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
113
+
114
+
115
+
116
+ This is a version of multihead attention written to comply with the defintion of TAB!!!
117
+ """
118
+ bias_k: Optional[torch.Tensor]
119
+ bias_v: Optional[torch.Tensor]
120
+
121
+ def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
122
+ super(TABAttention, self).__init__()
123
+ self.embed_dim = embed_dim
124
+ self.kdim = kdim if kdim is not None else embed_dim
125
+ self.vdim = vdim if vdim is not None else embed_dim
126
+ self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
127
+
128
+ self.num_heads = num_heads
129
+ self.dropout = dropout
130
+ self.head_dim = embed_dim // num_heads
131
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
132
+
133
+ if self._qkv_same_embed_dim is False:
134
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
135
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
136
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
137
+ self.register_parameter('in_proj_weight', None)
138
+ else:
139
+ self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
140
+ self.register_parameter('q_proj_weight', None)
141
+ self.register_parameter('k_proj_weight', None)
142
+ self.register_parameter('v_proj_weight', None)
143
+
144
+ if bias:
145
+ self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
146
+ else:
147
+ self.register_parameter('in_proj_bias', None)
148
+ self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias)
149
+
150
+ if add_bias_kv:
151
+ self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
152
+ self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
153
+ else:
154
+ self.bias_k = self.bias_v = None
155
+
156
+ self.add_zero_attn = add_zero_attn
157
+
158
+ self._reset_parameters()
159
+
160
+ def _reset_parameters(self):
161
+ if self._qkv_same_embed_dim:
162
+ xavier_uniform_(self.in_proj_weight)
163
+ else:
164
+ xavier_uniform_(self.q_proj_weight)
165
+ xavier_uniform_(self.k_proj_weight)
166
+ xavier_uniform_(self.v_proj_weight)
167
+
168
+ if self.in_proj_bias is not None:
169
+ constant_(self.in_proj_bias, 0.)
170
+ constant_(self.out_proj.bias, 0.)
171
+ if self.bias_k is not None:
172
+ xavier_normal_(self.bias_k)
173
+ if self.bias_v is not None:
174
+ xavier_normal_(self.bias_v)
175
+
176
+ def __setstate__(self, state):
177
+ # Support loading old TABAttention checkpoints generated by v1.1.0
178
+ if '_qkv_same_embed_dim' not in state:
179
+ state['_qkv_same_embed_dim'] = True
180
+
181
+ super(TABAttention, self).__setstate__(state)
182
+
183
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, gt_attention_map: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
184
+ need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
185
+ r"""
186
+ Args:
187
+ query, key, value: map a query and a set of key-value pairs to an output.
188
+ See "Attention Is All You Need" for more details.
189
+ key_padding_mask: if provided, specified padding elements in the key will
190
+ be ignored by the attention. When given a binary mask and a value is True,
191
+ the corresponding value on the attention layer will be ignored. When given
192
+ a byte mask and a value is non-zero, the corresponding value on the attention
193
+ layer will be ignored
194
+ need_weights: output attn_output_weights.
195
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
196
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
197
+
198
+ Shapes for inputs:
199
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
200
+ the embedding dimension.
201
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
202
+ the embedding dimension.
203
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
204
+ the embedding dimension.
205
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
206
+ If a ByteTensor is provided, the non-zero positions will be ignored while the position
207
+ with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
208
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
209
+ - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
210
+ source sequence length.
211
+
212
+ If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
213
+ length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
214
+ the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
215
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
216
+ is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
217
+ is provided, it will be added to the attention weight.
218
+
219
+ Shapes for outputs:
220
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
221
+ E is the embedding dimension.
222
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
223
+ L is the target sequence length, S is the source sequence length.
224
+ """
225
+ if not self._qkv_same_embed_dim:
226
+ return gated_coattention(
227
+ query, key, value, self.embed_dim, self.num_heads,
228
+ self.in_proj_weight.half(), self.in_proj_bias.half(),
229
+ self.bias_k, self.bias_v, self.add_zero_attn,
230
+ self.dropout, self.out_proj.weight.half(), self.out_proj.bias.half(),
231
+ training=self.training, gt_attention_map=gt_attention_map,
232
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
233
+ attn_mask=attn_mask, use_separate_proj_weight=True,
234
+ q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
235
+ v_proj_weight=self.v_proj_weight)
236
+ else:
237
+ return gated_coattention(
238
+ query, key, value, self.embed_dim, self.num_heads,
239
+ self.in_proj_weight.half(), self.in_proj_bias.half(),
240
+ self.bias_k, self.bias_v, self.add_zero_attn,
241
+ self.dropout, self.out_proj.weight.half(), self.out_proj.bias.half(),
242
+ training=self.training, gt_attention_map=gt_attention_map,
243
+ key_padding_mask=key_padding_mask, need_weights=need_weights,
244
+ attn_mask=attn_mask)
245
+
246
+
247
+ class LayerNorm(nn.LayerNorm):
248
+ """Subclass torch's LayerNorm to handle fp16."""
249
+
250
+ def forward(self, x: torch.Tensor):
251
+ orig_type = x.dtype
252
+ ret = super().forward(x.type(torch.float32))
253
+ return ret.type(orig_type)
254
+
255
+
256
+ class QuickGELU(nn.Module):
257
+ def forward(self, x: torch.Tensor):
258
+ return x * torch.sigmoid(1.702 * x)
259
+
260
+
261
+ class ResidualAttentionBlock(nn.Module):
262
+ def __init__(self, d_model: int, n_head: int, attn_mask=None):
263
+ super().__init__()
264
+
265
+ self.attn = nn.MultiheadAttention(d_model, n_head)
266
+ self.ln_1 = LayerNorm(d_model)
267
+ self.mlp = nn.Sequential(OrderedDict([
268
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
269
+ ("gelu", QuickGELU()),
270
+ ("c_proj", nn.Linear(d_model * 4, d_model))
271
+ ]))
272
+ self.ln_2 = LayerNorm(d_model)
273
+ self.attn_mask = attn_mask
274
+
275
+ def attention(self, x: torch.Tensor):
276
+ attn_mask_ = self.attn_mask
277
+ if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
278
+ attn_mask_ = self.attn_mask(x.size(0)) # LND
279
+
280
+ attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None
281
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
282
+
283
+ def forward(self, x_tuple:tuple):
284
+ x, video_frame = x_tuple
285
+ x = x + self.attention(self.ln_1(x))
286
+ x = x + self.mlp(self.ln_2(x))
287
+ return (x, video_frame)
288
+
289
+ def visualize_attention(self, x: torch.Tensor):
290
+ attn_outputs, attn_weights = self.attn(x, x, x, need_weights=True, attn_mask=None)
291
+ return attn_outputs, attn_weights
292
+
293
+ def visualize_forward(self, x_tuple:tuple):
294
+ x, video_frame = x_tuple
295
+ attn_outputs, attn_weights = self.visualize_attention(self.ln_1(x))
296
+ x = x + attn_outputs
297
+ x = x + self.mlp(self.ln_2(x))
298
+ return (x, video_frame, attn_weights)
299
+
300
+ class TABLayer(nn.Module):
301
+ def __init__(self, d_model: int, n_head: int, attn_mask=None):
302
+ super().__init__()
303
+
304
+ self.attn = TABAttention(d_model, n_head)
305
+ self.ln_1 = LayerNorm(d_model)
306
+ self.mlp = nn.Sequential(OrderedDict([
307
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
308
+ ("gelu", QuickGELU()),
309
+ ("c_proj", nn.Linear(d_model * 4, d_model))
310
+ ]))
311
+ self.ln_2 = LayerNorm(d_model)
312
+ self.attn_mask = attn_mask
313
+
314
+ def attention(self, x: torch.Tensor, y: torch.Tensor):
315
+ attn_mask_ = self.attn_mask
316
+ if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
317
+ attn_mask_ = self.attn_mask(x.size(0)) # LND
318
+
319
+ attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None
320
+ return self.attn(x, y, y, need_weights=False, attn_mask=attn_mask_)[0]
321
+
322
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
323
+ x = self.attention(self.ln_1(x), self.ln_1(y))
324
+ x = x + self.mlp(self.ln_2(x))
325
+ return x
326
+
327
+ def visualize_attention(self, x: torch.Tensor, y: torch.Tensor, gt_attention_map):
328
+ attn_outputs, attn_weights = self.attn(x, y, y, gt_attention_map=gt_attention_map, need_weights=True, attn_mask=None)
329
+ return attn_outputs, attn_weights
330
+
331
+ def visualize_forward(self, x: torch.Tensor, y: torch.Tensor, gt_attention_map):
332
+ attn_outputs, attn_weights = self.visualize_attention(self.ln_1(x), self.ln_1(y), gt_attention_map)
333
+ x = attn_outputs
334
+ x = x + self.mlp(self.ln_2(x))
335
+ return (x, attn_weights)
336
+
337
+ class visionTransformer(nn.Module):
338
+ def __init__(self, width: int, layers: int, heads: int, attn_mask = None):
339
+ super().__init__()
340
+ self.width = width
341
+ self.layers = layers
342
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) if i < (layers - 1) else TABLayer(width, 1, attn_mask) for i in range(layers)])
343
+
344
+ def forward(self, x: torch.Tensor, video_frame=-1):
345
+ return self.resblocks((x, video_frame))[0]
346
+
347
+ class Transformer(nn.Module):
348
+ def __init__(self, width: int, layers: int, heads: int, attn_mask = None):
349
+ super().__init__()
350
+ self.width = width
351
+ self.layers = layers
352
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
353
+
354
+ def forward(self, x: torch.Tensor, video_frame=-1):
355
+ return self.resblocks((x, video_frame))[0]
356
+
357
+
358
+ class VisualTransformer(nn.Module):
359
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int,
360
+ linear_patch: str = '2d', intra_layers: int = 9):
361
+ super().__init__()
362
+ self.input_resolution = input_resolution
363
+ self.output_dim = output_dim
364
+ self.intra_layers = intra_layers
365
+
366
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
367
+
368
+ scale = width ** -0.5
369
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
370
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
371
+ self.ln_pre = LayerNorm(width)
372
+
373
+ self.joint_positional_embedding = nn.Parameter(scale * torch.randn(2 * ((input_resolution // patch_size) ** 2 + 1), width))
374
+ self.bef_embedding = nn.Parameter(scale * torch.randn(width))
375
+ self.aft_embedding = nn.Parameter(scale * torch.randn(width))
376
+ self.ln_mid = LayerNorm(width)
377
+
378
+ self.transformer = visionTransformer(width, layers, heads)
379
+
380
+ self.ln_post = LayerNorm(width)
381
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
382
+
383
+ # For 3D
384
+ assert linear_patch in ['2d', '3d']
385
+ self.linear_patch = linear_patch
386
+ if self.linear_patch == '3d':
387
+ self.conv2 = nn.Conv3d(in_channels=3, out_channels=width, kernel_size=(3, patch_size, patch_size),
388
+ stride=(1, patch_size, patch_size), padding=(1, 0, 0), bias=False)
389
+
390
+ def forward(self, x: torch.Tensor, left_gt_map, right_gt_map, video_frame=-1, visualize=False):
391
+
392
+ if self.linear_patch == '3d':
393
+ assert video_frame != -1
394
+ x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2], x.shape[-1])
395
+ x_3d = x_3d.permute(0, 2, 1, 3, 4)
396
+ x_3d = self.conv2(x_3d) # shape = [*, width, frame, grid, grid]
397
+ x_3d = x_3d.permute(0, 2, 1, 3, 4) # shape = [*, frame, width, grid, grid]
398
+ x = x_3d.reshape(-1, x_3d.shape[-3], x_3d.shape[-2], x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid]
399
+ else:
400
+ x = self.conv1(x) # shape = [*, width, grid, grid]
401
+
402
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
403
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
404
+
405
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
406
+ x = x + self.positional_embedding.to(x.dtype)
407
+ x = self.ln_pre(x)
408
+
409
+ x = x.permute(1, 0, 2) # NLD -> LND
410
+
411
+ if visualize is True:
412
+ all_attn_weights = []
413
+ for i in range(self.intra_layers):
414
+ x, _, attn_weights = self.transformer.resblocks[i].visualize_forward((x, video_frame))
415
+ attn_weights = attn_weights.view(x.size(1) // video_frame, -1, attn_weights.size(-2),
416
+ attn_weights.size(-1))
417
+ all_attn_weights.append(attn_weights)
418
+ else:
419
+ for i in range(self.intra_layers):
420
+ x = self.transformer.resblocks[i]((x, video_frame))[0]
421
+ x = x.permute(1, 0, 2) # LND -> NLD
422
+
423
+ bs = x.size(0) // video_frame
424
+ x = x.view(bs, video_frame, x.size(-2), x.size(-1))
425
+ x = torch.cat([x[:, 0] + self.bef_embedding.to(x.dtype),
426
+ x[:, 1] + self.aft_embedding.to(x.dtype)], dim=1)
427
+
428
+ x = x + self.joint_positional_embedding.to(x.dtype)
429
+ x = self.ln_mid(x)
430
+
431
+ x = x.permute(1, 0, 2) # NLD -> LND
432
+
433
+ if visualize is True:
434
+ for i in range(self.intra_layers, self.transformer.layers - 1):
435
+ x, _, attn_weights = self.transformer.resblocks[i].visualize_forward((x, video_frame))
436
+ all_attn_weights.append(attn_weights)
437
+ cls_index = int(x.size(0) / 2)
438
+ left_features, left_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[:cls_index, :, :], x[cls_index:, :, :], right_gt_map)
439
+ right_features, right_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[cls_index:, :, :], x[:cls_index, :, :], left_gt_map)
440
+
441
+ all_attn_weights.append(left_attn_weights)
442
+ all_attn_weights.append(right_attn_weights)
443
+ else:
444
+ for i in range(self.intra_layers, self.transformer.layers - 1):
445
+ x = self.transformer.resblocks[i]((x, video_frame))[0]
446
+ cls_index = int(x.size(0) / 2)
447
+ left_features, left_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[:cls_index, :, :], x[cls_index:, :, :], right_gt_map)
448
+ right_features, right_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[cls_index:, :, :], x[:cls_index, :, :], left_gt_map)
449
+
450
+ left_features = left_features.permute(1, 0, 2) # LND -> NLD
451
+ right_features = right_features.permute(1, 0, 2) # LND -> NLD
452
+ x = torch.cat([left_features, right_features], 1)
453
+
454
+ # Move the three lines below to `encode_image` for entire hidden sequence
455
+ # x = self.ln_post(x[:, 0, :])
456
+ # if self.proj is not None:
457
+ # x = x @ self.proj
458
+
459
+ if visualize is True:
460
+ return x, all_attn_weights
461
+ return x, left_attn_weights, right_attn_weights
462
+
463
+
464
+ class CLIP(nn.Module):
465
+ def __init__(self,
466
+ embed_dim: int,
467
+ # vision
468
+ image_resolution: int,
469
+ vision_layers: Union[Tuple[int, int, int, int], int],
470
+ vision_width: int,
471
+ vision_patch_size: int,
472
+ # text
473
+ context_length: int,
474
+ vocab_size: int,
475
+ transformer_width: int,
476
+ transformer_heads: int,
477
+ transformer_layers: int,
478
+ # vision linear of patch
479
+ linear_patch: str = '2d',
480
+ intra_layers: int = 9,
481
+ ):
482
+ super().__init__()
483
+
484
+ self.context_length = context_length
485
+
486
+ vision_heads = vision_width // 64
487
+ self.visual = VisualTransformer(
488
+ input_resolution=image_resolution,
489
+ patch_size=vision_patch_size,
490
+ width=vision_width,
491
+ layers=vision_layers,
492
+ heads=vision_heads,
493
+ output_dim=embed_dim,
494
+ linear_patch=linear_patch,
495
+ intra_layers=intra_layers,
496
+ )
497
+
498
+ self.transformer = Transformer(
499
+ width=transformer_width,
500
+ layers=transformer_layers,
501
+ heads=transformer_heads,
502
+ attn_mask=self.build_attention_mask
503
+ )
504
+
505
+ self.vocab_size = vocab_size
506
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
507
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
508
+ self.ln_final = LayerNorm(transformer_width)
509
+
510
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
511
+ self.logit_scale = nn.Parameter(torch.ones([]))
512
+
513
+ self.initialize_parameters()
514
+
515
+ def initialize_parameters(self):
516
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
517
+ nn.init.normal_(self.positional_embedding, std=0.01)
518
+
519
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
520
+ attn_std = self.transformer.width ** -0.5
521
+ fc_std = (2 * self.transformer.width) ** -0.5
522
+ for block in self.transformer.resblocks:
523
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
524
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
525
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
526
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
527
+
528
+ if self.text_projection is not None:
529
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
530
+
531
+ @staticmethod
532
+ def get_config(pretrained_clip_name="ViT-B/32"):
533
+ model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViT-B-32.pt")
534
+ if pretrained_clip_name in _MODELS and pretrained_clip_name in _PT_NAME:
535
+ model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[pretrained_clip_name])
536
+
537
+ if pretrained_clip_name in ["ViT-B/32", "ViT-B/16"] and os.path.exists(model_path):
538
+ pass
539
+ else:
540
+ if pretrained_clip_name in _MODELS:
541
+ model_path = _download(_MODELS[pretrained_clip_name])
542
+ elif os.path.isfile(pretrained_clip_name):
543
+ model_path = pretrained_clip_name
544
+ else:
545
+ raise RuntimeError(f"Model {pretrained_clip_name} not found; available models = {available_models()}")
546
+
547
+ try:
548
+ # loading JIT archive
549
+ model = torch.jit.load(model_path, map_location="cpu").eval()
550
+ state_dict = model.state_dict()
551
+ except RuntimeError:
552
+ state_dict = torch.load(model_path, map_location="cpu")
553
+
554
+ return state_dict
555
+
556
+ def build_attention_mask(self, context_length):
557
+ # lazily create causal attention mask, with full attention between the vision tokens
558
+ # pytorch uses additive attention mask; fill with -inf
559
+ mask = torch.zeros(context_length, context_length)
560
+ mask.fill_(float("-inf"))
561
+ mask.triu_(1) # zero out the lower diagonal
562
+ return mask
563
+
564
+ @property
565
+ def dtype(self):
566
+ return self.visual.conv1.weight.dtype
567
+
568
+ def encode_image(self, image, left_gt_map, right_gt_map, return_hidden=False, video_frame=-1):
569
+ hidden, left_map, right_map = self.visual(image.type(self.dtype), left_gt_map, right_gt_map, video_frame=video_frame)
570
+ hidden = self.visual.ln_post(hidden) @ self.visual.proj
571
+
572
+ cls_index = int(hidden.size(1) / 2)
573
+ hidden2 = torch.cat([hidden[:, 0, :].unsqueeze(1), hidden[:, cls_index, :].unsqueeze(1)], 1)
574
+ x = torch.mean(hidden2, 1)
575
+
576
+ if return_hidden:
577
+ return x, hidden2, left_map, right_map
578
+
579
+ return x, left_map, right_map
580
+
581
+ def encode_text(self, text, return_hidden=False):
582
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
583
+
584
+ pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype)
585
+ x = x + pos_emd
586
+ x = x.permute(1, 0, 2) # NLD -> LND
587
+ x = self.transformer(x)
588
+ x = x.permute(1, 0, 2) # LND -> NLD
589
+
590
+ hidden = self.ln_final(x).type(self.dtype) @ self.text_projection
591
+
592
+ # x.shape = [batch_size, n_ctx, transformer.width]
593
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
594
+ x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)]
595
+
596
+ if return_hidden:
597
+ return x, hidden
598
+
599
+ return x
600
+
601
+ def forward(self, image, text):
602
+ image_features = self.encode_image(image)
603
+ text_features = self.encode_text(text)
604
+
605
+ # normalized features
606
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
607
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
608
+
609
+ # cosine similarity as logits
610
+ logit_scale = self.logit_scale.exp()
611
+ logits_per_image = logit_scale * image_features @ text_features.t()
612
+ logits_per_text = logit_scale * text_features @ image_features.t()
613
+
614
+ # shape = [global_batch_size, global_batch_size]
615
+ return logits_per_image, logits_per_text
616
+
617
+
618
+ def convert_weights(model: nn.Module):
619
+ """Convert applicable model parameters to fp16"""
620
+
621
+ def _convert_weights_to_fp16(l):
622
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
623
+ l.weight.data = l.weight.data.half()
624
+ if l.bias is not None:
625
+ l.bias.data = l.bias.data.half()
626
+
627
+ if isinstance(l, nn.MultiheadAttention):
628
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
629
+ tensor = getattr(l, attr)
630
+ if tensor is not None:
631
+ tensor.data = tensor.data.half()
632
+
633
+ for name in ["text_projection", "proj"]:
634
+ if hasattr(l, name):
635
+ attr = getattr(l, name)
636
+ if attr is not None:
637
+ attr.data = attr.data.half()
638
+
639
+ model.apply(_convert_weights_to_fp16)
640
+
641
+
642
+ def build_model(state_dict: dict):
643
+ vit = "visual.proj" in state_dict
644
+
645
+ if vit:
646
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
647
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
648
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
649
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
650
+ image_resolution = vision_patch_size * grid_size
651
+ else:
652
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
653
+ vision_layers = tuple(counts)
654
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
655
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
656
+ vision_patch_size = None
657
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
658
+ image_resolution = output_width * 32
659
+
660
+ embed_dim = state_dict["text_projection"].shape[1]
661
+ context_length = state_dict["positional_embedding"].shape[0]
662
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
663
+ transformer_width = state_dict["ln_final.weight"].shape[0]
664
+ transformer_heads = transformer_width // 64
665
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
666
+
667
+ model = CLIP(
668
+ embed_dim,
669
+ image_resolution, vision_layers, vision_width, vision_patch_size,
670
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
671
+ )
672
+
673
+ for key in ["input_resolution", "context_length", "vocab_size"]:
674
+ if key in state_dict:
675
+ del state_dict[key]
676
+
677
+ convert_weights(model)
678
+ model.load_state_dict(state_dict)
679
+ return model.eval()
utils/module_cross.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import os
23
+ import copy
24
+ import json
25
+ import math
26
+ import logging
27
+ import tarfile
28
+ import tempfile
29
+ import shutil
30
+
31
+ import torch
32
+ from torch import nn
33
+ import torch.nn.functional as F
34
+ from .file_utils import cached_path
35
+ from .until_config import PretrainedConfig
36
+ from .until_module import PreTrainedModel, LayerNorm, ACT2FN
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ PRETRAINED_MODEL_ARCHIVE_MAP = {}
41
+ CONFIG_NAME = 'cross_config.json'
42
+ WEIGHTS_NAME = 'cross_pytorch_model.bin'
43
+
44
+
45
+ class CrossConfig(PretrainedConfig):
46
+ """Configuration class to store the configuration of a `CrossModel`.
47
+ """
48
+ pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
49
+ config_name = CONFIG_NAME
50
+ weights_name = WEIGHTS_NAME
51
+ def __init__(self,
52
+ vocab_size_or_config_json_file,
53
+ hidden_size=768,
54
+ num_hidden_layers=12,
55
+ num_attention_heads=12,
56
+ intermediate_size=3072,
57
+ hidden_act="gelu",
58
+ hidden_dropout_prob=0.1,
59
+ attention_probs_dropout_prob=0.1,
60
+ max_position_embeddings=512,
61
+ type_vocab_size=2,
62
+ initializer_range=0.02):
63
+ """Constructs CrossConfig.
64
+
65
+ Args:
66
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`.
67
+ hidden_size: Size of the encoder layers and the pooler layer.
68
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
69
+ num_attention_heads: Number of attention heads for each attention layer in
70
+ the Transformer encoder.
71
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
72
+ layer in the Transformer encoder.
73
+ hidden_act: The non-linear activation function (function or string) in the
74
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
75
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
76
+ layers in the embeddings, encoder, and pooler.
77
+ attention_probs_dropout_prob: The dropout ratio for the attention
78
+ probabilities.
79
+ max_position_embeddings: The maximum sequence length that this model might
80
+ ever be used with. Typically set this to something large just in case
81
+ (e.g., 512 or 1024 or 2048).
82
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
83
+ `CrossModel`.
84
+ initializer_range: The sttdev of the truncated_normal_initializer for
85
+ initializing all weight matrices.
86
+ """
87
+ if isinstance(vocab_size_or_config_json_file, str):
88
+ with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
89
+ json_config = json.loads(reader.read())
90
+ for key, value in json_config.items():
91
+ self.__dict__[key] = value
92
+ elif isinstance(vocab_size_or_config_json_file, int):
93
+ self.vocab_size = vocab_size_or_config_json_file
94
+ self.hidden_size = hidden_size
95
+ self.num_hidden_layers = num_hidden_layers
96
+ self.num_attention_heads = num_attention_heads
97
+ self.hidden_act = hidden_act
98
+ self.intermediate_size = intermediate_size
99
+ self.hidden_dropout_prob = hidden_dropout_prob
100
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
101
+ self.max_position_embeddings = max_position_embeddings
102
+ self.type_vocab_size = type_vocab_size
103
+ self.initializer_range = initializer_range
104
+ else:
105
+ raise ValueError("First argument must be either a vocabulary size (int)"
106
+ "or the path to a pretrained model config file (str)")
107
+
108
+
109
+ class CrossEmbeddings(nn.Module):
110
+ """Construct the embeddings from word, position and token_type embeddings.
111
+ """
112
+ def __init__(self, config):
113
+ super(CrossEmbeddings, self).__init__()
114
+
115
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
116
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
117
+
118
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
119
+ # any TensorFlow checkpoint file
120
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
121
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
122
+
123
+ def forward(self, concat_embeddings, concat_type=None):
124
+
125
+ batch_size, seq_length = concat_embeddings.size(0), concat_embeddings.size(1)
126
+ if concat_type is None:
127
+ concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device)
128
+
129
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=concat_embeddings.device)
130
+ position_ids = position_ids.unsqueeze(0).expand(concat_embeddings.size(0), -1)
131
+
132
+ token_type_embeddings = self.token_type_embeddings(concat_type)
133
+ position_embeddings = self.position_embeddings(position_ids)
134
+
135
+ embeddings = concat_embeddings + position_embeddings + token_type_embeddings
136
+ embeddings = self.LayerNorm(embeddings)
137
+ embeddings = self.dropout(embeddings)
138
+ return embeddings
139
+
140
+ class CrossSelfAttention(nn.Module):
141
+ def __init__(self, config):
142
+ super(CrossSelfAttention, self).__init__()
143
+ if config.hidden_size % config.num_attention_heads != 0:
144
+ raise ValueError(
145
+ "The hidden size (%d) is not a multiple of the number of attention "
146
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
147
+ self.num_attention_heads = config.num_attention_heads
148
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
149
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
150
+
151
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
152
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
153
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
154
+
155
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
156
+
157
+ def transpose_for_scores(self, x):
158
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
159
+ x = x.view(*new_x_shape)
160
+ return x.permute(0, 2, 1, 3)
161
+
162
+ def forward(self, hidden_states, attention_mask):
163
+ mixed_query_layer = self.query(hidden_states)
164
+ mixed_key_layer = self.key(hidden_states)
165
+ mixed_value_layer = self.value(hidden_states)
166
+
167
+ query_layer = self.transpose_for_scores(mixed_query_layer)
168
+ key_layer = self.transpose_for_scores(mixed_key_layer)
169
+ value_layer = self.transpose_for_scores(mixed_value_layer)
170
+
171
+ # Take the dot product between "query" and "key" to get the raw attention scores.
172
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
173
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
174
+ # Apply the attention mask is (precomputed for all layers in CrossModel forward() function)
175
+ attention_scores = attention_scores + attention_mask
176
+
177
+ # Normalize the attention scores to probabilities.
178
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
179
+
180
+ # This is actually dropping out entire tokens to attend to, which might
181
+ # seem a bit unusual, but is taken from the original Transformer paper.
182
+ attention_probs = self.dropout(attention_probs)
183
+
184
+ context_layer = torch.matmul(attention_probs, value_layer)
185
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
186
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
187
+ context_layer = context_layer.view(*new_context_layer_shape)
188
+ return context_layer
189
+
190
+
191
+ class CrossSelfOutput(nn.Module):
192
+ def __init__(self, config):
193
+ super(CrossSelfOutput, self).__init__()
194
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
195
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
196
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
197
+
198
+ def forward(self, hidden_states, input_tensor):
199
+ hidden_states = self.dense(hidden_states)
200
+ hidden_states = self.dropout(hidden_states)
201
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
202
+ return hidden_states
203
+
204
+
205
+ class CrossAttention(nn.Module):
206
+ def __init__(self, config):
207
+ super(CrossAttention, self).__init__()
208
+ self.self = CrossSelfAttention(config)
209
+ self.output = CrossSelfOutput(config)
210
+
211
+ def forward(self, input_tensor, attention_mask):
212
+ self_output = self.self(input_tensor, attention_mask)
213
+ attention_output = self.output(self_output, input_tensor)
214
+ return attention_output
215
+
216
+
217
+ class CrossIntermediate(nn.Module):
218
+ def __init__(self, config):
219
+ super(CrossIntermediate, self).__init__()
220
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
221
+ self.intermediate_act_fn = ACT2FN[config.hidden_act] \
222
+ if isinstance(config.hidden_act, str) else config.hidden_act
223
+
224
+ def forward(self, hidden_states):
225
+ hidden_states = self.dense(hidden_states)
226
+ hidden_states = self.intermediate_act_fn(hidden_states)
227
+ return hidden_states
228
+
229
+
230
+ class CrossOutput(nn.Module):
231
+ def __init__(self, config):
232
+ super(CrossOutput, self).__init__()
233
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
234
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
235
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
236
+
237
+ def forward(self, hidden_states, input_tensor):
238
+ hidden_states = self.dense(hidden_states)
239
+ hidden_states = self.dropout(hidden_states)
240
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
241
+ return hidden_states
242
+
243
+
244
+ class CrossLayer(nn.Module):
245
+ def __init__(self, config):
246
+ super(CrossLayer, self).__init__()
247
+ self.attention = CrossAttention(config)
248
+ self.intermediate = CrossIntermediate(config)
249
+ self.output = CrossOutput(config)
250
+
251
+ def forward(self, hidden_states, attention_mask):
252
+ attention_output = self.attention(hidden_states, attention_mask)
253
+ intermediate_output = self.intermediate(attention_output)
254
+ layer_output = self.output(intermediate_output, attention_output)
255
+ return layer_output
256
+
257
+
258
+ class CrossEncoder(nn.Module):
259
+ def __init__(self, config):
260
+ super(CrossEncoder, self).__init__()
261
+ layer = CrossLayer(config)
262
+ self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
263
+
264
+ def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
265
+ all_encoder_layers = []
266
+ for layer_module in self.layer:
267
+ hidden_states = layer_module(hidden_states, attention_mask)
268
+ if output_all_encoded_layers:
269
+ all_encoder_layers.append(hidden_states)
270
+ if not output_all_encoded_layers:
271
+ all_encoder_layers.append(hidden_states)
272
+ return all_encoder_layers
273
+
274
+
275
+ class CrossPooler(nn.Module):
276
+ def __init__(self, config):
277
+ super(CrossPooler, self).__init__()
278
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
279
+ self.activation = nn.Tanh()
280
+
281
+ def forward(self, hidden_states):
282
+ # We "pool" the model by simply taking the hidden state corresponding
283
+ # to the first token.
284
+ first_token_tensor = hidden_states[:, 0]
285
+ pooled_output = self.dense(first_token_tensor)
286
+ pooled_output = self.activation(pooled_output)
287
+ return pooled_output
288
+
289
+
290
+ class CrossPredictionHeadTransform(nn.Module):
291
+ def __init__(self, config):
292
+ super(CrossPredictionHeadTransform, self).__init__()
293
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
294
+ self.transform_act_fn = ACT2FN[config.hidden_act] \
295
+ if isinstance(config.hidden_act, str) else config.hidden_act
296
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
297
+
298
+ def forward(self, hidden_states):
299
+ hidden_states = self.dense(hidden_states)
300
+ hidden_states = self.transform_act_fn(hidden_states)
301
+ hidden_states = self.LayerNorm(hidden_states)
302
+ return hidden_states
303
+
304
+
305
+ class CrossLMPredictionHead(nn.Module):
306
+ def __init__(self, config, cross_model_embedding_weights):
307
+ super(CrossLMPredictionHead, self).__init__()
308
+ self.transform = CrossPredictionHeadTransform(config)
309
+
310
+ # The output weights are the same as the input embeddings, but there is
311
+ # an output-only bias for each token.
312
+ self.decoder = nn.Linear(cross_model_embedding_weights.size(1),
313
+ cross_model_embedding_weights.size(0),
314
+ bias=False)
315
+ self.decoder.weight = cross_model_embedding_weights
316
+ self.bias = nn.Parameter(torch.zeros(cross_model_embedding_weights.size(0)))
317
+
318
+ def forward(self, hidden_states):
319
+ hidden_states = self.transform(hidden_states)
320
+ hidden_states = self.decoder(hidden_states) + self.bias
321
+ return hidden_states
322
+
323
+
324
+ class CrossOnlyMLMHead(nn.Module):
325
+ def __init__(self, config, cross_model_embedding_weights):
326
+ super(CrossOnlyMLMHead, self).__init__()
327
+ self.predictions = CrossLMPredictionHead(config, cross_model_embedding_weights)
328
+
329
+ def forward(self, sequence_output):
330
+ prediction_scores = self.predictions(sequence_output)
331
+ return prediction_scores
332
+
333
+
334
+ class CrossOnlyNSPHead(nn.Module):
335
+ def __init__(self, config):
336
+ super(CrossOnlyNSPHead, self).__init__()
337
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
338
+
339
+ def forward(self, pooled_output):
340
+ seq_relationship_score = self.seq_relationship(pooled_output)
341
+ return seq_relationship_score
342
+
343
+
344
+ class CrossPreTrainingHeads(nn.Module):
345
+ def __init__(self, config, cross_model_embedding_weights):
346
+ super(CrossPreTrainingHeads, self).__init__()
347
+ self.predictions = CrossLMPredictionHead(config, cross_model_embedding_weights)
348
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
349
+
350
+ def forward(self, sequence_output, pooled_output):
351
+ prediction_scores = self.predictions(sequence_output)
352
+ seq_relationship_score = self.seq_relationship(pooled_output)
353
+ return prediction_scores, seq_relationship_score
354
+
355
+
356
+ class CrossModel(PreTrainedModel):
357
+ def __init__(self, config):
358
+ super(CrossModel, self).__init__(config)
359
+ self.embeddings = CrossEmbeddings(config)
360
+ self.encoder = CrossEncoder(config)
361
+ self.pooler = CrossPooler(config)
362
+ self.apply(self.init_weights)
363
+
364
+ def forward(self, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True):
365
+
366
+ if attention_mask is None:
367
+ attention_mask = torch.ones(concat_input.size(0), concat_input.size(1))
368
+ if concat_type is None:
369
+ concat_type = torch.zeros_like(attention_mask)
370
+
371
+ # We create a 3D attention mask from a 2D tensor mask.
372
+ # Sizes are [batch_size, 1, 1, to_seq_length]
373
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
374
+ # this attention mask is more simple than the triangular masking of causal attention
375
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
376
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
377
+
378
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
379
+ # masked positions, this operation will create a tensor which is 0.0 for
380
+ # positions we want to attend and -10000.0 for masked positions.
381
+ # Since we are adding it to the raw scores before the softmax, this is
382
+ # effectively the same as removing these entirely.
383
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
384
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
385
+
386
+ embedding_output = self.embeddings(concat_input, concat_type)
387
+ encoded_layers = self.encoder(embedding_output,
388
+ extended_attention_mask,
389
+ output_all_encoded_layers=output_all_encoded_layers)
390
+ sequence_output = encoded_layers[-1]
391
+ pooled_output = self.pooler(sequence_output)
392
+ if not output_all_encoded_layers:
393
+ encoded_layers = encoded_layers[-1]
394
+ return encoded_layers, pooled_output
utils/module_decoder.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import os
23
+ import copy
24
+ import json
25
+ import math
26
+ import logging
27
+ import tarfile
28
+ import tempfile
29
+ import shutil
30
+ import numpy as np
31
+
32
+ import torch
33
+ from torch import nn
34
+ from .file_utils import cached_path
35
+ from .until_config import PretrainedConfig
36
+ from .until_module import PreTrainedModel, LayerNorm, ACT2FN
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+ PRETRAINED_MODEL_ARCHIVE_MAP = {}
41
+ CONFIG_NAME = 'decoder_config.json'
42
+ WEIGHTS_NAME = 'decoder_pytorch_model.bin'
43
+
44
+
45
+ class DecoderConfig(PretrainedConfig):
46
+ """Configuration class to store the configuration of a `DecoderModel`.
47
+ """
48
+ pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
49
+ config_name = CONFIG_NAME
50
+ weights_name = WEIGHTS_NAME
51
+ def __init__(self,
52
+ vocab_size_or_config_json_file,
53
+ hidden_size=768,
54
+ num_hidden_layers=12,
55
+ num_attention_heads=12,
56
+ intermediate_size=3072,
57
+ hidden_act="gelu",
58
+ hidden_dropout_prob=0.1,
59
+ attention_probs_dropout_prob=0.1,
60
+ type_vocab_size=2,
61
+ initializer_range=0.02,
62
+ max_target_embeddings=128,
63
+ num_decoder_layers=1):
64
+ """Constructs DecoderConfig.
65
+
66
+ Args:
67
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `DecoderModel`.
68
+ hidden_size: Size of the encoder layers and the pooler layer.
69
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
70
+ num_attention_heads: Number of attention heads for each attention layer in
71
+ the Transformer encoder.
72
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
73
+ layer in the Transformer encoder.
74
+ hidden_act: The non-linear activation function (function or string) in the
75
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
76
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
77
+ layers in the embeddings, encoder, and pooler.
78
+ attention_probs_dropout_prob: The dropout ratio for the attention
79
+ probabilities.
80
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
81
+ `DecoderModel`.
82
+ initializer_range: The sttdev of the truncated_normal_initializer for
83
+ initializing all weight matrices.
84
+ max_target_embeddings: The maximum sequence length that this model might
85
+ ever be used with. Typically set this to something large just in case
86
+ (e.g., 512 or 1024 or 2048).
87
+ num_decoder_layers:
88
+ """
89
+ if isinstance(vocab_size_or_config_json_file, str):
90
+ with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
91
+ json_config = json.loads(reader.read())
92
+ for key, value in json_config.items():
93
+ self.__dict__[key] = value
94
+ elif isinstance(vocab_size_or_config_json_file, int):
95
+ self.vocab_size = vocab_size_or_config_json_file
96
+ self.hidden_size = hidden_size
97
+ self.num_hidden_layers = num_hidden_layers
98
+ self.num_attention_heads = num_attention_heads
99
+ self.hidden_act = hidden_act
100
+ self.intermediate_size = intermediate_size
101
+ self.hidden_dropout_prob = hidden_dropout_prob
102
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
103
+ self.type_vocab_size = type_vocab_size
104
+ self.initializer_range = initializer_range
105
+ self.max_target_embeddings = max_target_embeddings
106
+ self.num_decoder_layers = num_decoder_layers
107
+ else:
108
+ raise ValueError("First argument must be either a vocabulary size (int)"
109
+ "or the path to a pretrained model config file (str)")
110
+
111
+
112
+ class BertSelfOutput(nn.Module):
113
+ def __init__(self, config):
114
+ super(BertSelfOutput, self).__init__()
115
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
116
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
117
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
118
+
119
+ def forward(self, hidden_states, input_tensor):
120
+ hidden_states = self.dense(hidden_states)
121
+ hidden_states = self.dropout(hidden_states)
122
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
123
+ return hidden_states
124
+
125
+ class BertIntermediate(nn.Module):
126
+ def __init__(self, config):
127
+ super(BertIntermediate, self).__init__()
128
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
129
+ self.intermediate_act_fn = ACT2FN[config.hidden_act] \
130
+ if isinstance(config.hidden_act, str) else config.hidden_act
131
+
132
+ def forward(self, hidden_states):
133
+ hidden_states = self.dense(hidden_states)
134
+ hidden_states = self.intermediate_act_fn(hidden_states)
135
+ return hidden_states
136
+
137
+
138
+ class BertOutput(nn.Module):
139
+ def __init__(self, config):
140
+ super(BertOutput, self).__init__()
141
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
142
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
143
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
144
+
145
+ def forward(self, hidden_states, input_tensor):
146
+ hidden_states = self.dense(hidden_states)
147
+ hidden_states = self.dropout(hidden_states)
148
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
149
+ return hidden_states
150
+
151
+
152
+ class BertPredictionHeadTransform(nn.Module):
153
+ def __init__(self, config):
154
+ super(BertPredictionHeadTransform, self).__init__()
155
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
156
+ self.transform_act_fn = ACT2FN[config.hidden_act] \
157
+ if isinstance(config.hidden_act, str) else config.hidden_act
158
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
159
+
160
+ def forward(self, hidden_states):
161
+ hidden_states = self.dense(hidden_states)
162
+ hidden_states = self.transform_act_fn(hidden_states)
163
+ hidden_states = self.LayerNorm(hidden_states)
164
+ return hidden_states
165
+
166
+
167
+ class BertLMPredictionHead(nn.Module):
168
+ def __init__(self, config, decoder_model_embedding_weights):
169
+ super(BertLMPredictionHead, self).__init__()
170
+ self.transform = BertPredictionHeadTransform(config)
171
+
172
+ # The output weights are the same as the input embeddings, but there is
173
+ # an output-only bias for each token.
174
+ self.decoder = nn.Linear(decoder_model_embedding_weights.size(1),
175
+ decoder_model_embedding_weights.size(0),
176
+ bias=False)
177
+ self.decoder.weight = decoder_model_embedding_weights
178
+ self.bias = nn.Parameter(torch.zeros(decoder_model_embedding_weights.size(0)))
179
+
180
+ def forward(self, hidden_states):
181
+ hidden_states = self.transform(hidden_states)
182
+ hidden_states = self.decoder(hidden_states) + self.bias
183
+ return hidden_states
184
+
185
+
186
+ class BertOnlyMLMHead(nn.Module):
187
+ def __init__(self, config, decoder_model_embedding_weights):
188
+ super(BertOnlyMLMHead, self).__init__()
189
+ self.predictions = BertLMPredictionHead(config, decoder_model_embedding_weights)
190
+
191
+ def forward(self, sequence_output):
192
+ prediction_scores = self.predictions(sequence_output)
193
+ return prediction_scores
194
+
195
+ class MultiHeadAttention(nn.Module):
196
+ ''' Multi-Head Attention module '''
197
+
198
+ def __init__(self, config):
199
+ super(MultiHeadAttention, self).__init__()
200
+
201
+ if config.hidden_size % config.num_attention_heads != 0:
202
+ raise ValueError(
203
+ "The hidden size (%d) is not a multiple of the number of attention "
204
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
205
+ self.num_attention_heads = config.num_attention_heads
206
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
207
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
208
+
209
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
210
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
211
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
212
+
213
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
214
+
215
+ def transpose_for_scores(self, x):
216
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
217
+ x = x.view(*new_x_shape)
218
+ return x.permute(0, 2, 1, 3)
219
+
220
+ def forward(self, q, k, v, attention_mask):
221
+ mixed_query_layer = self.query(q)
222
+ mixed_key_layer = self.key(k)
223
+ mixed_value_layer = self.value(v)
224
+
225
+ query_layer = self.transpose_for_scores(mixed_query_layer)
226
+ key_layer = self.transpose_for_scores(mixed_key_layer)
227
+ value_layer = self.transpose_for_scores(mixed_value_layer)
228
+
229
+ # Take the dot product between "query" and "key" to get the raw attention scores.
230
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
231
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
232
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
233
+ attention_scores = attention_scores + attention_mask
234
+
235
+ # Normalize the attention scores to probabilities.
236
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
237
+
238
+ # This is actually dropping out entire tokens to attend to, which might
239
+ # seem a bit unusual, but is taken from the original Transformer paper.
240
+ attention_probs = self.dropout(attention_probs)
241
+
242
+ context_layer = torch.matmul(attention_probs, value_layer)
243
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
244
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
245
+ context_layer = context_layer.view(*new_context_layer_shape)
246
+
247
+ return context_layer, attention_scores
248
+
249
+ class PositionwiseFeedForward(nn.Module):
250
+ ''' A two-feed-forward-layer module '''
251
+
252
+ def __init__(self, d_in, d_hid, dropout=0.1):
253
+ super().__init__()
254
+ self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
255
+ self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
256
+ self.layer_norm = nn.LayerNorm(d_in)
257
+ self.dropout = nn.Dropout(dropout)
258
+
259
+ def forward(self, x):
260
+ residual = x
261
+ output = x.transpose(1, 2)
262
+ output = self.w_2(ACT2FN["gelu"](self.w_1(output)))
263
+ output = output.transpose(1, 2)
264
+ output = self.dropout(output)
265
+ output = self.layer_norm(output + residual)
266
+ return output
267
+
268
+ class DecoderAttention(nn.Module):
269
+ def __init__(self, config):
270
+ super(DecoderAttention, self).__init__()
271
+ self.att = MultiHeadAttention(config)
272
+ self.output = BertSelfOutput(config)
273
+
274
+ def forward(self, q, k, v, attention_mask):
275
+ att_output, attention_probs = self.att(q, k, v, attention_mask)
276
+ attention_output = self.output(att_output, q)
277
+ return attention_output, attention_probs
278
+
279
+ class EncoderLayer(nn.Module):
280
+ def __init__(self, config):
281
+ super(EncoderLayer, self).__init__()
282
+ self.slf_attn = DecoderAttention(config)
283
+ self.intermediate = BertIntermediate(config)
284
+ self.output = BertOutput(config)
285
+
286
+ def forward(self, dec_input, slf_attn_mask=None):
287
+ slf_output, slf_att_scores = self.slf_attn(dec_input, dec_input, dec_input, slf_attn_mask)
288
+ intermediate_output = self.intermediate(slf_output)
289
+ dec_output = self.output(intermediate_output, slf_output)
290
+ return dec_output, slf_att_scores
291
+
292
+ class DecoderLayer(nn.Module):
293
+ def __init__(self, config):
294
+ super(DecoderLayer, self).__init__()
295
+ self.slf_attn = DecoderAttention(config)
296
+ self.enc_attn = DecoderAttention(config)
297
+ self.intermediate = BertIntermediate(config)
298
+ self.output = BertOutput(config)
299
+
300
+ def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None):
301
+ slf_output, _ = self.slf_attn(dec_input, dec_input, dec_input, slf_attn_mask)
302
+ dec_output, dec_att_scores = self.enc_attn(slf_output, enc_output, enc_output, dec_enc_attn_mask)
303
+ intermediate_output = self.intermediate(dec_output)
304
+ dec_output = self.output(intermediate_output, dec_output)
305
+ return dec_output, dec_att_scores
306
+
307
+ class DecoderEmbeddings(nn.Module):
308
+ """Construct the embeddings from word, position and token_type embeddings.
309
+ """
310
+ def __init__(self, config, decoder_word_embeddings_weight, decoder_position_embeddings_weight):
311
+ super(DecoderEmbeddings, self).__init__()
312
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
313
+ self.position_embeddings = nn.Embedding(config.max_target_embeddings, config.hidden_size)
314
+ self.word_embeddings.weight = decoder_word_embeddings_weight
315
+ self.position_embeddings.weight = decoder_position_embeddings_weight
316
+
317
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
318
+ # any TensorFlow checkpoint file
319
+ self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
320
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
321
+
322
+ def forward(self, input_ids):
323
+ seq_length = input_ids.size(1)
324
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
325
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
326
+
327
+ words_embeddings = self.word_embeddings(input_ids)
328
+ position_embeddings = self.position_embeddings(position_ids)
329
+
330
+ embeddings = words_embeddings + position_embeddings
331
+ embeddings = self.LayerNorm(embeddings)
332
+ embeddings = self.dropout(embeddings)
333
+ return embeddings
334
+
335
+ class Encoder(nn.Module):
336
+ def __init__(self, config):
337
+ super(Encoder, self).__init__()
338
+ layer = EncoderLayer(config)
339
+ self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_decoder_layers)])
340
+
341
+ def forward(self, hidden_states, self_attn_mask, output_all_encoded_layers=False):
342
+ dec_att_scores = None
343
+ all_encoder_layers = []
344
+ all_dec_att_probs = []
345
+ for layer_module in self.layer:
346
+ hidden_states, dec_att_scores = layer_module(hidden_states, self_attn_mask)
347
+ if output_all_encoded_layers:
348
+ all_encoder_layers.append(hidden_states)
349
+ all_dec_att_probs.append(dec_att_scores)
350
+ if not output_all_encoded_layers:
351
+ all_encoder_layers.append(hidden_states)
352
+ all_dec_att_probs.append(dec_att_scores)
353
+ return all_encoder_layers, all_dec_att_probs
354
+
355
+ class Decoder(nn.Module):
356
+ def __init__(self, config):
357
+ super(Decoder, self).__init__()
358
+ layer = DecoderLayer(config)
359
+ self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_decoder_layers)])
360
+
361
+ def forward(self, hidden_states, encoder_outs, self_attn_mask, attention_mask, output_all_encoded_layers=False):
362
+ dec_att_scores = None
363
+ all_encoder_layers = []
364
+ all_dec_att_probs = []
365
+ for i, layer_module in enumerate(self.layer):
366
+ if isinstance(encoder_outs, list):
367
+ hidden_states, dec_att_scores = layer_module(hidden_states, encoder_outs[i], self_attn_mask, attention_mask)
368
+ else:
369
+ hidden_states, dec_att_scores = layer_module(hidden_states, encoder_outs, self_attn_mask, attention_mask)
370
+ if output_all_encoded_layers:
371
+ all_encoder_layers.append(hidden_states)
372
+ all_dec_att_probs.append(dec_att_scores)
373
+ if not output_all_encoded_layers:
374
+ all_encoder_layers.append(hidden_states)
375
+ all_dec_att_probs.append(dec_att_scores)
376
+ return all_encoder_layers, all_dec_att_probs
377
+
378
+ class DecoderClassifier(nn.Module):
379
+ def __init__(self, config, embedding_weights):
380
+ super(DecoderClassifier, self).__init__()
381
+ self.cls = BertOnlyMLMHead(config, embedding_weights)
382
+
383
+ def forward(self, hidden_states):
384
+ cls_scores = self.cls(hidden_states)
385
+ return cls_scores
386
+
387
+ class DecoderModel(PreTrainedModel):
388
+
389
+ """
390
+ Transformer decoder consisting of *args.decoder_layers* layers. Each layer
391
+ is a :class:`TransformerDecoderLayer`.
392
+
393
+ Args:
394
+ args (argparse.Namespace): parsed command-line arguments
395
+ final_norm (bool, optional): apply layer norm to the output of the
396
+ final decoder layer (default: True).
397
+ """
398
+
399
+ def __init__(self, config, decoder_word_embeddings_weight, decoder_position_embeddings_weight):
400
+ super(DecoderModel, self).__init__(config)
401
+ self.config = config
402
+ self.max_target_length = config.max_target_embeddings
403
+ self.embeddings = DecoderEmbeddings(config, decoder_word_embeddings_weight, decoder_position_embeddings_weight)
404
+ self.decoder = Decoder(config)
405
+ self.encoder = Encoder(config)
406
+ self.classifier = DecoderClassifier(config, decoder_word_embeddings_weight)
407
+ self.apply(self.init_weights)
408
+
409
+ def forward(self, input_ids, encoder_outs=None, answer_mask=None, encoder_mask=None):
410
+ """
411
+ Args:
412
+ input_ids (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing
413
+ encoder_outs (Tensor, optional): output from the encoder, used for encoder-side attention
414
+
415
+ Returns:
416
+ tuple:
417
+ - the last decoder layer's output of shape `(batch, tgt_len, vocab)`
418
+ - the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)`
419
+ """
420
+
421
+ embedding_output = self.embeddings(input_ids)
422
+
423
+ extended_encoder_mask = encoder_mask.unsqueeze(1).unsqueeze(2) # b x 1 x 1 x ls
424
+ extended_encoder_mask = extended_encoder_mask.to(dtype=self.dtype) # fp16 compatibility
425
+ extended_encoder_mask = (1.0 - extended_encoder_mask) * -10000.0
426
+
427
+ extended_answer_mask = answer_mask.unsqueeze(1).unsqueeze(2)
428
+ extended_answer_mask = extended_answer_mask.to(dtype=self.dtype) # fp16 compatibility
429
+
430
+ sz_b, len_s, _ = embedding_output.size()
431
+ subsequent_mask = torch.triu(torch.ones((len_s, len_s), device=embedding_output.device, dtype=embedding_output.dtype), diagonal=1)
432
+ self_attn_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1).unsqueeze(1) # b x 1 x ls x ls
433
+ slf_attn_mask = ((1.0 - extended_answer_mask) + self_attn_mask).gt(0).to(dtype=self.dtype)
434
+ self_attn_mask = slf_attn_mask * -10000.0
435
+
436
+ encoder_outs, _ = self.encoder(encoder_outs, extended_encoder_mask, output_all_encoded_layers=True)
437
+ # encoder_outs = encoder_outs[-1]
438
+
439
+ decoded_layers, dec_att_scores = self.decoder(embedding_output,
440
+ encoder_outs,
441
+ self_attn_mask,
442
+ extended_encoder_mask,
443
+ )
444
+ sequence_output = decoded_layers[-1]
445
+ cls_scores = self.classifier(sequence_output)
446
+
447
+ return cls_scores
utils/module_gated_attention.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""Gated Co-attention interface"""
2
+ from typing import Callable, List, Optional, Tuple
3
+ import math
4
+ import warnings
5
+
6
+ import torch
7
+ from torch.nn.functional import *
8
+
9
+
10
+ def gated_coattention(
11
+ query: Tensor,
12
+ key: Tensor,
13
+ value: Tensor,
14
+ embed_dim_to_check: int,
15
+ num_heads: int,
16
+ in_proj_weight: Tensor,
17
+ in_proj_bias: Tensor,
18
+ bias_k: Optional[Tensor],
19
+ bias_v: Optional[Tensor],
20
+ add_zero_attn: bool,
21
+ dropout_p: float,
22
+ out_proj_weight: Tensor,
23
+ out_proj_bias: Tensor,
24
+ training: bool = True,
25
+ gt_attention_map: Optional[Tensor] = None,
26
+ key_padding_mask: Optional[Tensor] = None,
27
+ need_weights: bool = True,
28
+ attn_mask: Optional[Tensor] = None,
29
+ use_separate_proj_weight: bool = False,
30
+ q_proj_weight: Optional[Tensor] = None,
31
+ k_proj_weight: Optional[Tensor] = None,
32
+ v_proj_weight: Optional[Tensor] = None,
33
+ static_k: Optional[Tensor] = None,
34
+ static_v: Optional[Tensor] = None,
35
+ ) -> Tuple[Tensor, Optional[Tensor]]:
36
+ r"""
37
+ Args:
38
+ query, key, value: map a query and a set of key-value pairs to an output.
39
+ See "Attention Is All You Need" for more details.
40
+ embed_dim_to_check: total dimension of the model.
41
+ num_heads: parallel attention heads.
42
+ in_proj_weight, in_proj_bias: input projection weight and bias.
43
+ bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
44
+ add_zero_attn: add a new batch of zeros to the key and
45
+ value sequences at dim=1.
46
+ dropout_p: probability of an element to be zeroed.
47
+ out_proj_weight, out_proj_bias: the output projection weight and bias.
48
+ training: apply dropout if is ``True``.
49
+ key_padding_mask: if provided, specified padding elements in the key will
50
+ be ignored by the attention. This is an binary mask. When the value is True,
51
+ the corresponding value on the attention layer will be filled with -inf.
52
+ need_weights: output attn_output_weights.
53
+ attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
54
+ the batches while a 3D mask allows to specify a different mask for the entries of each batch.
55
+ use_separate_proj_weight: the function accept the proj. weights for query, key,
56
+ and value in different forms. If false, in_proj_weight will be used, which is
57
+ a combination of q_proj_weight, k_proj_weight, v_proj_weight.
58
+ q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
59
+ static_k, static_v: static key and value used for attention operators.
60
+
61
+
62
+ Shape:
63
+ Inputs:
64
+ - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
65
+ the embedding dimension.
66
+ - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
67
+ the embedding dimension.
68
+ - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
69
+ the embedding dimension.
70
+ - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
71
+ If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
72
+ will be unchanged. If a BoolTensor is provided, the positions with the
73
+ value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
74
+ - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
75
+ 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
76
+ S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
77
+ positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
78
+ while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
79
+ are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
80
+ is provided, it will be added to the attention weight.
81
+ - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
82
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
83
+ - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
84
+ N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
85
+
86
+ Outputs:
87
+ - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
88
+ E is the embedding dimension.
89
+ - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
90
+ L is the target sequence length, S is the source sequence length.
91
+ """
92
+ tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
93
+ if has_torch_function(tens_ops):
94
+ return handle_torch_function(
95
+ multi_head_attention_forward,
96
+ tens_ops,
97
+ query,
98
+ key,
99
+ value,
100
+ embed_dim_to_check,
101
+ num_heads,
102
+ in_proj_weight,
103
+ in_proj_bias,
104
+ bias_k,
105
+ bias_v,
106
+ add_zero_attn,
107
+ dropout_p,
108
+ out_proj_weight,
109
+ out_proj_bias,
110
+ training=training,
111
+ gt_attention_map=gt_attention_map,
112
+ key_padding_mask=key_padding_mask,
113
+ need_weights=need_weights,
114
+ attn_mask=attn_mask,
115
+ use_separate_proj_weight=use_separate_proj_weight,
116
+ q_proj_weight=q_proj_weight,
117
+ k_proj_weight=k_proj_weight,
118
+ v_proj_weight=v_proj_weight,
119
+ static_k=static_k,
120
+ static_v=static_v,
121
+ )
122
+ tgt_len, bsz, embed_dim = query.size()
123
+ assert embed_dim == embed_dim_to_check
124
+ # allow MHA to have different sizes for the feature dimension
125
+ assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
126
+
127
+ head_dim = embed_dim // num_heads
128
+ assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
129
+ scaling = float(head_dim) ** -0.5
130
+
131
+ if not use_separate_proj_weight:
132
+ # encoder-decoder attention ---->>>>>>>> co-attention style
133
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
134
+ _b = in_proj_bias
135
+ _start = 0
136
+ _end = embed_dim
137
+ _w = in_proj_weight[_start:_end, :]
138
+ if _b is not None:
139
+ _b = _b[_start:_end]
140
+ q = linear(query, _w, _b)
141
+
142
+ if key is None:
143
+ assert value is None
144
+ k = None
145
+ v = None
146
+ else:
147
+
148
+ # This is inline in_proj function with in_proj_weight and in_proj_bias
149
+ _b = in_proj_bias
150
+ _start = embed_dim
151
+ _end = None
152
+ _w = in_proj_weight[_start:, :]
153
+ if _b is not None:
154
+ _b = _b[_start:]
155
+ k, v = linear(key, _w, _b).chunk(2, dim=-1)
156
+
157
+ else:
158
+ q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
159
+ len1, len2 = q_proj_weight_non_opt.size()
160
+ assert len1 == embed_dim and len2 == query.size(-1)
161
+
162
+ k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
163
+ len1, len2 = k_proj_weight_non_opt.size()
164
+ assert len1 == embed_dim and len2 == key.size(-1)
165
+
166
+ v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
167
+ len1, len2 = v_proj_weight_non_opt.size()
168
+ assert len1 == embed_dim and len2 == value.size(-1)
169
+
170
+ if in_proj_bias is not None:
171
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
172
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)])
173
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])
174
+ else:
175
+ q = linear(query, q_proj_weight_non_opt, in_proj_bias)
176
+ k = linear(key, k_proj_weight_non_opt, in_proj_bias)
177
+ v = linear(value, v_proj_weight_non_opt, in_proj_bias)
178
+ q = q * scaling
179
+
180
+ if attn_mask is not None:
181
+ assert (
182
+ attn_mask.dtype == torch.float32
183
+ or attn_mask.dtype == torch.float64
184
+ or attn_mask.dtype == torch.float16
185
+ or attn_mask.dtype == torch.uint8
186
+ or attn_mask.dtype == torch.bool
187
+ ), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype)
188
+ if attn_mask.dtype == torch.uint8:
189
+ warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
190
+ attn_mask = attn_mask.to(torch.bool)
191
+
192
+ if attn_mask.dim() == 2:
193
+ attn_mask = attn_mask.unsqueeze(0)
194
+ if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
195
+ raise RuntimeError("The size of the 2D attn_mask is not correct.")
196
+ elif attn_mask.dim() == 3:
197
+ if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
198
+ raise RuntimeError("The size of the 3D attn_mask is not correct.")
199
+ else:
200
+ raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
201
+ # attn_mask's dim is 3 now.
202
+
203
+ # convert ByteTensor key_padding_mask to bool
204
+ if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
205
+ warnings.warn(
206
+ "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
207
+ )
208
+ key_padding_mask = key_padding_mask.to(torch.bool)
209
+
210
+ if bias_k is not None and bias_v is not None:
211
+ if static_k is None and static_v is None:
212
+ k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
213
+ v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
214
+ if attn_mask is not None:
215
+ attn_mask = pad(attn_mask, (0, 1))
216
+ if key_padding_mask is not None:
217
+ key_padding_mask = pad(key_padding_mask, (0, 1))
218
+ else:
219
+ assert static_k is None, "bias cannot be added to static key."
220
+ assert static_v is None, "bias cannot be added to static value."
221
+ else:
222
+ assert bias_k is None
223
+ assert bias_v is None
224
+
225
+ q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
226
+ if k is not None:
227
+ k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
228
+ if v is not None:
229
+ v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
230
+
231
+ if static_k is not None:
232
+ assert static_k.size(0) == bsz * num_heads
233
+ assert static_k.size(2) == head_dim
234
+ k = static_k
235
+
236
+ if static_v is not None:
237
+ assert static_v.size(0) == bsz * num_heads
238
+ assert static_v.size(2) == head_dim
239
+ v = static_v
240
+
241
+ src_len = k.size(1)
242
+
243
+ if key_padding_mask is not None:
244
+ assert key_padding_mask.size(0) == bsz
245
+ assert key_padding_mask.size(1) == src_len
246
+
247
+ if add_zero_attn:
248
+ src_len += 1
249
+ k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
250
+ v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
251
+ if attn_mask is not None:
252
+ attn_mask = pad(attn_mask, (0, 1))
253
+ if key_padding_mask is not None:
254
+ key_padding_mask = pad(key_padding_mask, (0, 1))
255
+
256
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
257
+ assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
258
+
259
+ if attn_mask is not None:
260
+ if attn_mask.dtype == torch.bool:
261
+ attn_output_weights.masked_fill_(attn_mask, float("-inf"))
262
+ else:
263
+ attn_output_weights += attn_mask
264
+
265
+ if key_padding_mask is not None:
266
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
267
+ attn_output_weights = attn_output_weights.masked_fill(
268
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
269
+ float("-inf"),
270
+ )
271
+ attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
272
+
273
+ attn_output_weights = softmax(attn_output_weights, dim=-1)
274
+
275
+ attn_output_weights_new = attn_output_weights.clone()
276
+ attn_output_weights_new[:, 0, :] = gt_attention_map
277
+ attn_output_weights = attn_output_weights_new
278
+
279
+ ##################### TAB Forget Gate (start) #########################
280
+ # temp_weights = torch.ones_like(attn_output_weights[:, 0, 0], device=attn_output_weights.device)
281
+ # temp_weights = temp_weights - attn_output_weights[:, 0, 0]
282
+ # weights = torch.mul(attn_output_weights[:, 0, :], temp_weights[:, None])
283
+ # attn_output_weights_new = attn_output_weights.clone() # Create a copy
284
+ # attn_output_weights_new[:, 0, :] = weights # Modify the copy
285
+ ##################### TAB Forget Gate (end) #########################
286
+
287
+ attn_output_weights_new = dropout(attn_output_weights_new, p=dropout_p, training=training)
288
+ attn_output = torch.bmm(attn_output_weights_new, v)
289
+
290
+
291
+
292
+ assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
293
+ attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
294
+ attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
295
+
296
+ if need_weights:
297
+ # average attention weights over heads
298
+ attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
299
+ return attn_output, attn_output_weights.sum(dim=1) / num_heads
300
+ else:
301
+ return attn_output, None
utils/tokenization_clip.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ self.vocab = self.encoder
81
+
82
+ def bpe(self, token):
83
+ if token in self.cache:
84
+ return self.cache[token]
85
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
86
+ pairs = get_pairs(word)
87
+
88
+ if not pairs:
89
+ return token+'</w>'
90
+
91
+ while True:
92
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
93
+ if bigram not in self.bpe_ranks:
94
+ break
95
+ first, second = bigram
96
+ new_word = []
97
+ i = 0
98
+ while i < len(word):
99
+ try:
100
+ j = word.index(first, i)
101
+ new_word.extend(word[i:j])
102
+ i = j
103
+ except:
104
+ new_word.extend(word[i:])
105
+ break
106
+
107
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
108
+ new_word.append(first+second)
109
+ i += 2
110
+ else:
111
+ new_word.append(word[i])
112
+ i += 1
113
+ new_word = tuple(new_word)
114
+ word = new_word
115
+ if len(word) == 1:
116
+ break
117
+ else:
118
+ pairs = get_pairs(word)
119
+ word = ' '.join(word)
120
+ self.cache[token] = word
121
+ return word
122
+
123
+ def encode(self, text):
124
+ bpe_tokens = []
125
+ text = whitespace_clean(basic_clean(text)).lower()
126
+ for token in re.findall(self.pat, text):
127
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
128
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
129
+ return bpe_tokens
130
+
131
+ def decode(self, tokens):
132
+ text = ''.join([self.decoder[token] for token in tokens])
133
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
134
+ return text
135
+
136
+ def tokenize(self, text):
137
+ tokens = []
138
+ text = whitespace_clean(basic_clean(text)).lower()
139
+ for token in re.findall(self.pat, text):
140
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
141
+ tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
142
+ return tokens
143
+
144
+ def convert_tokens_to_ids(self, tokens):
145
+ return [self.encoder[bpe_token] for bpe_token in tokens]
146
+
147
+ def convert_ids_to_tokens(self, ids):
148
+ """Converts a sequence of ids in tokens using the vocab."""
149
+ return self.decode(ids)
utils/until_config.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model."""
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import os
23
+ import copy
24
+ import json
25
+ import logging
26
+ import tarfile
27
+ import tempfile
28
+ import shutil
29
+ import torch
30
+ from .file_utils import cached_path
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ class PretrainedConfig(object):
35
+
36
+ pretrained_model_archive_map = {}
37
+ config_name = ""
38
+ weights_name = ""
39
+
40
+ @classmethod
41
+ def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
42
+ archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
43
+ if os.path.exists(archive_file) is False:
44
+ if pretrained_model_name in cls.pretrained_model_archive_map:
45
+ archive_file = cls.pretrained_model_archive_map[pretrained_model_name]
46
+ else:
47
+ archive_file = pretrained_model_name
48
+
49
+ # redirect to the cache, if necessary
50
+ try:
51
+ resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
52
+ except FileNotFoundError:
53
+ if task_config is None or task_config.local_rank == 0:
54
+ logger.error(
55
+ "Model name '{}' was not found in model name list. "
56
+ "We assumed '{}' was a path or url but couldn't find any file "
57
+ "associated to this path or url.".format(
58
+ pretrained_model_name,
59
+ archive_file))
60
+ return None
61
+ if resolved_archive_file == archive_file:
62
+ if task_config is None or task_config.local_rank == 0:
63
+ logger.info("loading archive file {}".format(archive_file))
64
+ else:
65
+ if task_config is None or task_config.local_rank == 0:
66
+ logger.info("loading archive file {} from cache at {}".format(
67
+ archive_file, resolved_archive_file))
68
+ tempdir = None
69
+ if os.path.isdir(resolved_archive_file):
70
+ serialization_dir = resolved_archive_file
71
+ else:
72
+ # Extract archive to temp dir
73
+ tempdir = tempfile.mkdtemp()
74
+ if task_config is None or task_config.local_rank == 0:
75
+ logger.info("extracting archive file {} to temp dir {}".format(
76
+ resolved_archive_file, tempdir))
77
+ with tarfile.open(resolved_archive_file, 'r:gz') as archive:
78
+ archive.extractall(tempdir)
79
+ serialization_dir = tempdir
80
+ # Load config
81
+ config_file = os.path.join(serialization_dir, cls.config_name)
82
+ config = cls.from_json_file(config_file)
83
+ config.type_vocab_size = type_vocab_size
84
+ if task_config is None or task_config.local_rank == 0:
85
+ logger.info("Model config {}".format(config))
86
+
87
+ if state_dict is None:
88
+ weights_path = os.path.join(serialization_dir, cls.weights_name)
89
+ if os.path.exists(weights_path):
90
+ state_dict = torch.load(weights_path, map_location='cpu')
91
+ else:
92
+ if task_config is None or task_config.local_rank == 0:
93
+ logger.info("Weight doesn't exsits. {}".format(weights_path))
94
+
95
+ if tempdir:
96
+ # Clean up temp dir
97
+ shutil.rmtree(tempdir)
98
+
99
+ return config, state_dict
100
+
101
+ @classmethod
102
+ def from_dict(cls, json_object):
103
+ """Constructs a `BertConfig` from a Python dictionary of parameters."""
104
+ config = cls(vocab_size_or_config_json_file=-1)
105
+ for key, value in json_object.items():
106
+ config.__dict__[key] = value
107
+ return config
108
+
109
+ @classmethod
110
+ def from_json_file(cls, json_file):
111
+ """Constructs a `BertConfig` from a json file of parameters."""
112
+ with open(json_file, "r", encoding='utf-8') as reader:
113
+ text = reader.read()
114
+ return cls.from_dict(json.loads(text))
115
+
116
+ def __repr__(self):
117
+ return str(self.to_json_string())
118
+
119
+ def to_dict(self):
120
+ """Serializes this instance to a Python dictionary."""
121
+ output = copy.deepcopy(self.__dict__)
122
+ return output
123
+
124
+ def to_json_string(self):
125
+ """Serializes this instance to a JSON string."""
126
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
utils/until_module.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model."""
17
+
18
+ import logging
19
+ import numpy as np
20
+ import torch
21
+ from torch import nn
22
+ import torch.nn.functional as F
23
+ import math
24
+ from .until_config import PretrainedConfig
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ def gelu(x):
29
+ """Implementation of the gelu activation function.
30
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
31
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
32
+ """
33
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
34
+
35
+ def swish(x):
36
+ return x * torch.sigmoid(x)
37
+
38
+ ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
39
+
40
+ class LayerNorm(nn.Module):
41
+ def __init__(self, hidden_size, eps=1e-12):
42
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
43
+ """
44
+ super(LayerNorm, self).__init__()
45
+ self.weight = nn.Parameter(torch.ones(hidden_size))
46
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
47
+ self.variance_epsilon = eps
48
+
49
+ def forward(self, x):
50
+ u = x.mean(-1, keepdim=True)
51
+ s = (x - u).pow(2).mean(-1, keepdim=True)
52
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
53
+ return self.weight * x + self.bias
54
+
55
+ class PreTrainedModel(nn.Module):
56
+ """ An abstract class to handle weights initialization and
57
+ a simple interface for dowloading and loading pretrained models.
58
+ """
59
+ def __init__(self, config, *inputs, **kwargs):
60
+ super(PreTrainedModel, self).__init__()
61
+ if not isinstance(config, PretrainedConfig):
62
+ raise ValueError(
63
+ "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
64
+ "To create a model from a Google pretrained model use "
65
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
66
+ self.__class__.__name__, self.__class__.__name__
67
+ ))
68
+ self.config = config
69
+
70
+ def init_weights(self, module):
71
+ """ Initialize the weights.
72
+ """
73
+ if isinstance(module, (nn.Linear, nn.Embedding)):
74
+ # Slightly different from the TF version which uses truncated_normal for initialization
75
+ # cf https://github.com/pytorch/pytorch/pull/5617
76
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
77
+ elif isinstance(module, LayerNorm):
78
+ if 'beta' in dir(module) and 'gamma' in dir(module):
79
+ module.beta.data.zero_()
80
+ module.gamma.data.fill_(1.0)
81
+ else:
82
+ module.bias.data.zero_()
83
+ module.weight.data.fill_(1.0)
84
+ if isinstance(module, nn.Linear) and module.bias is not None:
85
+ module.bias.data.zero_()
86
+
87
+ def resize_token_embeddings(self, new_num_tokens=None):
88
+ raise NotImplementedError
89
+
90
+ @classmethod
91
+ def init_preweight(cls, model, state_dict, prefix=None):
92
+ old_keys = []
93
+ new_keys = []
94
+ for key in state_dict.keys():
95
+ new_key = None
96
+ if 'gamma' in key:
97
+ new_key = key.replace('gamma', 'weight')
98
+ if 'beta' in key:
99
+ new_key = key.replace('beta', 'bias')
100
+ if new_key:
101
+ old_keys.append(key)
102
+ new_keys.append(new_key)
103
+ for old_key, new_key in zip(old_keys, new_keys):
104
+ state_dict[new_key] = state_dict.pop(old_key)
105
+
106
+ if prefix is not None:
107
+ old_keys = []
108
+ new_keys = []
109
+ for key in state_dict.keys():
110
+ old_keys.append(key)
111
+ new_keys.append(prefix + key)
112
+ for old_key, new_key in zip(old_keys, new_keys):
113
+ state_dict[new_key] = state_dict.pop(old_key)
114
+
115
+ missing_keys = []
116
+ unexpected_keys = []
117
+ error_msgs = []
118
+ # copy state_dict so _load_from_state_dict can modify it
119
+ metadata = getattr(state_dict, '_metadata', None)
120
+ state_dict = state_dict.copy()
121
+ if metadata is not None:
122
+ state_dict._metadata = metadata
123
+
124
+ def load(module, prefix=''):
125
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
126
+ module._load_from_state_dict(
127
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
128
+ for name, child in module._modules.items():
129
+ if child is not None:
130
+ load(child, prefix + name + '.')
131
+
132
+ load(model, prefix='')
133
+ return model
134
+
135
+ @property
136
+ def dtype(self):
137
+ """
138
+ :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
139
+ """
140
+ try:
141
+ return next(self.parameters()).dtype
142
+ except StopIteration:
143
+ # For nn.DataParallel compatibility in PyTorch 1.5
144
+ def find_tensor_attributes(module: nn.Module):
145
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
146
+ return tuples
147
+
148
+ gen = self._named_members(get_members_fn=find_tensor_attributes)
149
+ first_tuple = next(gen)
150
+ return first_tuple[1].dtype
151
+
152
+ @classmethod
153
+ def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs):
154
+ """
155
+ Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
156
+ Download and cache the pre-trained model file if needed.
157
+ """
158
+ # Instantiate model.
159
+ model = cls(config, *inputs, **kwargs)
160
+ if state_dict is None:
161
+ return model
162
+ model = cls.init_preweight(model, state_dict)
163
+
164
+ return model
165
+
166
+ ##################################
167
+ ###### LOSS FUNCTION #############
168
+ ##################################
169
+
170
+
171
+ import itertools
172
+
173
+ class GroupCEn(nn.Module):
174
+ def __init__(self,):
175
+ super(GroupCEn, self).__init__()
176
+
177
+ def forward(self, sim_matrix, target):
178
+ mask = torch.eye(sim_matrix.size()[0], device=sim_matrix.device)
179
+ indicies = torch.nonzero(target)
180
+ com = list(itertools.product(indicies, indicies))
181
+ for x, y in com:
182
+ mask[x, y] = 1.0
183
+ logpt = F.log_softmax(sim_matrix, dim=-1)
184
+ logpt = logpt * mask
185
+ logpt = torch.sum(logpt, dim=-1) / torch.sum(mask, dim=-1)
186
+ logpt = torch.diag(logpt)
187
+ nce_loss = -logpt
188
+ sim_loss = nce_loss.mean()
189
+ return sim_loss
190
+
191
+ class CrossEn(nn.Module):
192
+ def __init__(self,):
193
+ super(CrossEn, self).__init__()
194
+
195
+ def forward(self, sim_matrix):
196
+ logpt = F.log_softmax(sim_matrix, dim=-1)
197
+ logpt = torch.diag(logpt)
198
+ nce_loss = -logpt
199
+ sim_loss = nce_loss.mean()
200
+ return sim_loss
201
+
202
+ class MILNCELoss(nn.Module):
203
+ def __init__(self, batch_size=1, n_pair=1,):
204
+ super(MILNCELoss, self).__init__()
205
+ self.batch_size = batch_size
206
+ self.n_pair = n_pair
207
+ torch_v = float(".".join(torch.__version__.split(".")[:2]))
208
+ self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8
209
+
210
+ def forward(self, sim_matrix):
211
+ mm_mask = np.eye(self.batch_size)
212
+ mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair)))
213
+ mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device)
214
+
215
+ from_text_matrix = sim_matrix + mm_mask * -1e12
216
+ from_video_matrix = sim_matrix.transpose(1, 0)
217
+
218
+ new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1)
219
+ logpt = F.log_softmax(new_sim_matrix, dim=-1)
220
+
221
+ mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1)
222
+ masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12
223
+
224
+ new_logpt = -torch.logsumexp(masked_logpt, dim=-1)
225
+
226
+ logpt_choice = torch.zeros_like(new_logpt)
227
+ mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2)
228
+ logpt_choice[mark_ind] = 1
229
+ sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean()
230
+ return sim_loss
231
+
232
+ class MaxMarginRankingLoss(nn.Module):
233
+ def __init__(self,
234
+ margin=1.0,
235
+ negative_weighting=False,
236
+ batch_size=1,
237
+ n_pair=1,
238
+ hard_negative_rate=0.5,
239
+ ):
240
+ super(MaxMarginRankingLoss, self).__init__()
241
+ self.margin = margin
242
+ self.n_pair = n_pair
243
+ self.batch_size = batch_size
244
+ easy_negative_rate = 1 - hard_negative_rate
245
+ self.easy_negative_rate = easy_negative_rate
246
+ self.negative_weighting = negative_weighting
247
+ if n_pair > 1 and batch_size > 1:
248
+ alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate))
249
+ mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha
250
+ mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair)))
251
+ mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate))
252
+ self.mm_mask = mm_mask.float()
253
+
254
+ def forward(self, x):
255
+ d = torch.diag(x)
256
+ max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \
257
+ F.relu(self.margin + x - d.view(1, -1))
258
+ if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1:
259
+ max_margin = max_margin * self.mm_mask.to(max_margin.device)
260
+ return max_margin.mean()
261
+
262
+ class AllGather(torch.autograd.Function):
263
+ """An autograd function that performs allgather on a tensor."""
264
+
265
+ @staticmethod
266
+ def forward(ctx, tensor, args):
267
+ output = [torch.empty_like(tensor) for _ in range(args.world_size)]
268
+ torch.distributed.all_gather(output, tensor)
269
+ ctx.rank = args.rank
270
+ ctx.batch_size = tensor.shape[0]
271
+ return torch.cat(output, dim=0)
272
+
273
+ @staticmethod
274
+ def backward(ctx, grad_output):
275
+ return (
276
+ grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
277
+ None,
278
+ )