File size: 2,258 Bytes
6a71166 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class DenseTransformerConfig(PretrainedConfig):
model_type = "dense_transformer"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=50256,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=5,
num_attention_heads=12,
num_key_value_heads=8,
max_position_embeddings=1024,
rms_norm_eps=1e-6,
attention_dropout=0.1,
hidden_dropout=0.1,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
torch_dtype="float32",
pretraining_tp=1,
output_attentions=False,
output_hidden_states=False,
use_return_dict=True,
# Custom fields from ModelArgs
d_head=64,
window_size=128,
seq_len=512,
attn_eps=1e-6,
ffn_eps=1e-6,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
self.attention_dropout = attention_dropout
self.hidden_dropout = hidden_dropout
self.use_cache = use_cache
self.pretraining_tp = pretraining_tp
self.output_attentions = output_attentions
self.output_hidden_states = output_hidden_states
self.use_return_dict = use_return_dict
# Custom fields
self.d_head = d_head
self.window_size = window_size
self.seq_len = seq_len
self.attn_eps = attn_eps
self.ffn_eps = ffn_eps
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
torch_dtype=torch_dtype,
**kwargs,
) |