import tensorflow as tf class TransformerModel(tf.keras.Model): def __init__(self, config): super(TransformerModel, self).__init__() self.encoder = tf.keras.layers.Transformer(**config["encoder_params"]) self.decoder = tf.keras.layers.Transformer(**config["decoder_params"]) def call(self, inputs, targets): encoder_output = self.encoder(inputs) decoder_output = self.decoder(targets, encoder_output) return decoder_output