import tensorflow as tf | |
class TransformerModel(tf.keras.Model): | |
def __init__(self, config): | |
super(TransformerModel, self).__init__() | |
self.encoder = tf.keras.layers.Transformer(**config["encoder_params"]) | |
self.decoder = tf.keras.layers.Transformer(**config["decoder_params"]) | |
def call(self, inputs, targets): | |
encoder_output = self.encoder(inputs) | |
decoder_output = self.decoder(targets, encoder_output) | |
return decoder_output | |