Spaces:
Sleeping
Sleeping
initial commit
Browse files- Dockerfile +13 -0
- app.py +324 -0
- fileservice.py +48 -0
- js/interactive_grid.js +194 -0
- requirements.txt +20 -0
- utils/cross-base/cross_config.json +12 -0
- utils/decoder-base/decoder_config.json +14 -0
- utils/file_utils.py +239 -0
- utils/model.py +204 -0
- utils/module_clip.py +679 -0
- utils/module_cross.py +394 -0
- utils/module_decoder.py +447 -0
- utils/module_gated_attention.py +301 -0
- utils/tokenization_clip.py +149 -0
- utils/until_config.py +126 -0
- utils/until_module.py +278 -0
Dockerfile
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
|
3 |
+
RUN useradd -m -u 1000 user
|
4 |
+
USER user
|
5 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
6 |
+
|
7 |
+
WORKDIR /app
|
8 |
+
|
9 |
+
COPY --chown=user ./requirements.txt requirements.txt
|
10 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
11 |
+
|
12 |
+
COPY --chown=user . /app
|
13 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image, ImageDraw
|
4 |
+
import torch
|
5 |
+
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
6 |
+
from utils.model import init_model
|
7 |
+
from utils.tokenization_clip import SimpleTokenizer as ClipTokenizer
|
8 |
+
|
9 |
+
from fastapi.staticfiles import StaticFiles
|
10 |
+
from fileservice import app
|
11 |
+
|
12 |
+
|
13 |
+
html_text = """
|
14 |
+
<div id="container">
|
15 |
+
<canvas id="canvas" width="512" height="512"></canvas><img id="canvas-background" style="display:none;"/>
|
16 |
+
</div>
|
17 |
+
"""
|
18 |
+
|
19 |
+
def image_to_tensor(image_path):
|
20 |
+
image = Image.open(image_path).convert('RGB')
|
21 |
+
|
22 |
+
preprocess = Compose([
|
23 |
+
Resize([224, 224], interpolation=Image.BICUBIC),
|
24 |
+
lambda image: image.convert("RGB"),
|
25 |
+
ToTensor(),
|
26 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
27 |
+
])
|
28 |
+
image_data = preprocess(image)
|
29 |
+
|
30 |
+
return {'image': image_data}
|
31 |
+
|
32 |
+
def get_image_data(image_path):
|
33 |
+
image_input = image_to_tensor(image_path)
|
34 |
+
return image_input
|
35 |
+
|
36 |
+
def get_intervention_vector(selected_cells_bef, selected_cells_aft):
|
37 |
+
left = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
|
38 |
+
right = np.reshape(np.zeros((1, 14 * 14)), (14, 14))
|
39 |
+
|
40 |
+
for (i, j) in selected_cells_bef:
|
41 |
+
left[i, j] = 1.
|
42 |
+
for (i, j) in selected_cells_aft:
|
43 |
+
right[i, j] = 1.
|
44 |
+
|
45 |
+
|
46 |
+
left_map = np.zeros((1, 14 * 14 + 1))
|
47 |
+
right_map = np.zeros((1, 14 * 14 + 1))
|
48 |
+
|
49 |
+
left_map[0, 1:] = np.reshape(left, (1, 14 * 14))
|
50 |
+
right_map[0, 1:] = np.reshape(right, (1, 14 * 14))
|
51 |
+
|
52 |
+
|
53 |
+
if len(selected_cells_bef) == 0:
|
54 |
+
left_map[0, 0] = 0.0
|
55 |
+
|
56 |
+
if len(selected_cells_aft) == 0:
|
57 |
+
right_map[0, 0] = 0.0
|
58 |
+
|
59 |
+
|
60 |
+
return left_map, right_map
|
61 |
+
|
62 |
+
def _get_rawimage(image_path):
|
63 |
+
# Pair x L x T x 3 x H x W
|
64 |
+
image = np.zeros((1, 3, 224,
|
65 |
+
224), dtype=np.float)
|
66 |
+
|
67 |
+
for i in range(1):
|
68 |
+
|
69 |
+
raw_image_data = get_image_data(image_path)
|
70 |
+
raw_image_data = raw_image_data['image']
|
71 |
+
|
72 |
+
image[i] = raw_image_data
|
73 |
+
|
74 |
+
return image
|
75 |
+
|
76 |
+
|
77 |
+
def greedy_decode(model, tokenizer, video, video_mask, gt_left_map, gt_right_map):
|
78 |
+
visual_output, left_map, right_map = model.get_sequence_visual_output(video, video_mask,
|
79 |
+
gt_left_map[:, 0, :].squeeze(), gt_right_map[:, 0, :].squeeze())
|
80 |
+
|
81 |
+
video_mask = torch.ones(visual_output.shape[0], visual_output.shape[1], device=visual_output.device).long()
|
82 |
+
input_caption_ids = torch.zeros(visual_output.shape[0], device=visual_output.device).data.fill_(tokenizer.vocab["<|startoftext|>"])
|
83 |
+
input_caption_ids = input_caption_ids.long().unsqueeze(1)
|
84 |
+
decoder_mask = torch.ones_like(input_caption_ids)
|
85 |
+
for i in range(32):
|
86 |
+
decoder_scores = model.decoder_caption(visual_output, video_mask, input_caption_ids, decoder_mask, get_logits=True)
|
87 |
+
next_words = decoder_scores[:, -1].max(1)[1].unsqueeze(1)
|
88 |
+
input_caption_ids = torch.cat([input_caption_ids, next_words], 1)
|
89 |
+
next_mask = torch.ones_like(next_words)
|
90 |
+
decoder_mask = torch.cat([decoder_mask, next_mask], 1)
|
91 |
+
|
92 |
+
return input_caption_ids[:, 1:].tolist(), left_map, right_map
|
93 |
+
|
94 |
+
# Dummy prediction function
|
95 |
+
def predict_image(image_bef, image_aft, selected_cells_bef, selected_cells_aft):
|
96 |
+
if image_bef is None:
|
97 |
+
return "No image provided", "", ""
|
98 |
+
if image_aft is None:
|
99 |
+
return "No image provided", "", ""
|
100 |
+
|
101 |
+
|
102 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
103 |
+
|
104 |
+
model = init_model('data/pytorch_model.pt', device)
|
105 |
+
|
106 |
+
tokenizer = ClipTokenizer()
|
107 |
+
|
108 |
+
left_map, right_map = get_intervention_vector(selected_cells_bef, selected_cells_aft)
|
109 |
+
|
110 |
+
left_map, right_map = torch.from_numpy(left_map).unsqueeze(0), torch.from_numpy(right_map).unsqueeze(0)
|
111 |
+
|
112 |
+
bef_image = torch.from_numpy(_get_rawimage(image_bef)).unsqueeze(1)
|
113 |
+
aft_image = torch.from_numpy(_get_rawimage(image_aft)).unsqueeze(1)
|
114 |
+
|
115 |
+
image_pair = torch.cat([bef_image, aft_image], 1)
|
116 |
+
|
117 |
+
image_mask = torch.from_numpy(np.ones(2, dtype=np.long)).unsqueeze(0)
|
118 |
+
|
119 |
+
result_list, left_map, right_map = greedy_decode(model, tokenizer, image_pair, image_mask, left_map, right_map)
|
120 |
+
|
121 |
+
|
122 |
+
decode_text_list = tokenizer.convert_ids_to_tokens(result_list[0])
|
123 |
+
if "<|endoftext|>" in decode_text_list:
|
124 |
+
SEP_index = decode_text_list.index("<|endoftext|>")
|
125 |
+
decode_text_list = decode_text_list[:SEP_index]
|
126 |
+
if "!" in decode_text_list:
|
127 |
+
PAD_index = decode_text_list.index("!")
|
128 |
+
decode_text_list = decode_text_list[:PAD_index]
|
129 |
+
decode_text = decode_text_list.strip()
|
130 |
+
|
131 |
+
# Generate dummy predictions
|
132 |
+
pred = f"{decode_text}"
|
133 |
+
|
134 |
+
# Include information about selected cells
|
135 |
+
selected_info_bef = f"{selected_cells_bef}" if selected_cells_bef else "No image patch was selected"
|
136 |
+
selected_info_aft = f"{selected_cells_aft}" if selected_cells_aft else "No image patch was selected"
|
137 |
+
|
138 |
+
return pred, selected_info_bef, selected_info_aft
|
139 |
+
|
140 |
+
# Add grid to the image
|
141 |
+
def add_grid_to_image(image_path, grid_size=14):
|
142 |
+
if image_path is None:
|
143 |
+
return None
|
144 |
+
|
145 |
+
image = Image.open(image_path)
|
146 |
+
w, h = image.size
|
147 |
+
|
148 |
+
image = image.convert('RGBA')
|
149 |
+
|
150 |
+
draw = ImageDraw.Draw(image)
|
151 |
+
|
152 |
+
x_positions = np.linspace(0, w, grid_size + 1)
|
153 |
+
y_positions = np.linspace(0, h, grid_size + 1)
|
154 |
+
|
155 |
+
# Draw the vertical lines
|
156 |
+
for x in x_positions[1:-1]:
|
157 |
+
line = ((x, 0), (x, h))
|
158 |
+
draw.line(line, fill='white')
|
159 |
+
|
160 |
+
# Draw the horizontal lines
|
161 |
+
for y in y_positions[1:-1]:
|
162 |
+
line = ((0, y), (w, y))
|
163 |
+
draw.line(line, fill='white')
|
164 |
+
|
165 |
+
return image, h, w
|
166 |
+
|
167 |
+
# Handle cell selection
|
168 |
+
def handle_click(image, evt: gr.SelectData, selected_cells, image_path):
|
169 |
+
if image is None:
|
170 |
+
return None, []
|
171 |
+
|
172 |
+
grid_size = 14
|
173 |
+
|
174 |
+
image, h, w = add_grid_to_image(image_path, grid_size)
|
175 |
+
|
176 |
+
x_positions = np.linspace(0, w, grid_size + 1)
|
177 |
+
y_positions = np.linspace(0, h, grid_size + 1)
|
178 |
+
|
179 |
+
# Calculate which cell was clicked
|
180 |
+
for index, x in enumerate(x_positions[:-1]):
|
181 |
+
if evt.index[0] >= x and evt.index[0] <= x_positions[index+1]:
|
182 |
+
row = index
|
183 |
+
|
184 |
+
for index, y in enumerate(y_positions[:-1]):
|
185 |
+
if evt.index[1] >= y and evt.index[1] <= y_positions[index+1]:
|
186 |
+
col = index
|
187 |
+
|
188 |
+
cell_idx = (row, col)
|
189 |
+
|
190 |
+
# Toggle selection
|
191 |
+
if cell_idx in selected_cells:
|
192 |
+
selected_cells.remove(cell_idx)
|
193 |
+
else:
|
194 |
+
selected_cells.append(cell_idx)
|
195 |
+
|
196 |
+
# Add semi-transparent overlay for selected cells
|
197 |
+
highlight_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0)) # Fully transparent layer
|
198 |
+
highlight_draw = ImageDraw.Draw(highlight_layer)
|
199 |
+
|
200 |
+
# Define a lighter green color with 40% transparency
|
201 |
+
light_green = (144, 238, 144, 102) # RGB = (144, 238, 144), Alpha = 102 (40% of 255)
|
202 |
+
|
203 |
+
|
204 |
+
for (row, col) in selected_cells:
|
205 |
+
cell_top_left = (x_positions[row], y_positions[col])
|
206 |
+
cell_bottom_right = (x_positions[row + 1], y_positions[col + 1])
|
207 |
+
|
208 |
+
highlight_draw.rectangle([cell_top_left, cell_bottom_right], fill=light_green, outline='white')
|
209 |
+
|
210 |
+
result_img = Image.alpha_composite(image.convert('RGBA'), highlight_layer)
|
211 |
+
|
212 |
+
return result_img, selected_cells
|
213 |
+
|
214 |
+
|
215 |
+
# Process example images
|
216 |
+
def process_example(image_path_bef, image_path_aft):
|
217 |
+
# Add grid to the example image
|
218 |
+
image_bef_grid, _, _ = add_grid_to_image(image_path_bef, 14)
|
219 |
+
image_aft_grid, _, _ = add_grid_to_image(image_path_aft, 14)
|
220 |
+
return image_bef_grid, image_aft_grid # Reset selected cells and store original image
|
221 |
+
|
222 |
+
def display_image(image_path):
|
223 |
+
image_grid, _, _ = add_grid_to_image(image_path, 14)
|
224 |
+
return image_grid, []
|
225 |
+
|
226 |
+
with gr.Blocks() as demo:
|
227 |
+
gr.Markdown("# TAB: Transformer Attention Bottleneck")
|
228 |
+
|
229 |
+
# Instructions
|
230 |
+
gr.Markdown("""
|
231 |
+
## Instructions:
|
232 |
+
1. Upload an image or select one from the examples
|
233 |
+
2. Click on grid cells to select/deselect them
|
234 |
+
3. Click the 'Predict' button to get model predictions
|
235 |
+
""")
|
236 |
+
|
237 |
+
selected_cells_bef = gr.State([])
|
238 |
+
selected_cells_aft = gr.State([])
|
239 |
+
|
240 |
+
with gr.Row():
|
241 |
+
with gr.Column(scale=1):
|
242 |
+
# Input components with grid overlay
|
243 |
+
image_bef = gr.Image(type="filepath")
|
244 |
+
image_aft = gr.Image(type="filepath")
|
245 |
+
|
246 |
+
predict_btn = gr.Button("Predict")
|
247 |
+
|
248 |
+
with gr.Column(scale=1):
|
249 |
+
|
250 |
+
image_display_with_grid_bef = gr.Image(type="pil", label="Before Image with Grid")
|
251 |
+
image_display_with_grid_aft = gr.Image(type="pil", label="After Image with Grid")
|
252 |
+
|
253 |
+
# Add click event to the displayed image
|
254 |
+
image_display_with_grid_bef.select(
|
255 |
+
handle_click,
|
256 |
+
inputs=[image_display_with_grid_bef, selected_cells_bef, image_bef],
|
257 |
+
outputs=[image_display_with_grid_bef, selected_cells_bef]
|
258 |
+
)
|
259 |
+
|
260 |
+
image_display_with_grid_aft.select(
|
261 |
+
handle_click,
|
262 |
+
inputs=[image_display_with_grid_aft, selected_cells_aft, image_aft],
|
263 |
+
outputs=[image_display_with_grid_aft, selected_cells_aft]
|
264 |
+
)
|
265 |
+
|
266 |
+
with gr.Row():
|
267 |
+
with gr.Column(scale=1):
|
268 |
+
# Example images
|
269 |
+
examples = gr.Examples(
|
270 |
+
examples=[["data/images/CLEVR_default_000572.png", "data/images/CLEVR_semantic_000572.png"],
|
271 |
+
["data/images/CLEVR_default_003339.png", "data/images/CLEVR_semantic_003339.png"]],
|
272 |
+
inputs=[image_bef, image_aft],
|
273 |
+
outputs=[image_display_with_grid_bef, image_display_with_grid_aft],
|
274 |
+
label="Example Images",
|
275 |
+
fn=process_example,
|
276 |
+
examples_per_page=5
|
277 |
+
)
|
278 |
+
|
279 |
+
# image_bef.change(
|
280 |
+
# fn=display_image,
|
281 |
+
# inputs=[image_bef],
|
282 |
+
# outputs=[image_display_with_grid_bef, selected_cells_bef]
|
283 |
+
# )
|
284 |
+
|
285 |
+
# image_aft.change(
|
286 |
+
# fn=display_image,
|
287 |
+
# inputs=[image_aft],
|
288 |
+
# outputs=[image_display_with_grid_aft, selected_cells_aft]
|
289 |
+
# )
|
290 |
+
|
291 |
+
image_bef.change(
|
292 |
+
fn=None,
|
293 |
+
inputs=[image_bef],
|
294 |
+
outputs=[],
|
295 |
+
js="(image) => { initializeEditor(); importBackground(image); return []; }",
|
296 |
+
)
|
297 |
+
|
298 |
+
image_aft.change(
|
299 |
+
fn=None,
|
300 |
+
inputs=[image_aft],
|
301 |
+
outputs=[],
|
302 |
+
js="(image) => { initializeEditor(); importBackground(image); return []; }",
|
303 |
+
)
|
304 |
+
|
305 |
+
with gr.Column(scale=1):
|
306 |
+
# Output components
|
307 |
+
prediction = gr.Textbox(label="Predicted caption")
|
308 |
+
selected_info_bef = gr.Textbox(label="Selected patches on before")
|
309 |
+
selected_info_aft = gr.Textbox(label="Selected patches on after")
|
310 |
+
|
311 |
+
# Connect the predict button to the prediction function
|
312 |
+
predict_btn.click(
|
313 |
+
fn=predict_image,
|
314 |
+
inputs=[image_bef, image_aft, selected_cells_bef, selected_cells_aft],
|
315 |
+
outputs=[prediction, selected_info_bef, selected_info_aft]
|
316 |
+
)
|
317 |
+
|
318 |
+
|
319 |
+
|
320 |
+
app.mount("/js", StaticFiles(directory="js"), name="js")
|
321 |
+
gr.mount_gradio_app(app, demo, path="/")
|
322 |
+
|
323 |
+
|
324 |
+
|
fileservice.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, Request, Response
|
2 |
+
|
3 |
+
filenames = ["js/interactive_grid.js"]
|
4 |
+
contents = "\n".join(
|
5 |
+
[f"<script type='text/javascript' src='{x}'></script>" for x in filenames]
|
6 |
+
)
|
7 |
+
|
8 |
+
ga_script = """
|
9 |
+
<!-- Google tag (gtag.js) -->
|
10 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-11ZHMNWP9Y"></script>
|
11 |
+
<script>
|
12 |
+
window.dataLayer = window.dataLayer || [];
|
13 |
+
function gtag(){dataLayer.push(arguments);}
|
14 |
+
gtag('js', new Date());
|
15 |
+
|
16 |
+
gtag('config', 'G-11ZHMNWP9Y');
|
17 |
+
</script>
|
18 |
+
"""
|
19 |
+
|
20 |
+
app = FastAPI()
|
21 |
+
|
22 |
+
|
23 |
+
@app.middleware("http")
|
24 |
+
async def insert_js(request: Request, call_next):
|
25 |
+
path = request.scope["path"] # get the request route
|
26 |
+
response = await call_next(request)
|
27 |
+
|
28 |
+
if path == "/":
|
29 |
+
response_body = ""
|
30 |
+
async for chunk in response.body_iterator:
|
31 |
+
response_body += chunk.decode()
|
32 |
+
|
33 |
+
charset_tag = '<meta charset="utf-8" />'
|
34 |
+
if charset_tag in response_body:
|
35 |
+
response_body = response_body.replace(charset_tag, charset_tag + ga_script)
|
36 |
+
|
37 |
+
response_body = response_body.replace("</body>", contents + "</body>")
|
38 |
+
|
39 |
+
del response.headers["content-length"]
|
40 |
+
|
41 |
+
return Response(
|
42 |
+
content=response_body,
|
43 |
+
status_code=response.status_code,
|
44 |
+
headers=dict(response.headers),
|
45 |
+
media_type=response.media_type,
|
46 |
+
)
|
47 |
+
|
48 |
+
return response
|
js/interactive_grid.js
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
const gridSize = 14;
|
2 |
+
var cellSize = null;
|
3 |
+
var inputImage = null;
|
4 |
+
var image = null;
|
5 |
+
var canvas = null;
|
6 |
+
var ctx = null;
|
7 |
+
var canvasBg = null;
|
8 |
+
let grid = new Array(gridSize).fill(null).map(() => new Array(gridSize).fill(false));
|
9 |
+
var isInitialized = false;
|
10 |
+
|
11 |
+
let selectedCells = 0;
|
12 |
+
|
13 |
+
function createGrid() {
|
14 |
+
console.log('createGrid')
|
15 |
+
|
16 |
+
for (let i = 0; i < 196; i++) {
|
17 |
+
const div = document.createElement('div');
|
18 |
+
div.classList.add('checkbox');
|
19 |
+
div.innerHTML = '<input type="checkbox">';
|
20 |
+
grid.appendChild(div);
|
21 |
+
}
|
22 |
+
}
|
23 |
+
|
24 |
+
function loadImage(event) {
|
25 |
+
const file = event.target.files[0];
|
26 |
+
const reader = new FileReader();
|
27 |
+
reader.onload = function (e) {
|
28 |
+
image.src = e.target.result;
|
29 |
+
}
|
30 |
+
reader.readAsDataURL(file);
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
function handleMouseDown(event) {
|
35 |
+
// console.log("handleMouseDown");
|
36 |
+
}
|
37 |
+
|
38 |
+
function handleMouseMove(event) {
|
39 |
+
// console.log("handleMouseMove");
|
40 |
+
}
|
41 |
+
|
42 |
+
function handleMouseUp(event) {
|
43 |
+
// console.log("handleMouseUp");
|
44 |
+
}
|
45 |
+
|
46 |
+
function handleMouseLeave(event) {
|
47 |
+
// console.log("handleMouseLeave");
|
48 |
+
}
|
49 |
+
|
50 |
+
|
51 |
+
function drawGrid() {
|
52 |
+
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
53 |
+
drawBackground();
|
54 |
+
for (let row = 0; row < gridSize; row++) {
|
55 |
+
for (let col = 0; col < gridSize; col++) {
|
56 |
+
ctx.beginPath();
|
57 |
+
ctx.rect(col * cellSize, row * cellSize, cellSize, cellSize);
|
58 |
+
ctx.strokeStyle = 'black';
|
59 |
+
ctx.lineWidth = 2;
|
60 |
+
ctx.stroke();
|
61 |
+
|
62 |
+
if (grid[row][col]) {
|
63 |
+
ctx.fillStyle = 'rgba(0, 255, 0, 0.5)';
|
64 |
+
ctx.fillRect(col * cellSize, row * cellSize, cellSize, cellSize);
|
65 |
+
}
|
66 |
+
}
|
67 |
+
}
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
function initializeEditor() {
|
72 |
+
console.log("initializeEditor");
|
73 |
+
|
74 |
+
if (isInitialized) {
|
75 |
+
return;
|
76 |
+
}
|
77 |
+
isInitialized = true;
|
78 |
+
|
79 |
+
image = document.getElementById('image');
|
80 |
+
canvas = document.getElementById('canvas');
|
81 |
+
ctx = canvas.getContext('2d');
|
82 |
+
|
83 |
+
// Add click event listener to canvas
|
84 |
+
canvas.addEventListener('mousedown', handleMouseDown);
|
85 |
+
canvas.addEventListener('mousemove', handleMouseMove);
|
86 |
+
canvas.addEventListener('mouseup', handleMouseUp);
|
87 |
+
canvas.addEventListener('mouseleave', handleMouseLeave);
|
88 |
+
|
89 |
+
cellSize = canvas.width / gridSize;
|
90 |
+
|
91 |
+
canvas.addEventListener('click', (event) => {
|
92 |
+
const rect = canvas.getBoundingClientRect();
|
93 |
+
const scaleX = canvas.width / rect.width;
|
94 |
+
const scaleY = canvas.height / rect.height;
|
95 |
+
const x = (event.clientX - rect.left) * scaleX;
|
96 |
+
const y = (event.clientY - rect.top) * scaleY;
|
97 |
+
const row = Math.floor(y / cellSize);
|
98 |
+
const col = Math.floor(x / cellSize);
|
99 |
+
|
100 |
+
// If the cell is already selected, it's always allowed to deselect it
|
101 |
+
if (grid[row][col]) {
|
102 |
+
grid[row][col] = false;
|
103 |
+
selectedCells--; // Decrement the selected cell count
|
104 |
+
} else {
|
105 |
+
// Only select a new cell if less than 50 cells are already selected
|
106 |
+
if (selectedCells < 50) {
|
107 |
+
grid[row][col] = true;
|
108 |
+
selectedCells++; // Increment the selected cell count
|
109 |
+
}
|
110 |
+
}
|
111 |
+
drawGrid();
|
112 |
+
});
|
113 |
+
|
114 |
+
drawGrid();
|
115 |
+
}
|
116 |
+
|
117 |
+
|
118 |
+
function drawBackground() {
|
119 |
+
if (canvasBg != null) {
|
120 |
+
const canvasWidth = canvas.width;
|
121 |
+
const canvasHeight = canvas.height;
|
122 |
+
|
123 |
+
const bgWidth = canvasBg.width;
|
124 |
+
const bgHeight = canvasBg.height;
|
125 |
+
|
126 |
+
const scaleX = canvasWidth / bgWidth;
|
127 |
+
const scaleY = canvasHeight / bgHeight;
|
128 |
+
|
129 |
+
const scale = Math.max(scaleX, scaleY);
|
130 |
+
|
131 |
+
const newWidth = bgWidth * scale;
|
132 |
+
const newHeight = bgHeight * scale;
|
133 |
+
|
134 |
+
const xOffset = (canvasWidth - newWidth) / 2;
|
135 |
+
const yOffset = (canvasHeight - newHeight) / 2;
|
136 |
+
|
137 |
+
ctx.drawImage(canvasBg, 0, 0, bgWidth, bgHeight, xOffset, yOffset, newWidth, newHeight);
|
138 |
+
}
|
139 |
+
}
|
140 |
+
|
141 |
+
function importBackground(image) {
|
142 |
+
if (image == null) {
|
143 |
+
canvasBg = null;
|
144 |
+
drawGrid();
|
145 |
+
return;
|
146 |
+
}
|
147 |
+
|
148 |
+
let m = new Image();
|
149 |
+
m.src = image;
|
150 |
+
m.onload = function () {
|
151 |
+
canvasBg = m;
|
152 |
+
drawGrid();
|
153 |
+
}
|
154 |
+
}
|
155 |
+
|
156 |
+
function read_js_Data() {
|
157 |
+
console.log("read_js_Data");
|
158 |
+
console.log("read_js_Data");
|
159 |
+
console.log("read_js_Data");
|
160 |
+
console.log("read_js_Data");
|
161 |
+
console.log("read_js_Data");
|
162 |
+
return grid;
|
163 |
+
}
|
164 |
+
|
165 |
+
|
166 |
+
function set_grid_from_data(data) {
|
167 |
+
if (data.length !== gridSize || data[0].length !== gridSize) {
|
168 |
+
throw new Error('Invalid data dimensions. Expected ' + gridSize + 'x' + gridSize);
|
169 |
+
}
|
170 |
+
|
171 |
+
selectedCells = 0; // Reset the selected cell count
|
172 |
+
for (let row = 0; row < gridSize; row++) {
|
173 |
+
for (let col = 0; col < gridSize; col++) {
|
174 |
+
grid[row][col] = data[row][col];
|
175 |
+
if (grid[row][col]) {
|
176 |
+
selectedCells++; // Count the number of initially selected cells
|
177 |
+
}
|
178 |
+
}
|
179 |
+
}
|
180 |
+
|
181 |
+
drawGrid();
|
182 |
+
}
|
183 |
+
|
184 |
+
|
185 |
+
function clear_grid() {
|
186 |
+
console.log("clearGrid");
|
187 |
+
for (let row = 0; row < gridSize; row++) {
|
188 |
+
for (let col = 0; col < gridSize; col++) {
|
189 |
+
grid[row][col] = false;
|
190 |
+
}
|
191 |
+
}
|
192 |
+
selectedCells = 0; // Reset the selected cell count
|
193 |
+
drawGrid();
|
194 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
boto3==1.36.9
|
2 |
+
botocore==1.36.9
|
3 |
+
certifi==2024.12.14
|
4 |
+
charset-normalizer==3.4.1
|
5 |
+
ftfy==6.2.3
|
6 |
+
idna==3.10
|
7 |
+
jmespath==1.0.1
|
8 |
+
numpy==1.23.0
|
9 |
+
opencv-python==4.11.0.86
|
10 |
+
Pillow==9.3.0
|
11 |
+
regex==2024.11.6
|
12 |
+
requests==2.32.3
|
13 |
+
s3transfer==0.11.2
|
14 |
+
gradio
|
15 |
+
torch
|
16 |
+
torchvision
|
17 |
+
torchaudio
|
18 |
+
tqdm==4.67.1
|
19 |
+
fastapi
|
20 |
+
uvicorn[standard]
|
utils/cross-base/cross_config.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"attention_probs_dropout_prob": 0.1,
|
3 |
+
"hidden_act": "gelu",
|
4 |
+
"hidden_dropout_prob": 0.1,
|
5 |
+
"hidden_size": 512,
|
6 |
+
"initializer_range": 0.02,
|
7 |
+
"intermediate_size": 2048,
|
8 |
+
"max_position_embeddings": 150,
|
9 |
+
"num_attention_heads": 8,
|
10 |
+
"num_hidden_layers": 2,
|
11 |
+
"vocab_size": 512
|
12 |
+
}
|
utils/decoder-base/decoder_config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"attention_probs_dropout_prob": 0.1,
|
3 |
+
"hidden_act": "gelu",
|
4 |
+
"hidden_dropout_prob": 0.1,
|
5 |
+
"hidden_size": 512,
|
6 |
+
"initializer_range": 0.02,
|
7 |
+
"intermediate_size": 2048,
|
8 |
+
"num_attention_heads": 8,
|
9 |
+
"num_hidden_layers": 12,
|
10 |
+
"type_vocab_size": 2,
|
11 |
+
"vocab_size": 49408,
|
12 |
+
"num_decoder_layers": 3,
|
13 |
+
"max_target_embeddings": 77
|
14 |
+
}
|
utils/file_utils.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for working with the local dataset cache.
|
3 |
+
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
4 |
+
Copyright by the AllenNLP authors.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
import logging
|
9 |
+
import shutil
|
10 |
+
import tempfile
|
11 |
+
import json
|
12 |
+
from urllib.parse import urlparse
|
13 |
+
from pathlib import Path
|
14 |
+
from typing import Optional, Tuple, Union, IO, Callable, Set
|
15 |
+
from hashlib import sha256
|
16 |
+
from functools import wraps
|
17 |
+
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
import boto3
|
21 |
+
from botocore.exceptions import ClientError
|
22 |
+
import requests
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
25 |
+
|
26 |
+
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
27 |
+
Path.home() / '.pytorch_pretrained_bert'))
|
28 |
+
|
29 |
+
|
30 |
+
def url_to_filename(url: str, etag: str = None) -> str:
|
31 |
+
"""
|
32 |
+
Convert `url` into a hashed filename in a repeatable way.
|
33 |
+
If `etag` is specified, append its hash to the url's, delimited
|
34 |
+
by a period.
|
35 |
+
"""
|
36 |
+
url_bytes = url.encode('utf-8')
|
37 |
+
url_hash = sha256(url_bytes)
|
38 |
+
filename = url_hash.hexdigest()
|
39 |
+
|
40 |
+
if etag:
|
41 |
+
etag_bytes = etag.encode('utf-8')
|
42 |
+
etag_hash = sha256(etag_bytes)
|
43 |
+
filename += '.' + etag_hash.hexdigest()
|
44 |
+
|
45 |
+
return filename
|
46 |
+
|
47 |
+
|
48 |
+
def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
|
49 |
+
"""
|
50 |
+
Return the url and etag (which may be ``None``) stored for `filename`.
|
51 |
+
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
|
52 |
+
"""
|
53 |
+
if cache_dir is None:
|
54 |
+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
55 |
+
if isinstance(cache_dir, Path):
|
56 |
+
cache_dir = str(cache_dir)
|
57 |
+
|
58 |
+
cache_path = os.path.join(cache_dir, filename)
|
59 |
+
if not os.path.exists(cache_path):
|
60 |
+
raise FileNotFoundError("file {} not found".format(cache_path))
|
61 |
+
|
62 |
+
meta_path = cache_path + '.json'
|
63 |
+
if not os.path.exists(meta_path):
|
64 |
+
raise FileNotFoundError("file {} not found".format(meta_path))
|
65 |
+
|
66 |
+
with open(meta_path) as meta_file:
|
67 |
+
metadata = json.load(meta_file)
|
68 |
+
url = metadata['url']
|
69 |
+
etag = metadata['etag']
|
70 |
+
|
71 |
+
return url, etag
|
72 |
+
|
73 |
+
|
74 |
+
def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
|
75 |
+
"""
|
76 |
+
Given something that might be a URL (or might be a local path),
|
77 |
+
determine which. If it's a URL, download the file and cache it, and
|
78 |
+
return the path to the cached file. If it's already a local path,
|
79 |
+
make sure the file exists and then return the path.
|
80 |
+
"""
|
81 |
+
if cache_dir is None:
|
82 |
+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
83 |
+
if isinstance(url_or_filename, Path):
|
84 |
+
url_or_filename = str(url_or_filename)
|
85 |
+
if isinstance(cache_dir, Path):
|
86 |
+
cache_dir = str(cache_dir)
|
87 |
+
|
88 |
+
parsed = urlparse(url_or_filename)
|
89 |
+
|
90 |
+
if parsed.scheme in ('http', 'https', 's3'):
|
91 |
+
# URL, so get it from the cache (downloading if necessary)
|
92 |
+
return get_from_cache(url_or_filename, cache_dir)
|
93 |
+
elif os.path.exists(url_or_filename):
|
94 |
+
# File, and it exists.
|
95 |
+
return url_or_filename
|
96 |
+
elif parsed.scheme == '':
|
97 |
+
# File, but it doesn't exist.
|
98 |
+
raise FileNotFoundError("file {} not found".format(url_or_filename))
|
99 |
+
else:
|
100 |
+
# Something unknown
|
101 |
+
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
102 |
+
|
103 |
+
|
104 |
+
def split_s3_path(url: str) -> Tuple[str, str]:
|
105 |
+
"""Split a full s3 path into the bucket name and path."""
|
106 |
+
parsed = urlparse(url)
|
107 |
+
if not parsed.netloc or not parsed.path:
|
108 |
+
raise ValueError("bad s3 path {}".format(url))
|
109 |
+
bucket_name = parsed.netloc
|
110 |
+
s3_path = parsed.path
|
111 |
+
# Remove '/' at beginning of path.
|
112 |
+
if s3_path.startswith("/"):
|
113 |
+
s3_path = s3_path[1:]
|
114 |
+
return bucket_name, s3_path
|
115 |
+
|
116 |
+
|
117 |
+
def s3_request(func: Callable):
|
118 |
+
"""
|
119 |
+
Wrapper function for s3 requests in order to create more helpful error
|
120 |
+
messages.
|
121 |
+
"""
|
122 |
+
|
123 |
+
@wraps(func)
|
124 |
+
def wrapper(url: str, *args, **kwargs):
|
125 |
+
try:
|
126 |
+
return func(url, *args, **kwargs)
|
127 |
+
except ClientError as exc:
|
128 |
+
if int(exc.response["Error"]["Code"]) == 404:
|
129 |
+
raise FileNotFoundError("file {} not found".format(url))
|
130 |
+
else:
|
131 |
+
raise
|
132 |
+
|
133 |
+
return wrapper
|
134 |
+
|
135 |
+
|
136 |
+
@s3_request
|
137 |
+
def s3_etag(url: str) -> Optional[str]:
|
138 |
+
"""Check ETag on S3 object."""
|
139 |
+
s3_resource = boto3.resource("s3")
|
140 |
+
bucket_name, s3_path = split_s3_path(url)
|
141 |
+
s3_object = s3_resource.Object(bucket_name, s3_path)
|
142 |
+
return s3_object.e_tag
|
143 |
+
|
144 |
+
|
145 |
+
@s3_request
|
146 |
+
def s3_get(url: str, temp_file: IO) -> None:
|
147 |
+
"""Pull a file directly from S3."""
|
148 |
+
s3_resource = boto3.resource("s3")
|
149 |
+
bucket_name, s3_path = split_s3_path(url)
|
150 |
+
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
151 |
+
|
152 |
+
|
153 |
+
def http_get(url: str, temp_file: IO) -> None:
|
154 |
+
req = requests.get(url, stream=True)
|
155 |
+
content_length = req.headers.get('Content-Length')
|
156 |
+
total = int(content_length) if content_length is not None else None
|
157 |
+
progress = tqdm(unit="B", total=total)
|
158 |
+
for chunk in req.iter_content(chunk_size=1024):
|
159 |
+
if chunk: # filter out keep-alive new chunks
|
160 |
+
progress.update(len(chunk))
|
161 |
+
temp_file.write(chunk)
|
162 |
+
progress.close()
|
163 |
+
|
164 |
+
|
165 |
+
def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
|
166 |
+
"""
|
167 |
+
Given a URL, look for the corresponding dataset in the local cache.
|
168 |
+
If it's not there, download it. Then return the path to the cached file.
|
169 |
+
"""
|
170 |
+
if cache_dir is None:
|
171 |
+
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
172 |
+
if isinstance(cache_dir, Path):
|
173 |
+
cache_dir = str(cache_dir)
|
174 |
+
|
175 |
+
os.makedirs(cache_dir, exist_ok=True)
|
176 |
+
|
177 |
+
# Get eTag to add to filename, if it exists.
|
178 |
+
if url.startswith("s3://"):
|
179 |
+
etag = s3_etag(url)
|
180 |
+
else:
|
181 |
+
response = requests.head(url, allow_redirects=True)
|
182 |
+
if response.status_code != 200:
|
183 |
+
raise IOError("HEAD request failed for url {} with status code {}"
|
184 |
+
.format(url, response.status_code))
|
185 |
+
etag = response.headers.get("ETag")
|
186 |
+
|
187 |
+
filename = url_to_filename(url, etag)
|
188 |
+
|
189 |
+
# get cache path to put the file
|
190 |
+
cache_path = os.path.join(cache_dir, filename)
|
191 |
+
|
192 |
+
if not os.path.exists(cache_path):
|
193 |
+
# Download to temporary file, then copy to cache dir once finished.
|
194 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
195 |
+
with tempfile.NamedTemporaryFile() as temp_file:
|
196 |
+
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
|
197 |
+
|
198 |
+
# GET file object
|
199 |
+
if url.startswith("s3://"):
|
200 |
+
s3_get(url, temp_file)
|
201 |
+
else:
|
202 |
+
http_get(url, temp_file)
|
203 |
+
|
204 |
+
# we are copying the file before closing it, so flush to avoid truncation
|
205 |
+
temp_file.flush()
|
206 |
+
# shutil.copyfileobj() starts at the current position, so go to the start
|
207 |
+
temp_file.seek(0)
|
208 |
+
|
209 |
+
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
|
210 |
+
with open(cache_path, 'wb') as cache_file:
|
211 |
+
shutil.copyfileobj(temp_file, cache_file)
|
212 |
+
|
213 |
+
logger.info("creating metadata file for %s", cache_path)
|
214 |
+
meta = {'url': url, 'etag': etag}
|
215 |
+
meta_path = cache_path + '.json'
|
216 |
+
with open(meta_path, 'w') as meta_file:
|
217 |
+
json.dump(meta, meta_file)
|
218 |
+
|
219 |
+
logger.info("removing temp file %s", temp_file.name)
|
220 |
+
|
221 |
+
return cache_path
|
222 |
+
|
223 |
+
|
224 |
+
def read_set_from_file(filename: str) -> Set[str]:
|
225 |
+
'''
|
226 |
+
Extract a de-duped collection (set) of text from a file.
|
227 |
+
Expected file format is one item per line.
|
228 |
+
'''
|
229 |
+
collection = set()
|
230 |
+
with open(filename, 'r', encoding='utf-8') as file_:
|
231 |
+
for line in file_:
|
232 |
+
collection.add(line.rstrip())
|
233 |
+
return collection
|
234 |
+
|
235 |
+
|
236 |
+
def get_file_extension(path: str, dot=True, lower: bool = True):
|
237 |
+
ext = os.path.splitext(path)[1]
|
238 |
+
ext = ext if dot else ext[1:]
|
239 |
+
return ext.lower() if lower else ext
|
utils/model.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import
|
2 |
+
from __future__ import division
|
3 |
+
from __future__ import print_function
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from .until_module import PreTrainedModel
|
9 |
+
from .module_cross import CrossModel, CrossConfig
|
10 |
+
from .module_decoder import DecoderModel, DecoderConfig
|
11 |
+
|
12 |
+
from utils.module_clip import CLIP, convert_weights
|
13 |
+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
|
14 |
+
|
15 |
+
|
16 |
+
def update_attr(target_name, target_config, target_attr_name, source_config, source_attr_name, default_value=None):
|
17 |
+
if hasattr(source_config, source_attr_name):
|
18 |
+
if default_value is None or getattr(source_config, source_attr_name) != default_value:
|
19 |
+
setattr(target_config, target_attr_name, getattr(source_config, source_attr_name))
|
20 |
+
return target_config
|
21 |
+
|
22 |
+
class CLIP4IDCPreTrainedModel(PreTrainedModel, nn.Module):
|
23 |
+
""" An abstract class to handle weights initialization and
|
24 |
+
a simple interface for dowloading and loading pretrained models.
|
25 |
+
"""
|
26 |
+
def __init__(self, cross_config, decoder_config, *inputs, **kwargs):
|
27 |
+
super(CLIP4IDCPreTrainedModel, self).__init__(cross_config, decoder_config)
|
28 |
+
self.cross_config = cross_config
|
29 |
+
self.decoder_config = decoder_config
|
30 |
+
self.clip = None
|
31 |
+
self.cross = None
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def from_pretrained(cls, cross_model_name, decoder_model_name, state_dict=None, cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
|
35 |
+
|
36 |
+
|
37 |
+
if state_dict is None: state_dict = {}
|
38 |
+
pretrained_clip_name = "ViT-B/16"
|
39 |
+
clip_state_dict = CLIP.get_config(pretrained_clip_name=pretrained_clip_name)
|
40 |
+
for key, val in clip_state_dict.items():
|
41 |
+
new_key = "clip." + key
|
42 |
+
if new_key not in state_dict:
|
43 |
+
state_dict[new_key] = val.clone()
|
44 |
+
|
45 |
+
cross_config, _ = CrossConfig.get_config(cross_model_name, cache_dir, type_vocab_size, state_dict=None)
|
46 |
+
decoder_config, _ = DecoderConfig.get_config(decoder_model_name, cache_dir, type_vocab_size, state_dict=None)
|
47 |
+
|
48 |
+
model = cls(cross_config, decoder_config, clip_state_dict, *inputs, **kwargs)
|
49 |
+
|
50 |
+
## ===> Initialization trick [HARD CODE]
|
51 |
+
if model.linear_patch == "3d":
|
52 |
+
contain_conv2 = False
|
53 |
+
for key in state_dict.keys():
|
54 |
+
if key.find("visual.conv2.weight") > -1:
|
55 |
+
contain_conv2 = True
|
56 |
+
break
|
57 |
+
if contain_conv2 is False and hasattr(model.clip.visual, "conv2"):
|
58 |
+
cp_weight = state_dict["clip.visual.conv1.weight"].clone()
|
59 |
+
kernel_size = model.clip.visual.conv2.weight.size(2)
|
60 |
+
conv2_size = model.clip.visual.conv2.weight.size()
|
61 |
+
conv2_size = list(conv2_size)
|
62 |
+
|
63 |
+
left_conv2_size = conv2_size.copy()
|
64 |
+
right_conv2_size = conv2_size.copy()
|
65 |
+
left_conv2_size[2] = (kernel_size - 1) // 2
|
66 |
+
right_conv2_size[2] = kernel_size - 1 - left_conv2_size[2]
|
67 |
+
|
68 |
+
left_zeros, right_zeros = None, None
|
69 |
+
if left_conv2_size[2] > 0:
|
70 |
+
left_zeros = torch.zeros(*tuple(left_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device)
|
71 |
+
if right_conv2_size[2] > 0:
|
72 |
+
right_zeros = torch.zeros(*tuple(right_conv2_size), dtype=cp_weight.dtype, device=cp_weight.device)
|
73 |
+
|
74 |
+
cat_list = []
|
75 |
+
if left_zeros != None: cat_list.append(left_zeros)
|
76 |
+
cat_list.append(cp_weight.unsqueeze(2))
|
77 |
+
if right_zeros != None: cat_list.append(right_zeros)
|
78 |
+
cp_weight = torch.cat(cat_list, dim=2)
|
79 |
+
|
80 |
+
state_dict["clip.visual.conv2.weight"] = cp_weight
|
81 |
+
|
82 |
+
## <=== End of initialization trick
|
83 |
+
|
84 |
+
if state_dict is not None:
|
85 |
+
model = cls.init_preweight(model, state_dict)
|
86 |
+
|
87 |
+
return model
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
class CLIP4IDC(CLIP4IDCPreTrainedModel):
|
92 |
+
def __init__(self, cross_config, decoder_config, clip_state_dict):
|
93 |
+
super(CLIP4IDC, self).__init__(cross_config, decoder_config)
|
94 |
+
self.ignore_video_index = -1
|
95 |
+
|
96 |
+
# assert self.task_config.max_words <= cross_config.max_position_embeddings
|
97 |
+
|
98 |
+
# CLIP Encoders: From OpenAI: CLIP [https://github.com/openai/CLIP] ===>
|
99 |
+
vit = "visual.proj" in clip_state_dict
|
100 |
+
assert vit
|
101 |
+
if vit:
|
102 |
+
vision_width = clip_state_dict["visual.conv1.weight"].shape[0]
|
103 |
+
vision_layers = len(
|
104 |
+
[k for k in clip_state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
105 |
+
vision_patch_size = clip_state_dict["visual.conv1.weight"].shape[-1]
|
106 |
+
grid_size = round((clip_state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
107 |
+
image_resolution = vision_patch_size * grid_size
|
108 |
+
else:
|
109 |
+
counts: list = [len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"visual.layer{b}"))) for b in
|
110 |
+
[1, 2, 3, 4]]
|
111 |
+
vision_layers = tuple(counts)
|
112 |
+
vision_width = clip_state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
113 |
+
output_width = round((clip_state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
114 |
+
vision_patch_size = None
|
115 |
+
assert output_width ** 2 + 1 == clip_state_dict["visual.attnpool.positional_embedding"].shape[0]
|
116 |
+
image_resolution = output_width * 32
|
117 |
+
|
118 |
+
embed_dim = clip_state_dict["text_projection"].shape[1]
|
119 |
+
context_length = clip_state_dict["positional_embedding"].shape[0]
|
120 |
+
vocab_size = clip_state_dict["token_embedding.weight"].shape[0]
|
121 |
+
transformer_width = clip_state_dict["ln_final.weight"].shape[0]
|
122 |
+
transformer_heads = transformer_width // 64
|
123 |
+
transformer_layers = len(set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"transformer.resblocks")))
|
124 |
+
|
125 |
+
self.linear_patch = '2d'
|
126 |
+
|
127 |
+
# use .float() to avoid overflow/underflow from fp16 weight. https://github.com/openai/CLIP/issues/40
|
128 |
+
cut_top_layer = 0
|
129 |
+
self.clip = CLIP(
|
130 |
+
embed_dim,
|
131 |
+
image_resolution, vision_layers-cut_top_layer, vision_width, vision_patch_size,
|
132 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers-cut_top_layer,
|
133 |
+
linear_patch=self.linear_patch, intra_layers=9
|
134 |
+
).float()
|
135 |
+
|
136 |
+
bert_word_embeddings_weight = self.clip.token_embedding.weight
|
137 |
+
bert_position_embeddings_weight = self.clip.positional_embedding
|
138 |
+
|
139 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
140 |
+
if key in clip_state_dict:
|
141 |
+
del clip_state_dict[key]
|
142 |
+
|
143 |
+
convert_weights(self.clip)
|
144 |
+
# <=== End of CLIP Encoders
|
145 |
+
|
146 |
+
self.decoder = DecoderModel(decoder_config, bert_word_embeddings_weight, bert_position_embeddings_weight)
|
147 |
+
|
148 |
+
self.apply(self.init_weights)
|
149 |
+
|
150 |
+
def get_visual_output(self, video, visual_mask, left_gt_map, right_gt_map, shaped=False, video_frame=-1):
|
151 |
+
|
152 |
+
bs_pair = visual_mask.size(0)
|
153 |
+
visual_hidden, visual_output, left_map, right_map = self.clip.encode_image(video, left_gt_map, right_gt_map, video_frame=video_frame, return_hidden=True)
|
154 |
+
visual_hidden = visual_hidden.float()
|
155 |
+
visual_output = visual_output.float()
|
156 |
+
visual_hidden = visual_hidden.view(bs_pair, -1, visual_hidden.size(-1))
|
157 |
+
|
158 |
+
left_map = left_map.float()
|
159 |
+
right_map = right_map.float()
|
160 |
+
|
161 |
+
return visual_hidden, visual_output, left_map, right_map
|
162 |
+
|
163 |
+
def get_sequence_visual_output(self, video, visual_mask, left_gt_map, right_gt_map, shaped=False, video_frame=-1):
|
164 |
+
if shaped is False:
|
165 |
+
visual_mask = visual_mask.view(-1, visual_mask.shape[-1])
|
166 |
+
video = torch.as_tensor(video).float()
|
167 |
+
b, pair, channel, h, w = video.shape
|
168 |
+
video = video.view(b * pair, channel, h, w)
|
169 |
+
video_frame = pair
|
170 |
+
|
171 |
+
_, visual_hidden, left_map, right_map = self.get_visual_output(video, visual_mask, left_gt_map, right_gt_map, shaped=True, video_frame=video_frame)
|
172 |
+
|
173 |
+
return visual_hidden, left_map, right_map
|
174 |
+
|
175 |
+
def _get_decoder_score(self, visual_output, visual_mask, input_caption_ids, decoder_mask):
|
176 |
+
res_tuples = ()
|
177 |
+
decoder_scores = self.decoder(input_caption_ids, encoder_outs=visual_output, answer_mask=decoder_mask, encoder_mask=visual_mask)
|
178 |
+
|
179 |
+
return decoder_scores, res_tuples
|
180 |
+
|
181 |
+
def decoder_caption(self, visual_output, visual_mask, input_caption_ids, decoder_mask, get_logits=False):
|
182 |
+
|
183 |
+
decoder_scores, _ = self._get_decoder_score(visual_output, visual_mask,
|
184 |
+
input_caption_ids, decoder_mask)
|
185 |
+
|
186 |
+
if get_logits:
|
187 |
+
return decoder_scores
|
188 |
+
|
189 |
+
_, decoder_scores_result = torch.max(decoder_scores, -1)
|
190 |
+
|
191 |
+
return decoder_scores_result
|
192 |
+
|
193 |
+
|
194 |
+
def init_model(model_path, device):
|
195 |
+
|
196 |
+
model_state_dict = torch.load(model_path, map_location='cpu')
|
197 |
+
|
198 |
+
# Prepare model
|
199 |
+
cache_dir = ""
|
200 |
+
model = CLIP4IDC.from_pretrained("cross-base", "decoder-base", cache_dir=cache_dir, state_dict=model_state_dict)
|
201 |
+
|
202 |
+
model.to(device)
|
203 |
+
|
204 |
+
return model
|
utils/module_clip.py
ADDED
@@ -0,0 +1,679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Adapted from: https://github.com/openai/CLIP/blob/main/clip/clip.py
|
3 |
+
"""
|
4 |
+
import warnings
|
5 |
+
from collections import OrderedDict
|
6 |
+
from typing import Tuple, Union, Optional
|
7 |
+
|
8 |
+
import hashlib
|
9 |
+
import os
|
10 |
+
import urllib
|
11 |
+
import warnings
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
from torch import Tensor
|
18 |
+
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
19 |
+
from torch.nn.init import xavier_uniform_
|
20 |
+
from torch.nn.init import constant_
|
21 |
+
from torch.nn.init import xavier_normal_
|
22 |
+
from torch.nn.parameter import Parameter
|
23 |
+
|
24 |
+
from torch.nn.modules.module import Module
|
25 |
+
from .module_gated_attention import gated_coattention
|
26 |
+
from torch import nn
|
27 |
+
|
28 |
+
|
29 |
+
_MODELS = {
|
30 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
31 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
32 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
33 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
34 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
35 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
36 |
+
}
|
37 |
+
_PT_NAME = {
|
38 |
+
"RN50": "RN50.pt",
|
39 |
+
"RN101": "RN101.pt",
|
40 |
+
"RN50x4": "RN50x4.pt",
|
41 |
+
"RN50x16": "RN50x16.pt",
|
42 |
+
"ViT-B/32": "ViT-B-32.pt",
|
43 |
+
"ViT-B/16": "ViT-B-16.pt",
|
44 |
+
}
|
45 |
+
|
46 |
+
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
|
47 |
+
os.makedirs(root, exist_ok=True)
|
48 |
+
filename = os.path.basename(url)
|
49 |
+
|
50 |
+
expected_sha256 = url.split("/")[-2]
|
51 |
+
download_target = os.path.join(root, filename)
|
52 |
+
|
53 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
54 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
55 |
+
|
56 |
+
if os.path.isfile(download_target):
|
57 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
58 |
+
return download_target
|
59 |
+
else:
|
60 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
61 |
+
|
62 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
63 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
64 |
+
while True:
|
65 |
+
buffer = source.read(8192)
|
66 |
+
if not buffer:
|
67 |
+
break
|
68 |
+
|
69 |
+
output.write(buffer)
|
70 |
+
loop.update(len(buffer))
|
71 |
+
|
72 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
73 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
74 |
+
|
75 |
+
return download_target
|
76 |
+
|
77 |
+
def available_models():
|
78 |
+
"""Returns the names of available CLIP models"""
|
79 |
+
return list(_MODELS.keys())
|
80 |
+
|
81 |
+
# =============================
|
82 |
+
|
83 |
+
|
84 |
+
class TABAttention(Module):
|
85 |
+
r"""Allows the model to jointly attend to information
|
86 |
+
from different representation subspaces.
|
87 |
+
See `Attention Is All You Need <https://arxiv.org/abs/1706.03762>`_
|
88 |
+
|
89 |
+
.. math::
|
90 |
+
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
91 |
+
|
92 |
+
where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
embed_dim: total dimension of the model.
|
96 |
+
num_heads: parallel attention heads.
|
97 |
+
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
98 |
+
bias: add bias as module parameter. Default: True.
|
99 |
+
add_bias_kv: add bias to the key and value sequences at dim=0.
|
100 |
+
add_zero_attn: add a new batch of zeros to the key and
|
101 |
+
value sequences at dim=1.
|
102 |
+
kdim: total number of features in key. Default: None.
|
103 |
+
vdim: total number of features in value. Default: None.
|
104 |
+
|
105 |
+
Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
|
106 |
+
to :attr:`embed_dim` such that query, key, and value have the same
|
107 |
+
number of features.
|
108 |
+
|
109 |
+
Examples::
|
110 |
+
|
111 |
+
>>> multihead_attn = TABAttention(embed_dim, num_heads)
|
112 |
+
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
This is a version of multihead attention written to comply with the defintion of TAB!!!
|
117 |
+
"""
|
118 |
+
bias_k: Optional[torch.Tensor]
|
119 |
+
bias_v: Optional[torch.Tensor]
|
120 |
+
|
121 |
+
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
|
122 |
+
super(TABAttention, self).__init__()
|
123 |
+
self.embed_dim = embed_dim
|
124 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
125 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
126 |
+
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
127 |
+
|
128 |
+
self.num_heads = num_heads
|
129 |
+
self.dropout = dropout
|
130 |
+
self.head_dim = embed_dim // num_heads
|
131 |
+
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
132 |
+
|
133 |
+
if self._qkv_same_embed_dim is False:
|
134 |
+
self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
|
135 |
+
self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
|
136 |
+
self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
|
137 |
+
self.register_parameter('in_proj_weight', None)
|
138 |
+
else:
|
139 |
+
self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
|
140 |
+
self.register_parameter('q_proj_weight', None)
|
141 |
+
self.register_parameter('k_proj_weight', None)
|
142 |
+
self.register_parameter('v_proj_weight', None)
|
143 |
+
|
144 |
+
if bias:
|
145 |
+
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
|
146 |
+
else:
|
147 |
+
self.register_parameter('in_proj_bias', None)
|
148 |
+
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias)
|
149 |
+
|
150 |
+
if add_bias_kv:
|
151 |
+
self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
|
152 |
+
self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
|
153 |
+
else:
|
154 |
+
self.bias_k = self.bias_v = None
|
155 |
+
|
156 |
+
self.add_zero_attn = add_zero_attn
|
157 |
+
|
158 |
+
self._reset_parameters()
|
159 |
+
|
160 |
+
def _reset_parameters(self):
|
161 |
+
if self._qkv_same_embed_dim:
|
162 |
+
xavier_uniform_(self.in_proj_weight)
|
163 |
+
else:
|
164 |
+
xavier_uniform_(self.q_proj_weight)
|
165 |
+
xavier_uniform_(self.k_proj_weight)
|
166 |
+
xavier_uniform_(self.v_proj_weight)
|
167 |
+
|
168 |
+
if self.in_proj_bias is not None:
|
169 |
+
constant_(self.in_proj_bias, 0.)
|
170 |
+
constant_(self.out_proj.bias, 0.)
|
171 |
+
if self.bias_k is not None:
|
172 |
+
xavier_normal_(self.bias_k)
|
173 |
+
if self.bias_v is not None:
|
174 |
+
xavier_normal_(self.bias_v)
|
175 |
+
|
176 |
+
def __setstate__(self, state):
|
177 |
+
# Support loading old TABAttention checkpoints generated by v1.1.0
|
178 |
+
if '_qkv_same_embed_dim' not in state:
|
179 |
+
state['_qkv_same_embed_dim'] = True
|
180 |
+
|
181 |
+
super(TABAttention, self).__setstate__(state)
|
182 |
+
|
183 |
+
def forward(self, query: Tensor, key: Tensor, value: Tensor, gt_attention_map: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None,
|
184 |
+
need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
|
185 |
+
r"""
|
186 |
+
Args:
|
187 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
188 |
+
See "Attention Is All You Need" for more details.
|
189 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
190 |
+
be ignored by the attention. When given a binary mask and a value is True,
|
191 |
+
the corresponding value on the attention layer will be ignored. When given
|
192 |
+
a byte mask and a value is non-zero, the corresponding value on the attention
|
193 |
+
layer will be ignored
|
194 |
+
need_weights: output attn_output_weights.
|
195 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
196 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
197 |
+
|
198 |
+
Shapes for inputs:
|
199 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
200 |
+
the embedding dimension.
|
201 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
202 |
+
the embedding dimension.
|
203 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
204 |
+
the embedding dimension.
|
205 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
206 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
207 |
+
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
208 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
209 |
+
- attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
|
210 |
+
source sequence length.
|
211 |
+
|
212 |
+
If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
|
213 |
+
length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
|
214 |
+
the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
215 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
216 |
+
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
217 |
+
is provided, it will be added to the attention weight.
|
218 |
+
|
219 |
+
Shapes for outputs:
|
220 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
221 |
+
E is the embedding dimension.
|
222 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
223 |
+
L is the target sequence length, S is the source sequence length.
|
224 |
+
"""
|
225 |
+
if not self._qkv_same_embed_dim:
|
226 |
+
return gated_coattention(
|
227 |
+
query, key, value, self.embed_dim, self.num_heads,
|
228 |
+
self.in_proj_weight.half(), self.in_proj_bias.half(),
|
229 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
230 |
+
self.dropout, self.out_proj.weight.half(), self.out_proj.bias.half(),
|
231 |
+
training=self.training, gt_attention_map=gt_attention_map,
|
232 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
233 |
+
attn_mask=attn_mask, use_separate_proj_weight=True,
|
234 |
+
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
235 |
+
v_proj_weight=self.v_proj_weight)
|
236 |
+
else:
|
237 |
+
return gated_coattention(
|
238 |
+
query, key, value, self.embed_dim, self.num_heads,
|
239 |
+
self.in_proj_weight.half(), self.in_proj_bias.half(),
|
240 |
+
self.bias_k, self.bias_v, self.add_zero_attn,
|
241 |
+
self.dropout, self.out_proj.weight.half(), self.out_proj.bias.half(),
|
242 |
+
training=self.training, gt_attention_map=gt_attention_map,
|
243 |
+
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
244 |
+
attn_mask=attn_mask)
|
245 |
+
|
246 |
+
|
247 |
+
class LayerNorm(nn.LayerNorm):
|
248 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
249 |
+
|
250 |
+
def forward(self, x: torch.Tensor):
|
251 |
+
orig_type = x.dtype
|
252 |
+
ret = super().forward(x.type(torch.float32))
|
253 |
+
return ret.type(orig_type)
|
254 |
+
|
255 |
+
|
256 |
+
class QuickGELU(nn.Module):
|
257 |
+
def forward(self, x: torch.Tensor):
|
258 |
+
return x * torch.sigmoid(1.702 * x)
|
259 |
+
|
260 |
+
|
261 |
+
class ResidualAttentionBlock(nn.Module):
|
262 |
+
def __init__(self, d_model: int, n_head: int, attn_mask=None):
|
263 |
+
super().__init__()
|
264 |
+
|
265 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
266 |
+
self.ln_1 = LayerNorm(d_model)
|
267 |
+
self.mlp = nn.Sequential(OrderedDict([
|
268 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
269 |
+
("gelu", QuickGELU()),
|
270 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
271 |
+
]))
|
272 |
+
self.ln_2 = LayerNorm(d_model)
|
273 |
+
self.attn_mask = attn_mask
|
274 |
+
|
275 |
+
def attention(self, x: torch.Tensor):
|
276 |
+
attn_mask_ = self.attn_mask
|
277 |
+
if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
|
278 |
+
attn_mask_ = self.attn_mask(x.size(0)) # LND
|
279 |
+
|
280 |
+
attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None
|
281 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
|
282 |
+
|
283 |
+
def forward(self, x_tuple:tuple):
|
284 |
+
x, video_frame = x_tuple
|
285 |
+
x = x + self.attention(self.ln_1(x))
|
286 |
+
x = x + self.mlp(self.ln_2(x))
|
287 |
+
return (x, video_frame)
|
288 |
+
|
289 |
+
def visualize_attention(self, x: torch.Tensor):
|
290 |
+
attn_outputs, attn_weights = self.attn(x, x, x, need_weights=True, attn_mask=None)
|
291 |
+
return attn_outputs, attn_weights
|
292 |
+
|
293 |
+
def visualize_forward(self, x_tuple:tuple):
|
294 |
+
x, video_frame = x_tuple
|
295 |
+
attn_outputs, attn_weights = self.visualize_attention(self.ln_1(x))
|
296 |
+
x = x + attn_outputs
|
297 |
+
x = x + self.mlp(self.ln_2(x))
|
298 |
+
return (x, video_frame, attn_weights)
|
299 |
+
|
300 |
+
class TABLayer(nn.Module):
|
301 |
+
def __init__(self, d_model: int, n_head: int, attn_mask=None):
|
302 |
+
super().__init__()
|
303 |
+
|
304 |
+
self.attn = TABAttention(d_model, n_head)
|
305 |
+
self.ln_1 = LayerNorm(d_model)
|
306 |
+
self.mlp = nn.Sequential(OrderedDict([
|
307 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
308 |
+
("gelu", QuickGELU()),
|
309 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
310 |
+
]))
|
311 |
+
self.ln_2 = LayerNorm(d_model)
|
312 |
+
self.attn_mask = attn_mask
|
313 |
+
|
314 |
+
def attention(self, x: torch.Tensor, y: torch.Tensor):
|
315 |
+
attn_mask_ = self.attn_mask
|
316 |
+
if self.attn_mask is not None and hasattr(self.attn_mask, '__call__'):
|
317 |
+
attn_mask_ = self.attn_mask(x.size(0)) # LND
|
318 |
+
|
319 |
+
attn_mask_ = attn_mask_.to(dtype=x.dtype, device=x.device) if attn_mask_ is not None else None
|
320 |
+
return self.attn(x, y, y, need_weights=False, attn_mask=attn_mask_)[0]
|
321 |
+
|
322 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
323 |
+
x = self.attention(self.ln_1(x), self.ln_1(y))
|
324 |
+
x = x + self.mlp(self.ln_2(x))
|
325 |
+
return x
|
326 |
+
|
327 |
+
def visualize_attention(self, x: torch.Tensor, y: torch.Tensor, gt_attention_map):
|
328 |
+
attn_outputs, attn_weights = self.attn(x, y, y, gt_attention_map=gt_attention_map, need_weights=True, attn_mask=None)
|
329 |
+
return attn_outputs, attn_weights
|
330 |
+
|
331 |
+
def visualize_forward(self, x: torch.Tensor, y: torch.Tensor, gt_attention_map):
|
332 |
+
attn_outputs, attn_weights = self.visualize_attention(self.ln_1(x), self.ln_1(y), gt_attention_map)
|
333 |
+
x = attn_outputs
|
334 |
+
x = x + self.mlp(self.ln_2(x))
|
335 |
+
return (x, attn_weights)
|
336 |
+
|
337 |
+
class visionTransformer(nn.Module):
|
338 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask = None):
|
339 |
+
super().__init__()
|
340 |
+
self.width = width
|
341 |
+
self.layers = layers
|
342 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) if i < (layers - 1) else TABLayer(width, 1, attn_mask) for i in range(layers)])
|
343 |
+
|
344 |
+
def forward(self, x: torch.Tensor, video_frame=-1):
|
345 |
+
return self.resblocks((x, video_frame))[0]
|
346 |
+
|
347 |
+
class Transformer(nn.Module):
|
348 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask = None):
|
349 |
+
super().__init__()
|
350 |
+
self.width = width
|
351 |
+
self.layers = layers
|
352 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
353 |
+
|
354 |
+
def forward(self, x: torch.Tensor, video_frame=-1):
|
355 |
+
return self.resblocks((x, video_frame))[0]
|
356 |
+
|
357 |
+
|
358 |
+
class VisualTransformer(nn.Module):
|
359 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int,
|
360 |
+
linear_patch: str = '2d', intra_layers: int = 9):
|
361 |
+
super().__init__()
|
362 |
+
self.input_resolution = input_resolution
|
363 |
+
self.output_dim = output_dim
|
364 |
+
self.intra_layers = intra_layers
|
365 |
+
|
366 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
367 |
+
|
368 |
+
scale = width ** -0.5
|
369 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
370 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
371 |
+
self.ln_pre = LayerNorm(width)
|
372 |
+
|
373 |
+
self.joint_positional_embedding = nn.Parameter(scale * torch.randn(2 * ((input_resolution // patch_size) ** 2 + 1), width))
|
374 |
+
self.bef_embedding = nn.Parameter(scale * torch.randn(width))
|
375 |
+
self.aft_embedding = nn.Parameter(scale * torch.randn(width))
|
376 |
+
self.ln_mid = LayerNorm(width)
|
377 |
+
|
378 |
+
self.transformer = visionTransformer(width, layers, heads)
|
379 |
+
|
380 |
+
self.ln_post = LayerNorm(width)
|
381 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
382 |
+
|
383 |
+
# For 3D
|
384 |
+
assert linear_patch in ['2d', '3d']
|
385 |
+
self.linear_patch = linear_patch
|
386 |
+
if self.linear_patch == '3d':
|
387 |
+
self.conv2 = nn.Conv3d(in_channels=3, out_channels=width, kernel_size=(3, patch_size, patch_size),
|
388 |
+
stride=(1, patch_size, patch_size), padding=(1, 0, 0), bias=False)
|
389 |
+
|
390 |
+
def forward(self, x: torch.Tensor, left_gt_map, right_gt_map, video_frame=-1, visualize=False):
|
391 |
+
|
392 |
+
if self.linear_patch == '3d':
|
393 |
+
assert video_frame != -1
|
394 |
+
x_3d = x.reshape(-1, video_frame, x.shape[-3], x.shape[-2], x.shape[-1])
|
395 |
+
x_3d = x_3d.permute(0, 2, 1, 3, 4)
|
396 |
+
x_3d = self.conv2(x_3d) # shape = [*, width, frame, grid, grid]
|
397 |
+
x_3d = x_3d.permute(0, 2, 1, 3, 4) # shape = [*, frame, width, grid, grid]
|
398 |
+
x = x_3d.reshape(-1, x_3d.shape[-3], x_3d.shape[-2], x_3d.shape[-1]).contiguous() # shape = [*, width, grid, grid]
|
399 |
+
else:
|
400 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
401 |
+
|
402 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
403 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
404 |
+
|
405 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
406 |
+
x = x + self.positional_embedding.to(x.dtype)
|
407 |
+
x = self.ln_pre(x)
|
408 |
+
|
409 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
410 |
+
|
411 |
+
if visualize is True:
|
412 |
+
all_attn_weights = []
|
413 |
+
for i in range(self.intra_layers):
|
414 |
+
x, _, attn_weights = self.transformer.resblocks[i].visualize_forward((x, video_frame))
|
415 |
+
attn_weights = attn_weights.view(x.size(1) // video_frame, -1, attn_weights.size(-2),
|
416 |
+
attn_weights.size(-1))
|
417 |
+
all_attn_weights.append(attn_weights)
|
418 |
+
else:
|
419 |
+
for i in range(self.intra_layers):
|
420 |
+
x = self.transformer.resblocks[i]((x, video_frame))[0]
|
421 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
422 |
+
|
423 |
+
bs = x.size(0) // video_frame
|
424 |
+
x = x.view(bs, video_frame, x.size(-2), x.size(-1))
|
425 |
+
x = torch.cat([x[:, 0] + self.bef_embedding.to(x.dtype),
|
426 |
+
x[:, 1] + self.aft_embedding.to(x.dtype)], dim=1)
|
427 |
+
|
428 |
+
x = x + self.joint_positional_embedding.to(x.dtype)
|
429 |
+
x = self.ln_mid(x)
|
430 |
+
|
431 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
432 |
+
|
433 |
+
if visualize is True:
|
434 |
+
for i in range(self.intra_layers, self.transformer.layers - 1):
|
435 |
+
x, _, attn_weights = self.transformer.resblocks[i].visualize_forward((x, video_frame))
|
436 |
+
all_attn_weights.append(attn_weights)
|
437 |
+
cls_index = int(x.size(0) / 2)
|
438 |
+
left_features, left_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[:cls_index, :, :], x[cls_index:, :, :], right_gt_map)
|
439 |
+
right_features, right_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[cls_index:, :, :], x[:cls_index, :, :], left_gt_map)
|
440 |
+
|
441 |
+
all_attn_weights.append(left_attn_weights)
|
442 |
+
all_attn_weights.append(right_attn_weights)
|
443 |
+
else:
|
444 |
+
for i in range(self.intra_layers, self.transformer.layers - 1):
|
445 |
+
x = self.transformer.resblocks[i]((x, video_frame))[0]
|
446 |
+
cls_index = int(x.size(0) / 2)
|
447 |
+
left_features, left_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[:cls_index, :, :], x[cls_index:, :, :], right_gt_map)
|
448 |
+
right_features, right_attn_weights = self.transformer.resblocks[-1].visualize_forward(x[cls_index:, :, :], x[:cls_index, :, :], left_gt_map)
|
449 |
+
|
450 |
+
left_features = left_features.permute(1, 0, 2) # LND -> NLD
|
451 |
+
right_features = right_features.permute(1, 0, 2) # LND -> NLD
|
452 |
+
x = torch.cat([left_features, right_features], 1)
|
453 |
+
|
454 |
+
# Move the three lines below to `encode_image` for entire hidden sequence
|
455 |
+
# x = self.ln_post(x[:, 0, :])
|
456 |
+
# if self.proj is not None:
|
457 |
+
# x = x @ self.proj
|
458 |
+
|
459 |
+
if visualize is True:
|
460 |
+
return x, all_attn_weights
|
461 |
+
return x, left_attn_weights, right_attn_weights
|
462 |
+
|
463 |
+
|
464 |
+
class CLIP(nn.Module):
|
465 |
+
def __init__(self,
|
466 |
+
embed_dim: int,
|
467 |
+
# vision
|
468 |
+
image_resolution: int,
|
469 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
470 |
+
vision_width: int,
|
471 |
+
vision_patch_size: int,
|
472 |
+
# text
|
473 |
+
context_length: int,
|
474 |
+
vocab_size: int,
|
475 |
+
transformer_width: int,
|
476 |
+
transformer_heads: int,
|
477 |
+
transformer_layers: int,
|
478 |
+
# vision linear of patch
|
479 |
+
linear_patch: str = '2d',
|
480 |
+
intra_layers: int = 9,
|
481 |
+
):
|
482 |
+
super().__init__()
|
483 |
+
|
484 |
+
self.context_length = context_length
|
485 |
+
|
486 |
+
vision_heads = vision_width // 64
|
487 |
+
self.visual = VisualTransformer(
|
488 |
+
input_resolution=image_resolution,
|
489 |
+
patch_size=vision_patch_size,
|
490 |
+
width=vision_width,
|
491 |
+
layers=vision_layers,
|
492 |
+
heads=vision_heads,
|
493 |
+
output_dim=embed_dim,
|
494 |
+
linear_patch=linear_patch,
|
495 |
+
intra_layers=intra_layers,
|
496 |
+
)
|
497 |
+
|
498 |
+
self.transformer = Transformer(
|
499 |
+
width=transformer_width,
|
500 |
+
layers=transformer_layers,
|
501 |
+
heads=transformer_heads,
|
502 |
+
attn_mask=self.build_attention_mask
|
503 |
+
)
|
504 |
+
|
505 |
+
self.vocab_size = vocab_size
|
506 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
507 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
508 |
+
self.ln_final = LayerNorm(transformer_width)
|
509 |
+
|
510 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
511 |
+
self.logit_scale = nn.Parameter(torch.ones([]))
|
512 |
+
|
513 |
+
self.initialize_parameters()
|
514 |
+
|
515 |
+
def initialize_parameters(self):
|
516 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
517 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
518 |
+
|
519 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
520 |
+
attn_std = self.transformer.width ** -0.5
|
521 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
522 |
+
for block in self.transformer.resblocks:
|
523 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
524 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
525 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
526 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
527 |
+
|
528 |
+
if self.text_projection is not None:
|
529 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
530 |
+
|
531 |
+
@staticmethod
|
532 |
+
def get_config(pretrained_clip_name="ViT-B/32"):
|
533 |
+
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViT-B-32.pt")
|
534 |
+
if pretrained_clip_name in _MODELS and pretrained_clip_name in _PT_NAME:
|
535 |
+
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[pretrained_clip_name])
|
536 |
+
|
537 |
+
if pretrained_clip_name in ["ViT-B/32", "ViT-B/16"] and os.path.exists(model_path):
|
538 |
+
pass
|
539 |
+
else:
|
540 |
+
if pretrained_clip_name in _MODELS:
|
541 |
+
model_path = _download(_MODELS[pretrained_clip_name])
|
542 |
+
elif os.path.isfile(pretrained_clip_name):
|
543 |
+
model_path = pretrained_clip_name
|
544 |
+
else:
|
545 |
+
raise RuntimeError(f"Model {pretrained_clip_name} not found; available models = {available_models()}")
|
546 |
+
|
547 |
+
try:
|
548 |
+
# loading JIT archive
|
549 |
+
model = torch.jit.load(model_path, map_location="cpu").eval()
|
550 |
+
state_dict = model.state_dict()
|
551 |
+
except RuntimeError:
|
552 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
553 |
+
|
554 |
+
return state_dict
|
555 |
+
|
556 |
+
def build_attention_mask(self, context_length):
|
557 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
558 |
+
# pytorch uses additive attention mask; fill with -inf
|
559 |
+
mask = torch.zeros(context_length, context_length)
|
560 |
+
mask.fill_(float("-inf"))
|
561 |
+
mask.triu_(1) # zero out the lower diagonal
|
562 |
+
return mask
|
563 |
+
|
564 |
+
@property
|
565 |
+
def dtype(self):
|
566 |
+
return self.visual.conv1.weight.dtype
|
567 |
+
|
568 |
+
def encode_image(self, image, left_gt_map, right_gt_map, return_hidden=False, video_frame=-1):
|
569 |
+
hidden, left_map, right_map = self.visual(image.type(self.dtype), left_gt_map, right_gt_map, video_frame=video_frame)
|
570 |
+
hidden = self.visual.ln_post(hidden) @ self.visual.proj
|
571 |
+
|
572 |
+
cls_index = int(hidden.size(1) / 2)
|
573 |
+
hidden2 = torch.cat([hidden[:, 0, :].unsqueeze(1), hidden[:, cls_index, :].unsqueeze(1)], 1)
|
574 |
+
x = torch.mean(hidden2, 1)
|
575 |
+
|
576 |
+
if return_hidden:
|
577 |
+
return x, hidden2, left_map, right_map
|
578 |
+
|
579 |
+
return x, left_map, right_map
|
580 |
+
|
581 |
+
def encode_text(self, text, return_hidden=False):
|
582 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
583 |
+
|
584 |
+
pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype)
|
585 |
+
x = x + pos_emd
|
586 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
587 |
+
x = self.transformer(x)
|
588 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
589 |
+
|
590 |
+
hidden = self.ln_final(x).type(self.dtype) @ self.text_projection
|
591 |
+
|
592 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
593 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
594 |
+
x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)]
|
595 |
+
|
596 |
+
if return_hidden:
|
597 |
+
return x, hidden
|
598 |
+
|
599 |
+
return x
|
600 |
+
|
601 |
+
def forward(self, image, text):
|
602 |
+
image_features = self.encode_image(image)
|
603 |
+
text_features = self.encode_text(text)
|
604 |
+
|
605 |
+
# normalized features
|
606 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
607 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
608 |
+
|
609 |
+
# cosine similarity as logits
|
610 |
+
logit_scale = self.logit_scale.exp()
|
611 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
612 |
+
logits_per_text = logit_scale * text_features @ image_features.t()
|
613 |
+
|
614 |
+
# shape = [global_batch_size, global_batch_size]
|
615 |
+
return logits_per_image, logits_per_text
|
616 |
+
|
617 |
+
|
618 |
+
def convert_weights(model: nn.Module):
|
619 |
+
"""Convert applicable model parameters to fp16"""
|
620 |
+
|
621 |
+
def _convert_weights_to_fp16(l):
|
622 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear)):
|
623 |
+
l.weight.data = l.weight.data.half()
|
624 |
+
if l.bias is not None:
|
625 |
+
l.bias.data = l.bias.data.half()
|
626 |
+
|
627 |
+
if isinstance(l, nn.MultiheadAttention):
|
628 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
629 |
+
tensor = getattr(l, attr)
|
630 |
+
if tensor is not None:
|
631 |
+
tensor.data = tensor.data.half()
|
632 |
+
|
633 |
+
for name in ["text_projection", "proj"]:
|
634 |
+
if hasattr(l, name):
|
635 |
+
attr = getattr(l, name)
|
636 |
+
if attr is not None:
|
637 |
+
attr.data = attr.data.half()
|
638 |
+
|
639 |
+
model.apply(_convert_weights_to_fp16)
|
640 |
+
|
641 |
+
|
642 |
+
def build_model(state_dict: dict):
|
643 |
+
vit = "visual.proj" in state_dict
|
644 |
+
|
645 |
+
if vit:
|
646 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
647 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
648 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
649 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
650 |
+
image_resolution = vision_patch_size * grid_size
|
651 |
+
else:
|
652 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
653 |
+
vision_layers = tuple(counts)
|
654 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
655 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
656 |
+
vision_patch_size = None
|
657 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
658 |
+
image_resolution = output_width * 32
|
659 |
+
|
660 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
661 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
662 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
663 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
664 |
+
transformer_heads = transformer_width // 64
|
665 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
666 |
+
|
667 |
+
model = CLIP(
|
668 |
+
embed_dim,
|
669 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
670 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
671 |
+
)
|
672 |
+
|
673 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
674 |
+
if key in state_dict:
|
675 |
+
del state_dict[key]
|
676 |
+
|
677 |
+
convert_weights(model)
|
678 |
+
model.load_state_dict(state_dict)
|
679 |
+
return model.eval()
|
utils/module_cross.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch BERT model."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import os
|
23 |
+
import copy
|
24 |
+
import json
|
25 |
+
import math
|
26 |
+
import logging
|
27 |
+
import tarfile
|
28 |
+
import tempfile
|
29 |
+
import shutil
|
30 |
+
|
31 |
+
import torch
|
32 |
+
from torch import nn
|
33 |
+
import torch.nn.functional as F
|
34 |
+
from .file_utils import cached_path
|
35 |
+
from .until_config import PretrainedConfig
|
36 |
+
from .until_module import PreTrainedModel, LayerNorm, ACT2FN
|
37 |
+
|
38 |
+
logger = logging.getLogger(__name__)
|
39 |
+
|
40 |
+
PRETRAINED_MODEL_ARCHIVE_MAP = {}
|
41 |
+
CONFIG_NAME = 'cross_config.json'
|
42 |
+
WEIGHTS_NAME = 'cross_pytorch_model.bin'
|
43 |
+
|
44 |
+
|
45 |
+
class CrossConfig(PretrainedConfig):
|
46 |
+
"""Configuration class to store the configuration of a `CrossModel`.
|
47 |
+
"""
|
48 |
+
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
49 |
+
config_name = CONFIG_NAME
|
50 |
+
weights_name = WEIGHTS_NAME
|
51 |
+
def __init__(self,
|
52 |
+
vocab_size_or_config_json_file,
|
53 |
+
hidden_size=768,
|
54 |
+
num_hidden_layers=12,
|
55 |
+
num_attention_heads=12,
|
56 |
+
intermediate_size=3072,
|
57 |
+
hidden_act="gelu",
|
58 |
+
hidden_dropout_prob=0.1,
|
59 |
+
attention_probs_dropout_prob=0.1,
|
60 |
+
max_position_embeddings=512,
|
61 |
+
type_vocab_size=2,
|
62 |
+
initializer_range=0.02):
|
63 |
+
"""Constructs CrossConfig.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`.
|
67 |
+
hidden_size: Size of the encoder layers and the pooler layer.
|
68 |
+
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
69 |
+
num_attention_heads: Number of attention heads for each attention layer in
|
70 |
+
the Transformer encoder.
|
71 |
+
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
72 |
+
layer in the Transformer encoder.
|
73 |
+
hidden_act: The non-linear activation function (function or string) in the
|
74 |
+
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
75 |
+
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
76 |
+
layers in the embeddings, encoder, and pooler.
|
77 |
+
attention_probs_dropout_prob: The dropout ratio for the attention
|
78 |
+
probabilities.
|
79 |
+
max_position_embeddings: The maximum sequence length that this model might
|
80 |
+
ever be used with. Typically set this to something large just in case
|
81 |
+
(e.g., 512 or 1024 or 2048).
|
82 |
+
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
83 |
+
`CrossModel`.
|
84 |
+
initializer_range: The sttdev of the truncated_normal_initializer for
|
85 |
+
initializing all weight matrices.
|
86 |
+
"""
|
87 |
+
if isinstance(vocab_size_or_config_json_file, str):
|
88 |
+
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
89 |
+
json_config = json.loads(reader.read())
|
90 |
+
for key, value in json_config.items():
|
91 |
+
self.__dict__[key] = value
|
92 |
+
elif isinstance(vocab_size_or_config_json_file, int):
|
93 |
+
self.vocab_size = vocab_size_or_config_json_file
|
94 |
+
self.hidden_size = hidden_size
|
95 |
+
self.num_hidden_layers = num_hidden_layers
|
96 |
+
self.num_attention_heads = num_attention_heads
|
97 |
+
self.hidden_act = hidden_act
|
98 |
+
self.intermediate_size = intermediate_size
|
99 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
100 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
101 |
+
self.max_position_embeddings = max_position_embeddings
|
102 |
+
self.type_vocab_size = type_vocab_size
|
103 |
+
self.initializer_range = initializer_range
|
104 |
+
else:
|
105 |
+
raise ValueError("First argument must be either a vocabulary size (int)"
|
106 |
+
"or the path to a pretrained model config file (str)")
|
107 |
+
|
108 |
+
|
109 |
+
class CrossEmbeddings(nn.Module):
|
110 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
111 |
+
"""
|
112 |
+
def __init__(self, config):
|
113 |
+
super(CrossEmbeddings, self).__init__()
|
114 |
+
|
115 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
116 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
117 |
+
|
118 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
119 |
+
# any TensorFlow checkpoint file
|
120 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
121 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
122 |
+
|
123 |
+
def forward(self, concat_embeddings, concat_type=None):
|
124 |
+
|
125 |
+
batch_size, seq_length = concat_embeddings.size(0), concat_embeddings.size(1)
|
126 |
+
if concat_type is None:
|
127 |
+
concat_type = torch.zeros(batch_size, concat_type).to(concat_embeddings.device)
|
128 |
+
|
129 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=concat_embeddings.device)
|
130 |
+
position_ids = position_ids.unsqueeze(0).expand(concat_embeddings.size(0), -1)
|
131 |
+
|
132 |
+
token_type_embeddings = self.token_type_embeddings(concat_type)
|
133 |
+
position_embeddings = self.position_embeddings(position_ids)
|
134 |
+
|
135 |
+
embeddings = concat_embeddings + position_embeddings + token_type_embeddings
|
136 |
+
embeddings = self.LayerNorm(embeddings)
|
137 |
+
embeddings = self.dropout(embeddings)
|
138 |
+
return embeddings
|
139 |
+
|
140 |
+
class CrossSelfAttention(nn.Module):
|
141 |
+
def __init__(self, config):
|
142 |
+
super(CrossSelfAttention, self).__init__()
|
143 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
144 |
+
raise ValueError(
|
145 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
146 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
147 |
+
self.num_attention_heads = config.num_attention_heads
|
148 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
149 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
150 |
+
|
151 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
152 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
153 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
154 |
+
|
155 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
156 |
+
|
157 |
+
def transpose_for_scores(self, x):
|
158 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
159 |
+
x = x.view(*new_x_shape)
|
160 |
+
return x.permute(0, 2, 1, 3)
|
161 |
+
|
162 |
+
def forward(self, hidden_states, attention_mask):
|
163 |
+
mixed_query_layer = self.query(hidden_states)
|
164 |
+
mixed_key_layer = self.key(hidden_states)
|
165 |
+
mixed_value_layer = self.value(hidden_states)
|
166 |
+
|
167 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
168 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
169 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
170 |
+
|
171 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
172 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
173 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
174 |
+
# Apply the attention mask is (precomputed for all layers in CrossModel forward() function)
|
175 |
+
attention_scores = attention_scores + attention_mask
|
176 |
+
|
177 |
+
# Normalize the attention scores to probabilities.
|
178 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
179 |
+
|
180 |
+
# This is actually dropping out entire tokens to attend to, which might
|
181 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
182 |
+
attention_probs = self.dropout(attention_probs)
|
183 |
+
|
184 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
185 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
186 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
187 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
188 |
+
return context_layer
|
189 |
+
|
190 |
+
|
191 |
+
class CrossSelfOutput(nn.Module):
|
192 |
+
def __init__(self, config):
|
193 |
+
super(CrossSelfOutput, self).__init__()
|
194 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
195 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
196 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
197 |
+
|
198 |
+
def forward(self, hidden_states, input_tensor):
|
199 |
+
hidden_states = self.dense(hidden_states)
|
200 |
+
hidden_states = self.dropout(hidden_states)
|
201 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
202 |
+
return hidden_states
|
203 |
+
|
204 |
+
|
205 |
+
class CrossAttention(nn.Module):
|
206 |
+
def __init__(self, config):
|
207 |
+
super(CrossAttention, self).__init__()
|
208 |
+
self.self = CrossSelfAttention(config)
|
209 |
+
self.output = CrossSelfOutput(config)
|
210 |
+
|
211 |
+
def forward(self, input_tensor, attention_mask):
|
212 |
+
self_output = self.self(input_tensor, attention_mask)
|
213 |
+
attention_output = self.output(self_output, input_tensor)
|
214 |
+
return attention_output
|
215 |
+
|
216 |
+
|
217 |
+
class CrossIntermediate(nn.Module):
|
218 |
+
def __init__(self, config):
|
219 |
+
super(CrossIntermediate, self).__init__()
|
220 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
221 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
222 |
+
if isinstance(config.hidden_act, str) else config.hidden_act
|
223 |
+
|
224 |
+
def forward(self, hidden_states):
|
225 |
+
hidden_states = self.dense(hidden_states)
|
226 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
227 |
+
return hidden_states
|
228 |
+
|
229 |
+
|
230 |
+
class CrossOutput(nn.Module):
|
231 |
+
def __init__(self, config):
|
232 |
+
super(CrossOutput, self).__init__()
|
233 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
234 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
235 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
236 |
+
|
237 |
+
def forward(self, hidden_states, input_tensor):
|
238 |
+
hidden_states = self.dense(hidden_states)
|
239 |
+
hidden_states = self.dropout(hidden_states)
|
240 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
241 |
+
return hidden_states
|
242 |
+
|
243 |
+
|
244 |
+
class CrossLayer(nn.Module):
|
245 |
+
def __init__(self, config):
|
246 |
+
super(CrossLayer, self).__init__()
|
247 |
+
self.attention = CrossAttention(config)
|
248 |
+
self.intermediate = CrossIntermediate(config)
|
249 |
+
self.output = CrossOutput(config)
|
250 |
+
|
251 |
+
def forward(self, hidden_states, attention_mask):
|
252 |
+
attention_output = self.attention(hidden_states, attention_mask)
|
253 |
+
intermediate_output = self.intermediate(attention_output)
|
254 |
+
layer_output = self.output(intermediate_output, attention_output)
|
255 |
+
return layer_output
|
256 |
+
|
257 |
+
|
258 |
+
class CrossEncoder(nn.Module):
|
259 |
+
def __init__(self, config):
|
260 |
+
super(CrossEncoder, self).__init__()
|
261 |
+
layer = CrossLayer(config)
|
262 |
+
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
|
263 |
+
|
264 |
+
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
|
265 |
+
all_encoder_layers = []
|
266 |
+
for layer_module in self.layer:
|
267 |
+
hidden_states = layer_module(hidden_states, attention_mask)
|
268 |
+
if output_all_encoded_layers:
|
269 |
+
all_encoder_layers.append(hidden_states)
|
270 |
+
if not output_all_encoded_layers:
|
271 |
+
all_encoder_layers.append(hidden_states)
|
272 |
+
return all_encoder_layers
|
273 |
+
|
274 |
+
|
275 |
+
class CrossPooler(nn.Module):
|
276 |
+
def __init__(self, config):
|
277 |
+
super(CrossPooler, self).__init__()
|
278 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
279 |
+
self.activation = nn.Tanh()
|
280 |
+
|
281 |
+
def forward(self, hidden_states):
|
282 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
283 |
+
# to the first token.
|
284 |
+
first_token_tensor = hidden_states[:, 0]
|
285 |
+
pooled_output = self.dense(first_token_tensor)
|
286 |
+
pooled_output = self.activation(pooled_output)
|
287 |
+
return pooled_output
|
288 |
+
|
289 |
+
|
290 |
+
class CrossPredictionHeadTransform(nn.Module):
|
291 |
+
def __init__(self, config):
|
292 |
+
super(CrossPredictionHeadTransform, self).__init__()
|
293 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
294 |
+
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
295 |
+
if isinstance(config.hidden_act, str) else config.hidden_act
|
296 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
297 |
+
|
298 |
+
def forward(self, hidden_states):
|
299 |
+
hidden_states = self.dense(hidden_states)
|
300 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
301 |
+
hidden_states = self.LayerNorm(hidden_states)
|
302 |
+
return hidden_states
|
303 |
+
|
304 |
+
|
305 |
+
class CrossLMPredictionHead(nn.Module):
|
306 |
+
def __init__(self, config, cross_model_embedding_weights):
|
307 |
+
super(CrossLMPredictionHead, self).__init__()
|
308 |
+
self.transform = CrossPredictionHeadTransform(config)
|
309 |
+
|
310 |
+
# The output weights are the same as the input embeddings, but there is
|
311 |
+
# an output-only bias for each token.
|
312 |
+
self.decoder = nn.Linear(cross_model_embedding_weights.size(1),
|
313 |
+
cross_model_embedding_weights.size(0),
|
314 |
+
bias=False)
|
315 |
+
self.decoder.weight = cross_model_embedding_weights
|
316 |
+
self.bias = nn.Parameter(torch.zeros(cross_model_embedding_weights.size(0)))
|
317 |
+
|
318 |
+
def forward(self, hidden_states):
|
319 |
+
hidden_states = self.transform(hidden_states)
|
320 |
+
hidden_states = self.decoder(hidden_states) + self.bias
|
321 |
+
return hidden_states
|
322 |
+
|
323 |
+
|
324 |
+
class CrossOnlyMLMHead(nn.Module):
|
325 |
+
def __init__(self, config, cross_model_embedding_weights):
|
326 |
+
super(CrossOnlyMLMHead, self).__init__()
|
327 |
+
self.predictions = CrossLMPredictionHead(config, cross_model_embedding_weights)
|
328 |
+
|
329 |
+
def forward(self, sequence_output):
|
330 |
+
prediction_scores = self.predictions(sequence_output)
|
331 |
+
return prediction_scores
|
332 |
+
|
333 |
+
|
334 |
+
class CrossOnlyNSPHead(nn.Module):
|
335 |
+
def __init__(self, config):
|
336 |
+
super(CrossOnlyNSPHead, self).__init__()
|
337 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
338 |
+
|
339 |
+
def forward(self, pooled_output):
|
340 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
341 |
+
return seq_relationship_score
|
342 |
+
|
343 |
+
|
344 |
+
class CrossPreTrainingHeads(nn.Module):
|
345 |
+
def __init__(self, config, cross_model_embedding_weights):
|
346 |
+
super(CrossPreTrainingHeads, self).__init__()
|
347 |
+
self.predictions = CrossLMPredictionHead(config, cross_model_embedding_weights)
|
348 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
349 |
+
|
350 |
+
def forward(self, sequence_output, pooled_output):
|
351 |
+
prediction_scores = self.predictions(sequence_output)
|
352 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
353 |
+
return prediction_scores, seq_relationship_score
|
354 |
+
|
355 |
+
|
356 |
+
class CrossModel(PreTrainedModel):
|
357 |
+
def __init__(self, config):
|
358 |
+
super(CrossModel, self).__init__(config)
|
359 |
+
self.embeddings = CrossEmbeddings(config)
|
360 |
+
self.encoder = CrossEncoder(config)
|
361 |
+
self.pooler = CrossPooler(config)
|
362 |
+
self.apply(self.init_weights)
|
363 |
+
|
364 |
+
def forward(self, concat_input, concat_type=None, attention_mask=None, output_all_encoded_layers=True):
|
365 |
+
|
366 |
+
if attention_mask is None:
|
367 |
+
attention_mask = torch.ones(concat_input.size(0), concat_input.size(1))
|
368 |
+
if concat_type is None:
|
369 |
+
concat_type = torch.zeros_like(attention_mask)
|
370 |
+
|
371 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
372 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
373 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
374 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
375 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
376 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
377 |
+
|
378 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
379 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
380 |
+
# positions we want to attend and -10000.0 for masked positions.
|
381 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
382 |
+
# effectively the same as removing these entirely.
|
383 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
384 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
385 |
+
|
386 |
+
embedding_output = self.embeddings(concat_input, concat_type)
|
387 |
+
encoded_layers = self.encoder(embedding_output,
|
388 |
+
extended_attention_mask,
|
389 |
+
output_all_encoded_layers=output_all_encoded_layers)
|
390 |
+
sequence_output = encoded_layers[-1]
|
391 |
+
pooled_output = self.pooler(sequence_output)
|
392 |
+
if not output_all_encoded_layers:
|
393 |
+
encoded_layers = encoded_layers[-1]
|
394 |
+
return encoded_layers, pooled_output
|
utils/module_decoder.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch BERT model."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import os
|
23 |
+
import copy
|
24 |
+
import json
|
25 |
+
import math
|
26 |
+
import logging
|
27 |
+
import tarfile
|
28 |
+
import tempfile
|
29 |
+
import shutil
|
30 |
+
import numpy as np
|
31 |
+
|
32 |
+
import torch
|
33 |
+
from torch import nn
|
34 |
+
from .file_utils import cached_path
|
35 |
+
from .until_config import PretrainedConfig
|
36 |
+
from .until_module import PreTrainedModel, LayerNorm, ACT2FN
|
37 |
+
|
38 |
+
logger = logging.getLogger(__name__)
|
39 |
+
|
40 |
+
PRETRAINED_MODEL_ARCHIVE_MAP = {}
|
41 |
+
CONFIG_NAME = 'decoder_config.json'
|
42 |
+
WEIGHTS_NAME = 'decoder_pytorch_model.bin'
|
43 |
+
|
44 |
+
|
45 |
+
class DecoderConfig(PretrainedConfig):
|
46 |
+
"""Configuration class to store the configuration of a `DecoderModel`.
|
47 |
+
"""
|
48 |
+
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
49 |
+
config_name = CONFIG_NAME
|
50 |
+
weights_name = WEIGHTS_NAME
|
51 |
+
def __init__(self,
|
52 |
+
vocab_size_or_config_json_file,
|
53 |
+
hidden_size=768,
|
54 |
+
num_hidden_layers=12,
|
55 |
+
num_attention_heads=12,
|
56 |
+
intermediate_size=3072,
|
57 |
+
hidden_act="gelu",
|
58 |
+
hidden_dropout_prob=0.1,
|
59 |
+
attention_probs_dropout_prob=0.1,
|
60 |
+
type_vocab_size=2,
|
61 |
+
initializer_range=0.02,
|
62 |
+
max_target_embeddings=128,
|
63 |
+
num_decoder_layers=1):
|
64 |
+
"""Constructs DecoderConfig.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `DecoderModel`.
|
68 |
+
hidden_size: Size of the encoder layers and the pooler layer.
|
69 |
+
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
70 |
+
num_attention_heads: Number of attention heads for each attention layer in
|
71 |
+
the Transformer encoder.
|
72 |
+
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
73 |
+
layer in the Transformer encoder.
|
74 |
+
hidden_act: The non-linear activation function (function or string) in the
|
75 |
+
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
|
76 |
+
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
77 |
+
layers in the embeddings, encoder, and pooler.
|
78 |
+
attention_probs_dropout_prob: The dropout ratio for the attention
|
79 |
+
probabilities.
|
80 |
+
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
81 |
+
`DecoderModel`.
|
82 |
+
initializer_range: The sttdev of the truncated_normal_initializer for
|
83 |
+
initializing all weight matrices.
|
84 |
+
max_target_embeddings: The maximum sequence length that this model might
|
85 |
+
ever be used with. Typically set this to something large just in case
|
86 |
+
(e.g., 512 or 1024 or 2048).
|
87 |
+
num_decoder_layers:
|
88 |
+
"""
|
89 |
+
if isinstance(vocab_size_or_config_json_file, str):
|
90 |
+
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
91 |
+
json_config = json.loads(reader.read())
|
92 |
+
for key, value in json_config.items():
|
93 |
+
self.__dict__[key] = value
|
94 |
+
elif isinstance(vocab_size_or_config_json_file, int):
|
95 |
+
self.vocab_size = vocab_size_or_config_json_file
|
96 |
+
self.hidden_size = hidden_size
|
97 |
+
self.num_hidden_layers = num_hidden_layers
|
98 |
+
self.num_attention_heads = num_attention_heads
|
99 |
+
self.hidden_act = hidden_act
|
100 |
+
self.intermediate_size = intermediate_size
|
101 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
102 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
103 |
+
self.type_vocab_size = type_vocab_size
|
104 |
+
self.initializer_range = initializer_range
|
105 |
+
self.max_target_embeddings = max_target_embeddings
|
106 |
+
self.num_decoder_layers = num_decoder_layers
|
107 |
+
else:
|
108 |
+
raise ValueError("First argument must be either a vocabulary size (int)"
|
109 |
+
"or the path to a pretrained model config file (str)")
|
110 |
+
|
111 |
+
|
112 |
+
class BertSelfOutput(nn.Module):
|
113 |
+
def __init__(self, config):
|
114 |
+
super(BertSelfOutput, self).__init__()
|
115 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
116 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
117 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
118 |
+
|
119 |
+
def forward(self, hidden_states, input_tensor):
|
120 |
+
hidden_states = self.dense(hidden_states)
|
121 |
+
hidden_states = self.dropout(hidden_states)
|
122 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
123 |
+
return hidden_states
|
124 |
+
|
125 |
+
class BertIntermediate(nn.Module):
|
126 |
+
def __init__(self, config):
|
127 |
+
super(BertIntermediate, self).__init__()
|
128 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
129 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
130 |
+
if isinstance(config.hidden_act, str) else config.hidden_act
|
131 |
+
|
132 |
+
def forward(self, hidden_states):
|
133 |
+
hidden_states = self.dense(hidden_states)
|
134 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
135 |
+
return hidden_states
|
136 |
+
|
137 |
+
|
138 |
+
class BertOutput(nn.Module):
|
139 |
+
def __init__(self, config):
|
140 |
+
super(BertOutput, self).__init__()
|
141 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
142 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
143 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
144 |
+
|
145 |
+
def forward(self, hidden_states, input_tensor):
|
146 |
+
hidden_states = self.dense(hidden_states)
|
147 |
+
hidden_states = self.dropout(hidden_states)
|
148 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
149 |
+
return hidden_states
|
150 |
+
|
151 |
+
|
152 |
+
class BertPredictionHeadTransform(nn.Module):
|
153 |
+
def __init__(self, config):
|
154 |
+
super(BertPredictionHeadTransform, self).__init__()
|
155 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
156 |
+
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
157 |
+
if isinstance(config.hidden_act, str) else config.hidden_act
|
158 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
159 |
+
|
160 |
+
def forward(self, hidden_states):
|
161 |
+
hidden_states = self.dense(hidden_states)
|
162 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
163 |
+
hidden_states = self.LayerNorm(hidden_states)
|
164 |
+
return hidden_states
|
165 |
+
|
166 |
+
|
167 |
+
class BertLMPredictionHead(nn.Module):
|
168 |
+
def __init__(self, config, decoder_model_embedding_weights):
|
169 |
+
super(BertLMPredictionHead, self).__init__()
|
170 |
+
self.transform = BertPredictionHeadTransform(config)
|
171 |
+
|
172 |
+
# The output weights are the same as the input embeddings, but there is
|
173 |
+
# an output-only bias for each token.
|
174 |
+
self.decoder = nn.Linear(decoder_model_embedding_weights.size(1),
|
175 |
+
decoder_model_embedding_weights.size(0),
|
176 |
+
bias=False)
|
177 |
+
self.decoder.weight = decoder_model_embedding_weights
|
178 |
+
self.bias = nn.Parameter(torch.zeros(decoder_model_embedding_weights.size(0)))
|
179 |
+
|
180 |
+
def forward(self, hidden_states):
|
181 |
+
hidden_states = self.transform(hidden_states)
|
182 |
+
hidden_states = self.decoder(hidden_states) + self.bias
|
183 |
+
return hidden_states
|
184 |
+
|
185 |
+
|
186 |
+
class BertOnlyMLMHead(nn.Module):
|
187 |
+
def __init__(self, config, decoder_model_embedding_weights):
|
188 |
+
super(BertOnlyMLMHead, self).__init__()
|
189 |
+
self.predictions = BertLMPredictionHead(config, decoder_model_embedding_weights)
|
190 |
+
|
191 |
+
def forward(self, sequence_output):
|
192 |
+
prediction_scores = self.predictions(sequence_output)
|
193 |
+
return prediction_scores
|
194 |
+
|
195 |
+
class MultiHeadAttention(nn.Module):
|
196 |
+
''' Multi-Head Attention module '''
|
197 |
+
|
198 |
+
def __init__(self, config):
|
199 |
+
super(MultiHeadAttention, self).__init__()
|
200 |
+
|
201 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
202 |
+
raise ValueError(
|
203 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
204 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
205 |
+
self.num_attention_heads = config.num_attention_heads
|
206 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
207 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
208 |
+
|
209 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
210 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
211 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
212 |
+
|
213 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
214 |
+
|
215 |
+
def transpose_for_scores(self, x):
|
216 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
217 |
+
x = x.view(*new_x_shape)
|
218 |
+
return x.permute(0, 2, 1, 3)
|
219 |
+
|
220 |
+
def forward(self, q, k, v, attention_mask):
|
221 |
+
mixed_query_layer = self.query(q)
|
222 |
+
mixed_key_layer = self.key(k)
|
223 |
+
mixed_value_layer = self.value(v)
|
224 |
+
|
225 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
226 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
227 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
228 |
+
|
229 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
230 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
231 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
232 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
233 |
+
attention_scores = attention_scores + attention_mask
|
234 |
+
|
235 |
+
# Normalize the attention scores to probabilities.
|
236 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
237 |
+
|
238 |
+
# This is actually dropping out entire tokens to attend to, which might
|
239 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
240 |
+
attention_probs = self.dropout(attention_probs)
|
241 |
+
|
242 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
243 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
244 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
245 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
246 |
+
|
247 |
+
return context_layer, attention_scores
|
248 |
+
|
249 |
+
class PositionwiseFeedForward(nn.Module):
|
250 |
+
''' A two-feed-forward-layer module '''
|
251 |
+
|
252 |
+
def __init__(self, d_in, d_hid, dropout=0.1):
|
253 |
+
super().__init__()
|
254 |
+
self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
|
255 |
+
self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
|
256 |
+
self.layer_norm = nn.LayerNorm(d_in)
|
257 |
+
self.dropout = nn.Dropout(dropout)
|
258 |
+
|
259 |
+
def forward(self, x):
|
260 |
+
residual = x
|
261 |
+
output = x.transpose(1, 2)
|
262 |
+
output = self.w_2(ACT2FN["gelu"](self.w_1(output)))
|
263 |
+
output = output.transpose(1, 2)
|
264 |
+
output = self.dropout(output)
|
265 |
+
output = self.layer_norm(output + residual)
|
266 |
+
return output
|
267 |
+
|
268 |
+
class DecoderAttention(nn.Module):
|
269 |
+
def __init__(self, config):
|
270 |
+
super(DecoderAttention, self).__init__()
|
271 |
+
self.att = MultiHeadAttention(config)
|
272 |
+
self.output = BertSelfOutput(config)
|
273 |
+
|
274 |
+
def forward(self, q, k, v, attention_mask):
|
275 |
+
att_output, attention_probs = self.att(q, k, v, attention_mask)
|
276 |
+
attention_output = self.output(att_output, q)
|
277 |
+
return attention_output, attention_probs
|
278 |
+
|
279 |
+
class EncoderLayer(nn.Module):
|
280 |
+
def __init__(self, config):
|
281 |
+
super(EncoderLayer, self).__init__()
|
282 |
+
self.slf_attn = DecoderAttention(config)
|
283 |
+
self.intermediate = BertIntermediate(config)
|
284 |
+
self.output = BertOutput(config)
|
285 |
+
|
286 |
+
def forward(self, dec_input, slf_attn_mask=None):
|
287 |
+
slf_output, slf_att_scores = self.slf_attn(dec_input, dec_input, dec_input, slf_attn_mask)
|
288 |
+
intermediate_output = self.intermediate(slf_output)
|
289 |
+
dec_output = self.output(intermediate_output, slf_output)
|
290 |
+
return dec_output, slf_att_scores
|
291 |
+
|
292 |
+
class DecoderLayer(nn.Module):
|
293 |
+
def __init__(self, config):
|
294 |
+
super(DecoderLayer, self).__init__()
|
295 |
+
self.slf_attn = DecoderAttention(config)
|
296 |
+
self.enc_attn = DecoderAttention(config)
|
297 |
+
self.intermediate = BertIntermediate(config)
|
298 |
+
self.output = BertOutput(config)
|
299 |
+
|
300 |
+
def forward(self, dec_input, enc_output, slf_attn_mask=None, dec_enc_attn_mask=None):
|
301 |
+
slf_output, _ = self.slf_attn(dec_input, dec_input, dec_input, slf_attn_mask)
|
302 |
+
dec_output, dec_att_scores = self.enc_attn(slf_output, enc_output, enc_output, dec_enc_attn_mask)
|
303 |
+
intermediate_output = self.intermediate(dec_output)
|
304 |
+
dec_output = self.output(intermediate_output, dec_output)
|
305 |
+
return dec_output, dec_att_scores
|
306 |
+
|
307 |
+
class DecoderEmbeddings(nn.Module):
|
308 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
309 |
+
"""
|
310 |
+
def __init__(self, config, decoder_word_embeddings_weight, decoder_position_embeddings_weight):
|
311 |
+
super(DecoderEmbeddings, self).__init__()
|
312 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
313 |
+
self.position_embeddings = nn.Embedding(config.max_target_embeddings, config.hidden_size)
|
314 |
+
self.word_embeddings.weight = decoder_word_embeddings_weight
|
315 |
+
self.position_embeddings.weight = decoder_position_embeddings_weight
|
316 |
+
|
317 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
318 |
+
# any TensorFlow checkpoint file
|
319 |
+
self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12)
|
320 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
321 |
+
|
322 |
+
def forward(self, input_ids):
|
323 |
+
seq_length = input_ids.size(1)
|
324 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
|
325 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
326 |
+
|
327 |
+
words_embeddings = self.word_embeddings(input_ids)
|
328 |
+
position_embeddings = self.position_embeddings(position_ids)
|
329 |
+
|
330 |
+
embeddings = words_embeddings + position_embeddings
|
331 |
+
embeddings = self.LayerNorm(embeddings)
|
332 |
+
embeddings = self.dropout(embeddings)
|
333 |
+
return embeddings
|
334 |
+
|
335 |
+
class Encoder(nn.Module):
|
336 |
+
def __init__(self, config):
|
337 |
+
super(Encoder, self).__init__()
|
338 |
+
layer = EncoderLayer(config)
|
339 |
+
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_decoder_layers)])
|
340 |
+
|
341 |
+
def forward(self, hidden_states, self_attn_mask, output_all_encoded_layers=False):
|
342 |
+
dec_att_scores = None
|
343 |
+
all_encoder_layers = []
|
344 |
+
all_dec_att_probs = []
|
345 |
+
for layer_module in self.layer:
|
346 |
+
hidden_states, dec_att_scores = layer_module(hidden_states, self_attn_mask)
|
347 |
+
if output_all_encoded_layers:
|
348 |
+
all_encoder_layers.append(hidden_states)
|
349 |
+
all_dec_att_probs.append(dec_att_scores)
|
350 |
+
if not output_all_encoded_layers:
|
351 |
+
all_encoder_layers.append(hidden_states)
|
352 |
+
all_dec_att_probs.append(dec_att_scores)
|
353 |
+
return all_encoder_layers, all_dec_att_probs
|
354 |
+
|
355 |
+
class Decoder(nn.Module):
|
356 |
+
def __init__(self, config):
|
357 |
+
super(Decoder, self).__init__()
|
358 |
+
layer = DecoderLayer(config)
|
359 |
+
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_decoder_layers)])
|
360 |
+
|
361 |
+
def forward(self, hidden_states, encoder_outs, self_attn_mask, attention_mask, output_all_encoded_layers=False):
|
362 |
+
dec_att_scores = None
|
363 |
+
all_encoder_layers = []
|
364 |
+
all_dec_att_probs = []
|
365 |
+
for i, layer_module in enumerate(self.layer):
|
366 |
+
if isinstance(encoder_outs, list):
|
367 |
+
hidden_states, dec_att_scores = layer_module(hidden_states, encoder_outs[i], self_attn_mask, attention_mask)
|
368 |
+
else:
|
369 |
+
hidden_states, dec_att_scores = layer_module(hidden_states, encoder_outs, self_attn_mask, attention_mask)
|
370 |
+
if output_all_encoded_layers:
|
371 |
+
all_encoder_layers.append(hidden_states)
|
372 |
+
all_dec_att_probs.append(dec_att_scores)
|
373 |
+
if not output_all_encoded_layers:
|
374 |
+
all_encoder_layers.append(hidden_states)
|
375 |
+
all_dec_att_probs.append(dec_att_scores)
|
376 |
+
return all_encoder_layers, all_dec_att_probs
|
377 |
+
|
378 |
+
class DecoderClassifier(nn.Module):
|
379 |
+
def __init__(self, config, embedding_weights):
|
380 |
+
super(DecoderClassifier, self).__init__()
|
381 |
+
self.cls = BertOnlyMLMHead(config, embedding_weights)
|
382 |
+
|
383 |
+
def forward(self, hidden_states):
|
384 |
+
cls_scores = self.cls(hidden_states)
|
385 |
+
return cls_scores
|
386 |
+
|
387 |
+
class DecoderModel(PreTrainedModel):
|
388 |
+
|
389 |
+
"""
|
390 |
+
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
391 |
+
is a :class:`TransformerDecoderLayer`.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
args (argparse.Namespace): parsed command-line arguments
|
395 |
+
final_norm (bool, optional): apply layer norm to the output of the
|
396 |
+
final decoder layer (default: True).
|
397 |
+
"""
|
398 |
+
|
399 |
+
def __init__(self, config, decoder_word_embeddings_weight, decoder_position_embeddings_weight):
|
400 |
+
super(DecoderModel, self).__init__(config)
|
401 |
+
self.config = config
|
402 |
+
self.max_target_length = config.max_target_embeddings
|
403 |
+
self.embeddings = DecoderEmbeddings(config, decoder_word_embeddings_weight, decoder_position_embeddings_weight)
|
404 |
+
self.decoder = Decoder(config)
|
405 |
+
self.encoder = Encoder(config)
|
406 |
+
self.classifier = DecoderClassifier(config, decoder_word_embeddings_weight)
|
407 |
+
self.apply(self.init_weights)
|
408 |
+
|
409 |
+
def forward(self, input_ids, encoder_outs=None, answer_mask=None, encoder_mask=None):
|
410 |
+
"""
|
411 |
+
Args:
|
412 |
+
input_ids (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for input feeding/teacher forcing
|
413 |
+
encoder_outs (Tensor, optional): output from the encoder, used for encoder-side attention
|
414 |
+
|
415 |
+
Returns:
|
416 |
+
tuple:
|
417 |
+
- the last decoder layer's output of shape `(batch, tgt_len, vocab)`
|
418 |
+
- the last decoder layer's attention weights of shape `(batch, tgt_len, src_len)`
|
419 |
+
"""
|
420 |
+
|
421 |
+
embedding_output = self.embeddings(input_ids)
|
422 |
+
|
423 |
+
extended_encoder_mask = encoder_mask.unsqueeze(1).unsqueeze(2) # b x 1 x 1 x ls
|
424 |
+
extended_encoder_mask = extended_encoder_mask.to(dtype=self.dtype) # fp16 compatibility
|
425 |
+
extended_encoder_mask = (1.0 - extended_encoder_mask) * -10000.0
|
426 |
+
|
427 |
+
extended_answer_mask = answer_mask.unsqueeze(1).unsqueeze(2)
|
428 |
+
extended_answer_mask = extended_answer_mask.to(dtype=self.dtype) # fp16 compatibility
|
429 |
+
|
430 |
+
sz_b, len_s, _ = embedding_output.size()
|
431 |
+
subsequent_mask = torch.triu(torch.ones((len_s, len_s), device=embedding_output.device, dtype=embedding_output.dtype), diagonal=1)
|
432 |
+
self_attn_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1).unsqueeze(1) # b x 1 x ls x ls
|
433 |
+
slf_attn_mask = ((1.0 - extended_answer_mask) + self_attn_mask).gt(0).to(dtype=self.dtype)
|
434 |
+
self_attn_mask = slf_attn_mask * -10000.0
|
435 |
+
|
436 |
+
encoder_outs, _ = self.encoder(encoder_outs, extended_encoder_mask, output_all_encoded_layers=True)
|
437 |
+
# encoder_outs = encoder_outs[-1]
|
438 |
+
|
439 |
+
decoded_layers, dec_att_scores = self.decoder(embedding_output,
|
440 |
+
encoder_outs,
|
441 |
+
self_attn_mask,
|
442 |
+
extended_encoder_mask,
|
443 |
+
)
|
444 |
+
sequence_output = decoded_layers[-1]
|
445 |
+
cls_scores = self.classifier(sequence_output)
|
446 |
+
|
447 |
+
return cls_scores
|
utils/module_gated_attention.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r"""Gated Co-attention interface"""
|
2 |
+
from typing import Callable, List, Optional, Tuple
|
3 |
+
import math
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.nn.functional import *
|
8 |
+
|
9 |
+
|
10 |
+
def gated_coattention(
|
11 |
+
query: Tensor,
|
12 |
+
key: Tensor,
|
13 |
+
value: Tensor,
|
14 |
+
embed_dim_to_check: int,
|
15 |
+
num_heads: int,
|
16 |
+
in_proj_weight: Tensor,
|
17 |
+
in_proj_bias: Tensor,
|
18 |
+
bias_k: Optional[Tensor],
|
19 |
+
bias_v: Optional[Tensor],
|
20 |
+
add_zero_attn: bool,
|
21 |
+
dropout_p: float,
|
22 |
+
out_proj_weight: Tensor,
|
23 |
+
out_proj_bias: Tensor,
|
24 |
+
training: bool = True,
|
25 |
+
gt_attention_map: Optional[Tensor] = None,
|
26 |
+
key_padding_mask: Optional[Tensor] = None,
|
27 |
+
need_weights: bool = True,
|
28 |
+
attn_mask: Optional[Tensor] = None,
|
29 |
+
use_separate_proj_weight: bool = False,
|
30 |
+
q_proj_weight: Optional[Tensor] = None,
|
31 |
+
k_proj_weight: Optional[Tensor] = None,
|
32 |
+
v_proj_weight: Optional[Tensor] = None,
|
33 |
+
static_k: Optional[Tensor] = None,
|
34 |
+
static_v: Optional[Tensor] = None,
|
35 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
36 |
+
r"""
|
37 |
+
Args:
|
38 |
+
query, key, value: map a query and a set of key-value pairs to an output.
|
39 |
+
See "Attention Is All You Need" for more details.
|
40 |
+
embed_dim_to_check: total dimension of the model.
|
41 |
+
num_heads: parallel attention heads.
|
42 |
+
in_proj_weight, in_proj_bias: input projection weight and bias.
|
43 |
+
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
44 |
+
add_zero_attn: add a new batch of zeros to the key and
|
45 |
+
value sequences at dim=1.
|
46 |
+
dropout_p: probability of an element to be zeroed.
|
47 |
+
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
48 |
+
training: apply dropout if is ``True``.
|
49 |
+
key_padding_mask: if provided, specified padding elements in the key will
|
50 |
+
be ignored by the attention. This is an binary mask. When the value is True,
|
51 |
+
the corresponding value on the attention layer will be filled with -inf.
|
52 |
+
need_weights: output attn_output_weights.
|
53 |
+
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
54 |
+
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
55 |
+
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
56 |
+
and value in different forms. If false, in_proj_weight will be used, which is
|
57 |
+
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
58 |
+
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
59 |
+
static_k, static_v: static key and value used for attention operators.
|
60 |
+
|
61 |
+
|
62 |
+
Shape:
|
63 |
+
Inputs:
|
64 |
+
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
65 |
+
the embedding dimension.
|
66 |
+
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
67 |
+
the embedding dimension.
|
68 |
+
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
69 |
+
the embedding dimension.
|
70 |
+
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
71 |
+
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
72 |
+
will be unchanged. If a BoolTensor is provided, the positions with the
|
73 |
+
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
74 |
+
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
75 |
+
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
76 |
+
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
77 |
+
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
78 |
+
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
79 |
+
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
80 |
+
is provided, it will be added to the attention weight.
|
81 |
+
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
82 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
83 |
+
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
84 |
+
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
85 |
+
|
86 |
+
Outputs:
|
87 |
+
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
88 |
+
E is the embedding dimension.
|
89 |
+
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
90 |
+
L is the target sequence length, S is the source sequence length.
|
91 |
+
"""
|
92 |
+
tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias)
|
93 |
+
if has_torch_function(tens_ops):
|
94 |
+
return handle_torch_function(
|
95 |
+
multi_head_attention_forward,
|
96 |
+
tens_ops,
|
97 |
+
query,
|
98 |
+
key,
|
99 |
+
value,
|
100 |
+
embed_dim_to_check,
|
101 |
+
num_heads,
|
102 |
+
in_proj_weight,
|
103 |
+
in_proj_bias,
|
104 |
+
bias_k,
|
105 |
+
bias_v,
|
106 |
+
add_zero_attn,
|
107 |
+
dropout_p,
|
108 |
+
out_proj_weight,
|
109 |
+
out_proj_bias,
|
110 |
+
training=training,
|
111 |
+
gt_attention_map=gt_attention_map,
|
112 |
+
key_padding_mask=key_padding_mask,
|
113 |
+
need_weights=need_weights,
|
114 |
+
attn_mask=attn_mask,
|
115 |
+
use_separate_proj_weight=use_separate_proj_weight,
|
116 |
+
q_proj_weight=q_proj_weight,
|
117 |
+
k_proj_weight=k_proj_weight,
|
118 |
+
v_proj_weight=v_proj_weight,
|
119 |
+
static_k=static_k,
|
120 |
+
static_v=static_v,
|
121 |
+
)
|
122 |
+
tgt_len, bsz, embed_dim = query.size()
|
123 |
+
assert embed_dim == embed_dim_to_check
|
124 |
+
# allow MHA to have different sizes for the feature dimension
|
125 |
+
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
126 |
+
|
127 |
+
head_dim = embed_dim // num_heads
|
128 |
+
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
129 |
+
scaling = float(head_dim) ** -0.5
|
130 |
+
|
131 |
+
if not use_separate_proj_weight:
|
132 |
+
# encoder-decoder attention ---->>>>>>>> co-attention style
|
133 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
134 |
+
_b = in_proj_bias
|
135 |
+
_start = 0
|
136 |
+
_end = embed_dim
|
137 |
+
_w = in_proj_weight[_start:_end, :]
|
138 |
+
if _b is not None:
|
139 |
+
_b = _b[_start:_end]
|
140 |
+
q = linear(query, _w, _b)
|
141 |
+
|
142 |
+
if key is None:
|
143 |
+
assert value is None
|
144 |
+
k = None
|
145 |
+
v = None
|
146 |
+
else:
|
147 |
+
|
148 |
+
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
149 |
+
_b = in_proj_bias
|
150 |
+
_start = embed_dim
|
151 |
+
_end = None
|
152 |
+
_w = in_proj_weight[_start:, :]
|
153 |
+
if _b is not None:
|
154 |
+
_b = _b[_start:]
|
155 |
+
k, v = linear(key, _w, _b).chunk(2, dim=-1)
|
156 |
+
|
157 |
+
else:
|
158 |
+
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
159 |
+
len1, len2 = q_proj_weight_non_opt.size()
|
160 |
+
assert len1 == embed_dim and len2 == query.size(-1)
|
161 |
+
|
162 |
+
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
163 |
+
len1, len2 = k_proj_weight_non_opt.size()
|
164 |
+
assert len1 == embed_dim and len2 == key.size(-1)
|
165 |
+
|
166 |
+
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
167 |
+
len1, len2 = v_proj_weight_non_opt.size()
|
168 |
+
assert len1 == embed_dim and len2 == value.size(-1)
|
169 |
+
|
170 |
+
if in_proj_bias is not None:
|
171 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
172 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim : (embed_dim * 2)])
|
173 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2) :])
|
174 |
+
else:
|
175 |
+
q = linear(query, q_proj_weight_non_opt, in_proj_bias)
|
176 |
+
k = linear(key, k_proj_weight_non_opt, in_proj_bias)
|
177 |
+
v = linear(value, v_proj_weight_non_opt, in_proj_bias)
|
178 |
+
q = q * scaling
|
179 |
+
|
180 |
+
if attn_mask is not None:
|
181 |
+
assert (
|
182 |
+
attn_mask.dtype == torch.float32
|
183 |
+
or attn_mask.dtype == torch.float64
|
184 |
+
or attn_mask.dtype == torch.float16
|
185 |
+
or attn_mask.dtype == torch.uint8
|
186 |
+
or attn_mask.dtype == torch.bool
|
187 |
+
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(attn_mask.dtype)
|
188 |
+
if attn_mask.dtype == torch.uint8:
|
189 |
+
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
190 |
+
attn_mask = attn_mask.to(torch.bool)
|
191 |
+
|
192 |
+
if attn_mask.dim() == 2:
|
193 |
+
attn_mask = attn_mask.unsqueeze(0)
|
194 |
+
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
195 |
+
raise RuntimeError("The size of the 2D attn_mask is not correct.")
|
196 |
+
elif attn_mask.dim() == 3:
|
197 |
+
if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
|
198 |
+
raise RuntimeError("The size of the 3D attn_mask is not correct.")
|
199 |
+
else:
|
200 |
+
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
201 |
+
# attn_mask's dim is 3 now.
|
202 |
+
|
203 |
+
# convert ByteTensor key_padding_mask to bool
|
204 |
+
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
205 |
+
warnings.warn(
|
206 |
+
"Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead."
|
207 |
+
)
|
208 |
+
key_padding_mask = key_padding_mask.to(torch.bool)
|
209 |
+
|
210 |
+
if bias_k is not None and bias_v is not None:
|
211 |
+
if static_k is None and static_v is None:
|
212 |
+
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
213 |
+
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
214 |
+
if attn_mask is not None:
|
215 |
+
attn_mask = pad(attn_mask, (0, 1))
|
216 |
+
if key_padding_mask is not None:
|
217 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
218 |
+
else:
|
219 |
+
assert static_k is None, "bias cannot be added to static key."
|
220 |
+
assert static_v is None, "bias cannot be added to static value."
|
221 |
+
else:
|
222 |
+
assert bias_k is None
|
223 |
+
assert bias_v is None
|
224 |
+
|
225 |
+
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
226 |
+
if k is not None:
|
227 |
+
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
228 |
+
if v is not None:
|
229 |
+
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
230 |
+
|
231 |
+
if static_k is not None:
|
232 |
+
assert static_k.size(0) == bsz * num_heads
|
233 |
+
assert static_k.size(2) == head_dim
|
234 |
+
k = static_k
|
235 |
+
|
236 |
+
if static_v is not None:
|
237 |
+
assert static_v.size(0) == bsz * num_heads
|
238 |
+
assert static_v.size(2) == head_dim
|
239 |
+
v = static_v
|
240 |
+
|
241 |
+
src_len = k.size(1)
|
242 |
+
|
243 |
+
if key_padding_mask is not None:
|
244 |
+
assert key_padding_mask.size(0) == bsz
|
245 |
+
assert key_padding_mask.size(1) == src_len
|
246 |
+
|
247 |
+
if add_zero_attn:
|
248 |
+
src_len += 1
|
249 |
+
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
250 |
+
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
251 |
+
if attn_mask is not None:
|
252 |
+
attn_mask = pad(attn_mask, (0, 1))
|
253 |
+
if key_padding_mask is not None:
|
254 |
+
key_padding_mask = pad(key_padding_mask, (0, 1))
|
255 |
+
|
256 |
+
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
257 |
+
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
258 |
+
|
259 |
+
if attn_mask is not None:
|
260 |
+
if attn_mask.dtype == torch.bool:
|
261 |
+
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
262 |
+
else:
|
263 |
+
attn_output_weights += attn_mask
|
264 |
+
|
265 |
+
if key_padding_mask is not None:
|
266 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
267 |
+
attn_output_weights = attn_output_weights.masked_fill(
|
268 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
269 |
+
float("-inf"),
|
270 |
+
)
|
271 |
+
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
272 |
+
|
273 |
+
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
274 |
+
|
275 |
+
attn_output_weights_new = attn_output_weights.clone()
|
276 |
+
attn_output_weights_new[:, 0, :] = gt_attention_map
|
277 |
+
attn_output_weights = attn_output_weights_new
|
278 |
+
|
279 |
+
##################### TAB Forget Gate (start) #########################
|
280 |
+
# temp_weights = torch.ones_like(attn_output_weights[:, 0, 0], device=attn_output_weights.device)
|
281 |
+
# temp_weights = temp_weights - attn_output_weights[:, 0, 0]
|
282 |
+
# weights = torch.mul(attn_output_weights[:, 0, :], temp_weights[:, None])
|
283 |
+
# attn_output_weights_new = attn_output_weights.clone() # Create a copy
|
284 |
+
# attn_output_weights_new[:, 0, :] = weights # Modify the copy
|
285 |
+
##################### TAB Forget Gate (end) #########################
|
286 |
+
|
287 |
+
attn_output_weights_new = dropout(attn_output_weights_new, p=dropout_p, training=training)
|
288 |
+
attn_output = torch.bmm(attn_output_weights_new, v)
|
289 |
+
|
290 |
+
|
291 |
+
|
292 |
+
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
293 |
+
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
294 |
+
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
295 |
+
|
296 |
+
if need_weights:
|
297 |
+
# average attention weights over heads
|
298 |
+
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
299 |
+
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
300 |
+
else:
|
301 |
+
return attn_output, None
|
utils/tokenization_clip.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import ftfy
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def default_bpe():
|
12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def bytes_to_unicode():
|
17 |
+
"""
|
18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
+
The reversible bpe codes work on unicode strings.
|
20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
+
"""
|
26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
+
cs = bs[:]
|
28 |
+
n = 0
|
29 |
+
for b in range(2**8):
|
30 |
+
if b not in bs:
|
31 |
+
bs.append(b)
|
32 |
+
cs.append(2**8+n)
|
33 |
+
n += 1
|
34 |
+
cs = [chr(n) for n in cs]
|
35 |
+
return dict(zip(bs, cs))
|
36 |
+
|
37 |
+
|
38 |
+
def get_pairs(word):
|
39 |
+
"""Return set of symbol pairs in a word.
|
40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
+
"""
|
42 |
+
pairs = set()
|
43 |
+
prev_char = word[0]
|
44 |
+
for char in word[1:]:
|
45 |
+
pairs.add((prev_char, char))
|
46 |
+
prev_char = char
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def basic_clean(text):
|
51 |
+
text = ftfy.fix_text(text)
|
52 |
+
text = html.unescape(html.unescape(text))
|
53 |
+
return text.strip()
|
54 |
+
|
55 |
+
|
56 |
+
def whitespace_clean(text):
|
57 |
+
text = re.sub(r'\s+', ' ', text)
|
58 |
+
text = text.strip()
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
class SimpleTokenizer(object):
|
63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
+
self.byte_encoder = bytes_to_unicode()
|
65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
+
merges = merges[1:49152-256-2+1]
|
68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
69 |
+
vocab = list(bytes_to_unicode().values())
|
70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
+
for merge in merges:
|
72 |
+
vocab.append(''.join(merge))
|
73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
+
|
80 |
+
self.vocab = self.encoder
|
81 |
+
|
82 |
+
def bpe(self, token):
|
83 |
+
if token in self.cache:
|
84 |
+
return self.cache[token]
|
85 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
86 |
+
pairs = get_pairs(word)
|
87 |
+
|
88 |
+
if not pairs:
|
89 |
+
return token+'</w>'
|
90 |
+
|
91 |
+
while True:
|
92 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
93 |
+
if bigram not in self.bpe_ranks:
|
94 |
+
break
|
95 |
+
first, second = bigram
|
96 |
+
new_word = []
|
97 |
+
i = 0
|
98 |
+
while i < len(word):
|
99 |
+
try:
|
100 |
+
j = word.index(first, i)
|
101 |
+
new_word.extend(word[i:j])
|
102 |
+
i = j
|
103 |
+
except:
|
104 |
+
new_word.extend(word[i:])
|
105 |
+
break
|
106 |
+
|
107 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
108 |
+
new_word.append(first+second)
|
109 |
+
i += 2
|
110 |
+
else:
|
111 |
+
new_word.append(word[i])
|
112 |
+
i += 1
|
113 |
+
new_word = tuple(new_word)
|
114 |
+
word = new_word
|
115 |
+
if len(word) == 1:
|
116 |
+
break
|
117 |
+
else:
|
118 |
+
pairs = get_pairs(word)
|
119 |
+
word = ' '.join(word)
|
120 |
+
self.cache[token] = word
|
121 |
+
return word
|
122 |
+
|
123 |
+
def encode(self, text):
|
124 |
+
bpe_tokens = []
|
125 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
126 |
+
for token in re.findall(self.pat, text):
|
127 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
128 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
129 |
+
return bpe_tokens
|
130 |
+
|
131 |
+
def decode(self, tokens):
|
132 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
133 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
134 |
+
return text
|
135 |
+
|
136 |
+
def tokenize(self, text):
|
137 |
+
tokens = []
|
138 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
139 |
+
for token in re.findall(self.pat, text):
|
140 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
141 |
+
tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
142 |
+
return tokens
|
143 |
+
|
144 |
+
def convert_tokens_to_ids(self, tokens):
|
145 |
+
return [self.encoder[bpe_token] for bpe_token in tokens]
|
146 |
+
|
147 |
+
def convert_ids_to_tokens(self, ids):
|
148 |
+
"""Converts a sequence of ids in tokens using the vocab."""
|
149 |
+
return self.decode(ids)
|
utils/until_config.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch BERT model."""
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import os
|
23 |
+
import copy
|
24 |
+
import json
|
25 |
+
import logging
|
26 |
+
import tarfile
|
27 |
+
import tempfile
|
28 |
+
import shutil
|
29 |
+
import torch
|
30 |
+
from .file_utils import cached_path
|
31 |
+
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
class PretrainedConfig(object):
|
35 |
+
|
36 |
+
pretrained_model_archive_map = {}
|
37 |
+
config_name = ""
|
38 |
+
weights_name = ""
|
39 |
+
|
40 |
+
@classmethod
|
41 |
+
def get_config(cls, pretrained_model_name, cache_dir, type_vocab_size, state_dict, task_config=None):
|
42 |
+
archive_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), pretrained_model_name)
|
43 |
+
if os.path.exists(archive_file) is False:
|
44 |
+
if pretrained_model_name in cls.pretrained_model_archive_map:
|
45 |
+
archive_file = cls.pretrained_model_archive_map[pretrained_model_name]
|
46 |
+
else:
|
47 |
+
archive_file = pretrained_model_name
|
48 |
+
|
49 |
+
# redirect to the cache, if necessary
|
50 |
+
try:
|
51 |
+
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
52 |
+
except FileNotFoundError:
|
53 |
+
if task_config is None or task_config.local_rank == 0:
|
54 |
+
logger.error(
|
55 |
+
"Model name '{}' was not found in model name list. "
|
56 |
+
"We assumed '{}' was a path or url but couldn't find any file "
|
57 |
+
"associated to this path or url.".format(
|
58 |
+
pretrained_model_name,
|
59 |
+
archive_file))
|
60 |
+
return None
|
61 |
+
if resolved_archive_file == archive_file:
|
62 |
+
if task_config is None or task_config.local_rank == 0:
|
63 |
+
logger.info("loading archive file {}".format(archive_file))
|
64 |
+
else:
|
65 |
+
if task_config is None or task_config.local_rank == 0:
|
66 |
+
logger.info("loading archive file {} from cache at {}".format(
|
67 |
+
archive_file, resolved_archive_file))
|
68 |
+
tempdir = None
|
69 |
+
if os.path.isdir(resolved_archive_file):
|
70 |
+
serialization_dir = resolved_archive_file
|
71 |
+
else:
|
72 |
+
# Extract archive to temp dir
|
73 |
+
tempdir = tempfile.mkdtemp()
|
74 |
+
if task_config is None or task_config.local_rank == 0:
|
75 |
+
logger.info("extracting archive file {} to temp dir {}".format(
|
76 |
+
resolved_archive_file, tempdir))
|
77 |
+
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
|
78 |
+
archive.extractall(tempdir)
|
79 |
+
serialization_dir = tempdir
|
80 |
+
# Load config
|
81 |
+
config_file = os.path.join(serialization_dir, cls.config_name)
|
82 |
+
config = cls.from_json_file(config_file)
|
83 |
+
config.type_vocab_size = type_vocab_size
|
84 |
+
if task_config is None or task_config.local_rank == 0:
|
85 |
+
logger.info("Model config {}".format(config))
|
86 |
+
|
87 |
+
if state_dict is None:
|
88 |
+
weights_path = os.path.join(serialization_dir, cls.weights_name)
|
89 |
+
if os.path.exists(weights_path):
|
90 |
+
state_dict = torch.load(weights_path, map_location='cpu')
|
91 |
+
else:
|
92 |
+
if task_config is None or task_config.local_rank == 0:
|
93 |
+
logger.info("Weight doesn't exsits. {}".format(weights_path))
|
94 |
+
|
95 |
+
if tempdir:
|
96 |
+
# Clean up temp dir
|
97 |
+
shutil.rmtree(tempdir)
|
98 |
+
|
99 |
+
return config, state_dict
|
100 |
+
|
101 |
+
@classmethod
|
102 |
+
def from_dict(cls, json_object):
|
103 |
+
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
|
104 |
+
config = cls(vocab_size_or_config_json_file=-1)
|
105 |
+
for key, value in json_object.items():
|
106 |
+
config.__dict__[key] = value
|
107 |
+
return config
|
108 |
+
|
109 |
+
@classmethod
|
110 |
+
def from_json_file(cls, json_file):
|
111 |
+
"""Constructs a `BertConfig` from a json file of parameters."""
|
112 |
+
with open(json_file, "r", encoding='utf-8') as reader:
|
113 |
+
text = reader.read()
|
114 |
+
return cls.from_dict(json.loads(text))
|
115 |
+
|
116 |
+
def __repr__(self):
|
117 |
+
return str(self.to_json_string())
|
118 |
+
|
119 |
+
def to_dict(self):
|
120 |
+
"""Serializes this instance to a Python dictionary."""
|
121 |
+
output = copy.deepcopy(self.__dict__)
|
122 |
+
return output
|
123 |
+
|
124 |
+
def to_json_string(self):
|
125 |
+
"""Serializes this instance to a JSON string."""
|
126 |
+
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
utils/until_module.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch BERT model."""
|
17 |
+
|
18 |
+
import logging
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import math
|
24 |
+
from .until_config import PretrainedConfig
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
def gelu(x):
|
29 |
+
"""Implementation of the gelu activation function.
|
30 |
+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
31 |
+
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
32 |
+
"""
|
33 |
+
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
34 |
+
|
35 |
+
def swish(x):
|
36 |
+
return x * torch.sigmoid(x)
|
37 |
+
|
38 |
+
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
39 |
+
|
40 |
+
class LayerNorm(nn.Module):
|
41 |
+
def __init__(self, hidden_size, eps=1e-12):
|
42 |
+
"""Construct a layernorm module in the TF style (epsilon inside the square root).
|
43 |
+
"""
|
44 |
+
super(LayerNorm, self).__init__()
|
45 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
46 |
+
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
47 |
+
self.variance_epsilon = eps
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
u = x.mean(-1, keepdim=True)
|
51 |
+
s = (x - u).pow(2).mean(-1, keepdim=True)
|
52 |
+
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
53 |
+
return self.weight * x + self.bias
|
54 |
+
|
55 |
+
class PreTrainedModel(nn.Module):
|
56 |
+
""" An abstract class to handle weights initialization and
|
57 |
+
a simple interface for dowloading and loading pretrained models.
|
58 |
+
"""
|
59 |
+
def __init__(self, config, *inputs, **kwargs):
|
60 |
+
super(PreTrainedModel, self).__init__()
|
61 |
+
if not isinstance(config, PretrainedConfig):
|
62 |
+
raise ValueError(
|
63 |
+
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
|
64 |
+
"To create a model from a Google pretrained model use "
|
65 |
+
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
66 |
+
self.__class__.__name__, self.__class__.__name__
|
67 |
+
))
|
68 |
+
self.config = config
|
69 |
+
|
70 |
+
def init_weights(self, module):
|
71 |
+
""" Initialize the weights.
|
72 |
+
"""
|
73 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
74 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
75 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
76 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
77 |
+
elif isinstance(module, LayerNorm):
|
78 |
+
if 'beta' in dir(module) and 'gamma' in dir(module):
|
79 |
+
module.beta.data.zero_()
|
80 |
+
module.gamma.data.fill_(1.0)
|
81 |
+
else:
|
82 |
+
module.bias.data.zero_()
|
83 |
+
module.weight.data.fill_(1.0)
|
84 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
85 |
+
module.bias.data.zero_()
|
86 |
+
|
87 |
+
def resize_token_embeddings(self, new_num_tokens=None):
|
88 |
+
raise NotImplementedError
|
89 |
+
|
90 |
+
@classmethod
|
91 |
+
def init_preweight(cls, model, state_dict, prefix=None):
|
92 |
+
old_keys = []
|
93 |
+
new_keys = []
|
94 |
+
for key in state_dict.keys():
|
95 |
+
new_key = None
|
96 |
+
if 'gamma' in key:
|
97 |
+
new_key = key.replace('gamma', 'weight')
|
98 |
+
if 'beta' in key:
|
99 |
+
new_key = key.replace('beta', 'bias')
|
100 |
+
if new_key:
|
101 |
+
old_keys.append(key)
|
102 |
+
new_keys.append(new_key)
|
103 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
104 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
105 |
+
|
106 |
+
if prefix is not None:
|
107 |
+
old_keys = []
|
108 |
+
new_keys = []
|
109 |
+
for key in state_dict.keys():
|
110 |
+
old_keys.append(key)
|
111 |
+
new_keys.append(prefix + key)
|
112 |
+
for old_key, new_key in zip(old_keys, new_keys):
|
113 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
114 |
+
|
115 |
+
missing_keys = []
|
116 |
+
unexpected_keys = []
|
117 |
+
error_msgs = []
|
118 |
+
# copy state_dict so _load_from_state_dict can modify it
|
119 |
+
metadata = getattr(state_dict, '_metadata', None)
|
120 |
+
state_dict = state_dict.copy()
|
121 |
+
if metadata is not None:
|
122 |
+
state_dict._metadata = metadata
|
123 |
+
|
124 |
+
def load(module, prefix=''):
|
125 |
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
126 |
+
module._load_from_state_dict(
|
127 |
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
128 |
+
for name, child in module._modules.items():
|
129 |
+
if child is not None:
|
130 |
+
load(child, prefix + name + '.')
|
131 |
+
|
132 |
+
load(model, prefix='')
|
133 |
+
return model
|
134 |
+
|
135 |
+
@property
|
136 |
+
def dtype(self):
|
137 |
+
"""
|
138 |
+
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
139 |
+
"""
|
140 |
+
try:
|
141 |
+
return next(self.parameters()).dtype
|
142 |
+
except StopIteration:
|
143 |
+
# For nn.DataParallel compatibility in PyTorch 1.5
|
144 |
+
def find_tensor_attributes(module: nn.Module):
|
145 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
146 |
+
return tuples
|
147 |
+
|
148 |
+
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
149 |
+
first_tuple = next(gen)
|
150 |
+
return first_tuple[1].dtype
|
151 |
+
|
152 |
+
@classmethod
|
153 |
+
def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs):
|
154 |
+
"""
|
155 |
+
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
|
156 |
+
Download and cache the pre-trained model file if needed.
|
157 |
+
"""
|
158 |
+
# Instantiate model.
|
159 |
+
model = cls(config, *inputs, **kwargs)
|
160 |
+
if state_dict is None:
|
161 |
+
return model
|
162 |
+
model = cls.init_preweight(model, state_dict)
|
163 |
+
|
164 |
+
return model
|
165 |
+
|
166 |
+
##################################
|
167 |
+
###### LOSS FUNCTION #############
|
168 |
+
##################################
|
169 |
+
|
170 |
+
|
171 |
+
import itertools
|
172 |
+
|
173 |
+
class GroupCEn(nn.Module):
|
174 |
+
def __init__(self,):
|
175 |
+
super(GroupCEn, self).__init__()
|
176 |
+
|
177 |
+
def forward(self, sim_matrix, target):
|
178 |
+
mask = torch.eye(sim_matrix.size()[0], device=sim_matrix.device)
|
179 |
+
indicies = torch.nonzero(target)
|
180 |
+
com = list(itertools.product(indicies, indicies))
|
181 |
+
for x, y in com:
|
182 |
+
mask[x, y] = 1.0
|
183 |
+
logpt = F.log_softmax(sim_matrix, dim=-1)
|
184 |
+
logpt = logpt * mask
|
185 |
+
logpt = torch.sum(logpt, dim=-1) / torch.sum(mask, dim=-1)
|
186 |
+
logpt = torch.diag(logpt)
|
187 |
+
nce_loss = -logpt
|
188 |
+
sim_loss = nce_loss.mean()
|
189 |
+
return sim_loss
|
190 |
+
|
191 |
+
class CrossEn(nn.Module):
|
192 |
+
def __init__(self,):
|
193 |
+
super(CrossEn, self).__init__()
|
194 |
+
|
195 |
+
def forward(self, sim_matrix):
|
196 |
+
logpt = F.log_softmax(sim_matrix, dim=-1)
|
197 |
+
logpt = torch.diag(logpt)
|
198 |
+
nce_loss = -logpt
|
199 |
+
sim_loss = nce_loss.mean()
|
200 |
+
return sim_loss
|
201 |
+
|
202 |
+
class MILNCELoss(nn.Module):
|
203 |
+
def __init__(self, batch_size=1, n_pair=1,):
|
204 |
+
super(MILNCELoss, self).__init__()
|
205 |
+
self.batch_size = batch_size
|
206 |
+
self.n_pair = n_pair
|
207 |
+
torch_v = float(".".join(torch.__version__.split(".")[:2]))
|
208 |
+
self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8
|
209 |
+
|
210 |
+
def forward(self, sim_matrix):
|
211 |
+
mm_mask = np.eye(self.batch_size)
|
212 |
+
mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair)))
|
213 |
+
mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device)
|
214 |
+
|
215 |
+
from_text_matrix = sim_matrix + mm_mask * -1e12
|
216 |
+
from_video_matrix = sim_matrix.transpose(1, 0)
|
217 |
+
|
218 |
+
new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1)
|
219 |
+
logpt = F.log_softmax(new_sim_matrix, dim=-1)
|
220 |
+
|
221 |
+
mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1)
|
222 |
+
masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12
|
223 |
+
|
224 |
+
new_logpt = -torch.logsumexp(masked_logpt, dim=-1)
|
225 |
+
|
226 |
+
logpt_choice = torch.zeros_like(new_logpt)
|
227 |
+
mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2)
|
228 |
+
logpt_choice[mark_ind] = 1
|
229 |
+
sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean()
|
230 |
+
return sim_loss
|
231 |
+
|
232 |
+
class MaxMarginRankingLoss(nn.Module):
|
233 |
+
def __init__(self,
|
234 |
+
margin=1.0,
|
235 |
+
negative_weighting=False,
|
236 |
+
batch_size=1,
|
237 |
+
n_pair=1,
|
238 |
+
hard_negative_rate=0.5,
|
239 |
+
):
|
240 |
+
super(MaxMarginRankingLoss, self).__init__()
|
241 |
+
self.margin = margin
|
242 |
+
self.n_pair = n_pair
|
243 |
+
self.batch_size = batch_size
|
244 |
+
easy_negative_rate = 1 - hard_negative_rate
|
245 |
+
self.easy_negative_rate = easy_negative_rate
|
246 |
+
self.negative_weighting = negative_weighting
|
247 |
+
if n_pair > 1 and batch_size > 1:
|
248 |
+
alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate))
|
249 |
+
mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha
|
250 |
+
mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair)))
|
251 |
+
mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate))
|
252 |
+
self.mm_mask = mm_mask.float()
|
253 |
+
|
254 |
+
def forward(self, x):
|
255 |
+
d = torch.diag(x)
|
256 |
+
max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \
|
257 |
+
F.relu(self.margin + x - d.view(1, -1))
|
258 |
+
if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1:
|
259 |
+
max_margin = max_margin * self.mm_mask.to(max_margin.device)
|
260 |
+
return max_margin.mean()
|
261 |
+
|
262 |
+
class AllGather(torch.autograd.Function):
|
263 |
+
"""An autograd function that performs allgather on a tensor."""
|
264 |
+
|
265 |
+
@staticmethod
|
266 |
+
def forward(ctx, tensor, args):
|
267 |
+
output = [torch.empty_like(tensor) for _ in range(args.world_size)]
|
268 |
+
torch.distributed.all_gather(output, tensor)
|
269 |
+
ctx.rank = args.rank
|
270 |
+
ctx.batch_size = tensor.shape[0]
|
271 |
+
return torch.cat(output, dim=0)
|
272 |
+
|
273 |
+
@staticmethod
|
274 |
+
def backward(ctx, grad_output):
|
275 |
+
return (
|
276 |
+
grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
|
277 |
+
None,
|
278 |
+
)
|