Rend19's picture
Upload app.py
f6a45ee verified
import gradio as gr
import pickle
import tensorflow as tf
import keras.ops as ops
import keras
from keras import layers
from keras.layers import TextVectorization
# from gradio_webrtc import WebRTC
@keras.saving.register_keras_serializable()
class TextVectorization(keras.layers.TextVectorization):
pass
@keras.saving.register_keras_serializable()
class StringLookup(keras.layers.StringLookup):
pass
@keras.saving.register_keras_serializable(package="Transformer")
class TransformerEncoder(layers.Layer):
def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.dense_dim = dense_dim
self.num_heads = num_heads
self.attention = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim
)
self.dense_proj = keras.Sequential(
[
layers.Dense(dense_dim, activation="relu"),
layers.Dense(embed_dim),
]
)
self.layernorm_1 = layers.LayerNormalization()
self.layernorm_2 = layers.LayerNormalization()
self.supports_masking = True
def call(self, inputs, mask=None):
if mask is not None:
padding_mask = ops.cast(mask[:, None, :], dtype="int32")
else:
padding_mask = None
attention_output = self.attention(
query=inputs, value=inputs, key=inputs, attention_mask=padding_mask
)
proj_input = self.layernorm_1(inputs + attention_output)
proj_output = self.dense_proj(proj_input)
return self.layernorm_2(proj_input + proj_output)
def get_config(self):
config = super().get_config()
config.update(
{
"embed_dim": self.embed_dim,
"dense_dim": self.dense_dim,
"num_heads": self.num_heads,
}
)
return config
@keras.saving.register_keras_serializable(package="Transformer")
class PositionalEmbedding(layers.Layer):
def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):
super().__init__(**kwargs)
self.token_embeddings = layers.Embedding(
input_dim=vocab_size, output_dim=embed_dim
)
self.position_embeddings = layers.Embedding(
input_dim=sequence_length, output_dim=embed_dim
)
self.sequence_length = sequence_length
self.vocab_size = vocab_size
self.embed_dim = embed_dim
def call(self, inputs):
length = ops.shape(inputs)[-1]
positions = ops.arange(0, length, 1)
embedded_tokens = self.token_embeddings(inputs)
embedded_positions = self.position_embeddings(positions)
return embedded_tokens + embedded_positions
def compute_mask(self, inputs, mask=None):
return ops.not_equal(inputs, 0)
def get_config(self):
config = super().get_config()
config.update(
{
"sequence_length": self.sequence_length,
"vocab_size": self.vocab_size,
"embed_dim": self.embed_dim,
}
)
return config
@keras.saving.register_keras_serializable(package="Transformer")
class TransformerDecoder(layers.Layer):
def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.latent_dim = latent_dim
self.num_heads = num_heads
self.attention_1 = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim
)
self.attention_2 = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=embed_dim
)
self.dense_proj = keras.Sequential(
[
layers.Dense(latent_dim, activation="relu"),
layers.Dense(embed_dim),
]
)
self.layernorm_1 = layers.LayerNormalization()
self.layernorm_2 = layers.LayerNormalization()
self.layernorm_3 = layers.LayerNormalization()
self.supports_masking = True
def call(self, inputs, mask=None):
inputs, encoder_outputs = inputs
causal_mask = self.get_causal_attention_mask(inputs)
if mask is None:
inputs_padding_mask, encoder_outputs_padding_mask = None, None
else:
inputs_padding_mask, encoder_outputs_padding_mask = mask
attention_output_1 = self.attention_1(
query=inputs,
value=inputs,
key=inputs,
attention_mask=causal_mask,
query_mask=inputs_padding_mask,
)
out_1 = self.layernorm_1(inputs + attention_output_1)
attention_output_2 = self.attention_2(
query=out_1,
value=encoder_outputs,
key=encoder_outputs,
query_mask=inputs_padding_mask,
key_mask=encoder_outputs_padding_mask,
)
out_2 = self.layernorm_2(out_1 + attention_output_2)
proj_output = self.dense_proj(out_2)
return self.layernorm_3(out_2 + proj_output)
def get_causal_attention_mask(self, inputs):
input_shape = ops.shape(inputs)
batch_size, sequence_length = input_shape[0], input_shape[1]
i = ops.arange(sequence_length)[:, None]
j = ops.arange(sequence_length)
mask = ops.cast(i >= j, dtype="int32")
mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))
mult = ops.concatenate(
[ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])],
axis=0,
)
return ops.tile(mask, mult)
def get_config(self):
config = super().get_config()
config.update(
{
"embed_dim": self.embed_dim,
"latent_dim": self.latent_dim,
"num_heads": self.num_heads,
}
)
return config
with open("id_vectorization_transformer.pickle", "rb") as file:
from_disk = pickle.load(file)
id_vectorization = TextVectorization.from_config(from_disk['config'])
id_vectorization.adapt(tf.data.Dataset.from_tensor_slices(["xyz"]))
id_vectorization.set_weights(from_disk['weights'])
id_vectorization.set_vocabulary(from_disk["vocab"])
with open("en_vectorization_transformer.pickle", "rb") as file:
from_disk = pickle.load(file)
en_vectorization = TextVectorization.from_config(from_disk['config'])
en_vectorization.adapt(tf.data.Dataset.from_tensor_slices(["xyz"]))
en_vectorization.set_weights(from_disk['weights'])
en_vectorization.set_vocabulary(from_disk["vocab"])
transformer = keras.models.load_model(
"transformer_keras.keras",
custom_objects={"TransformerEncoder": TransformerEncoder, "TransformerDecoder": TransformerDecoder, "PositionalEmbedding": PositionalEmbedding}
)
id_vocab = id_vectorization.get_vocabulary()
id_index_lookup = dict(zip(range(len(id_vocab)), id_vocab))
max_decoded_sentence_lenth = 20
def decode_sequence(input_sentence):
tokenized_input_sentence = en_vectorization([input_sentence])
decoded_sentence = "[start]"
for i in range(max_decoded_sentence_lenth):
tokenized_target_sentence = id_vectorization([decoded_sentence])[:, :-1]
predictions = transformer(
{
"encoder_inputs": tokenized_input_sentence,
"decoder_inputs": tokenized_target_sentence,
}
)
sampled_token_index = ops.convert_to_numpy(
ops.argmax(predictions[0, i, :])
).item(0)
sampled_token = id_index_lookup[sampled_token_index]
decoded_sentence += " " + sampled_token
if sampled_token == "end":
break
return decoded_sentence.replace("[start]", "").replace("end", "").lstrip().rstrip()
# image = WebRTC(label="Stream")
desc=("<h2>This is a simple English to Indonesian translator app using transformer for our final Deep Learning Project.</h2>" +
"<br/> <h3 style='font-weight: bold'>Team Members:</h3>"+
"<br/> <ul> <li>2602082452 - Rendy Susanto</li>" +
"<li>2602082452 - Rendy Susanto</li></ul>")
demo = gr.Interface(
fn=decode_sequence,
inputs=gr.Textbox(label="Please input your text (English):"),
outputs=gr.Textbox(label="Output (Indonesian):"),
title="English To Indonesian Translator",
description=desc
)
demo.launch(share=True)