PJAITEST1903 / model.py
Nah_kagz1092
Update model.py
e2692aa verified
raw
history blame contribute delete
481 Bytes
import tensorflow as tf
class TransformerModel(tf.keras.Model):
def __init__(self, config):
super(TransformerModel, self).__init__()
self.encoder = tf.keras.layers.Transformer(**config["encoder_params"])
self.decoder = tf.keras.layers.Transformer(**config["decoder_params"])
def call(self, inputs, targets):
encoder_output = self.encoder(inputs)
decoder_output = self.decoder(targets, encoder_output)
return decoder_output