Yanisadel commited on
Commit
3bbb8d3
·
verified ·
1 Parent(s): 3750642

Upload sCT

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. config.py +2 -2
  3. sct.py +0 -7
config.json CHANGED
@@ -18,7 +18,7 @@
18
  "layer_norm_eps": 1e-05,
19
  "mask_token_id": 5,
20
  "max_positions": 20480,
21
- "model_type": "sCellTransformer",
22
  "num_cells": 50,
23
  "num_downsamples": 8,
24
  "num_hidden_layers_head": 1,
 
18
  "layer_norm_eps": 1e-05,
19
  "mask_token_id": 5,
20
  "max_positions": 20480,
21
+ "model_type": "sCT",
22
  "num_cells": 50,
23
  "num_downsamples": 8,
24
  "num_hidden_layers_head": 1,
config.py CHANGED
@@ -1,4 +1,4 @@
1
- from dataclasses import dataclass, field
2
  from typing import Tuple
3
 
4
  from transformers import PretrainedConfig
@@ -6,7 +6,7 @@ from transformers import PretrainedConfig
6
 
7
  @dataclass
8
  class sCTConfig(PretrainedConfig): # noqa: N801
9
- model_type = "sCellTransformer"
10
 
11
  def __init__(self, **kwargs): # type: ignore
12
  super().__init__()
 
1
+ from dataclasses import dataclass
2
  from typing import Tuple
3
 
4
  from transformers import PretrainedConfig
 
6
 
7
  @dataclass
8
  class sCTConfig(PretrainedConfig): # noqa: N801
9
+ model_type = "sCT"
10
 
11
  def __init__(self, **kwargs): # type: ignore
12
  super().__init__()
sct.py CHANGED
@@ -672,9 +672,7 @@ class sCT(PreTrainedModel): # noqa: N801
672
  for _idx, conv_block in enumerate(self.conv_tower):
673
  x, res = conv_block(x)
674
  residuals.append(res)
675
- outs["residuals"] = residuals
676
  residuals = residuals[::-1]
677
- conv_block_out = x
678
  x = x.permute(0, 2, 1)
679
 
680
  for layer_idx, transformer in enumerate(self.transformer_layers):
@@ -686,16 +684,11 @@ class sCT(PreTrainedModel): # noqa: N801
686
  for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]:
687
  dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}"
688
  outs[dkey] = output["attention_weights"][:, map_number + 1]
689
- transformer_output = x
690
  x = x.permute(0, 2, 1)
691
  for deconv_block, res in zip(self.deconv_tower, residuals):
692
  x = deconv_block(x, res)
693
- deconv_block_out = x
694
  x = x.permute(0, 2, 1)
695
  logits = self.lm_head(x)
696
  outs["logits"] = logits
697
- outs["transformer_output"] = transformer_output
698
- outs["conv_out"] = conv_block_out
699
- outs["deconv_out"] = deconv_block_out
700
 
701
  return outs
 
672
  for _idx, conv_block in enumerate(self.conv_tower):
673
  x, res = conv_block(x)
674
  residuals.append(res)
 
675
  residuals = residuals[::-1]
 
676
  x = x.permute(0, 2, 1)
677
 
678
  for layer_idx, transformer in enumerate(self.transformer_layers):
 
684
  for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]:
685
  dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}"
686
  outs[dkey] = output["attention_weights"][:, map_number + 1]
 
687
  x = x.permute(0, 2, 1)
688
  for deconv_block, res in zip(self.deconv_tower, residuals):
689
  x = deconv_block(x, res)
 
690
  x = x.permute(0, 2, 1)
691
  logits = self.lm_head(x)
692
  outs["logits"] = logits
 
 
 
693
 
694
  return outs