Upload sCT
Browse files- config.json +1 -1
- config.py +2 -2
- sct.py +0 -7
config.json
CHANGED
@@ -18,7 +18,7 @@
|
|
18 |
"layer_norm_eps": 1e-05,
|
19 |
"mask_token_id": 5,
|
20 |
"max_positions": 20480,
|
21 |
-
"model_type": "
|
22 |
"num_cells": 50,
|
23 |
"num_downsamples": 8,
|
24 |
"num_hidden_layers_head": 1,
|
|
|
18 |
"layer_norm_eps": 1e-05,
|
19 |
"mask_token_id": 5,
|
20 |
"max_positions": 20480,
|
21 |
+
"model_type": "sCT",
|
22 |
"num_cells": 50,
|
23 |
"num_downsamples": 8,
|
24 |
"num_hidden_layers_head": 1,
|
config.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
from typing import Tuple
|
3 |
|
4 |
from transformers import PretrainedConfig
|
@@ -6,7 +6,7 @@ from transformers import PretrainedConfig
|
|
6 |
|
7 |
@dataclass
|
8 |
class sCTConfig(PretrainedConfig): # noqa: N801
|
9 |
-
model_type = "
|
10 |
|
11 |
def __init__(self, **kwargs): # type: ignore
|
12 |
super().__init__()
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
from typing import Tuple
|
3 |
|
4 |
from transformers import PretrainedConfig
|
|
|
6 |
|
7 |
@dataclass
|
8 |
class sCTConfig(PretrainedConfig): # noqa: N801
|
9 |
+
model_type = "sCT"
|
10 |
|
11 |
def __init__(self, **kwargs): # type: ignore
|
12 |
super().__init__()
|
sct.py
CHANGED
@@ -672,9 +672,7 @@ class sCT(PreTrainedModel): # noqa: N801
|
|
672 |
for _idx, conv_block in enumerate(self.conv_tower):
|
673 |
x, res = conv_block(x)
|
674 |
residuals.append(res)
|
675 |
-
outs["residuals"] = residuals
|
676 |
residuals = residuals[::-1]
|
677 |
-
conv_block_out = x
|
678 |
x = x.permute(0, 2, 1)
|
679 |
|
680 |
for layer_idx, transformer in enumerate(self.transformer_layers):
|
@@ -686,16 +684,11 @@ class sCT(PreTrainedModel): # noqa: N801
|
|
686 |
for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]:
|
687 |
dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}"
|
688 |
outs[dkey] = output["attention_weights"][:, map_number + 1]
|
689 |
-
transformer_output = x
|
690 |
x = x.permute(0, 2, 1)
|
691 |
for deconv_block, res in zip(self.deconv_tower, residuals):
|
692 |
x = deconv_block(x, res)
|
693 |
-
deconv_block_out = x
|
694 |
x = x.permute(0, 2, 1)
|
695 |
logits = self.lm_head(x)
|
696 |
outs["logits"] = logits
|
697 |
-
outs["transformer_output"] = transformer_output
|
698 |
-
outs["conv_out"] = conv_block_out
|
699 |
-
outs["deconv_out"] = deconv_block_out
|
700 |
|
701 |
return outs
|
|
|
672 |
for _idx, conv_block in enumerate(self.conv_tower):
|
673 |
x, res = conv_block(x)
|
674 |
residuals.append(res)
|
|
|
675 |
residuals = residuals[::-1]
|
|
|
676 |
x = x.permute(0, 2, 1)
|
677 |
|
678 |
for layer_idx, transformer in enumerate(self.transformer_layers):
|
|
|
684 |
for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]:
|
685 |
dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}"
|
686 |
outs[dkey] = output["attention_weights"][:, map_number + 1]
|
|
|
687 |
x = x.permute(0, 2, 1)
|
688 |
for deconv_block, res in zip(self.deconv_tower, residuals):
|
689 |
x = deconv_block(x, res)
|
|
|
690 |
x = x.permute(0, 2, 1)
|
691 |
logits = self.lm_head(x)
|
692 |
outs["logits"] = logits
|
|
|
|
|
|
|
693 |
|
694 |
return outs
|