File size: 403 Bytes
4047d7b 9501156 4047d7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
from user_model.configuration import UserModelConfig
from transformers import PreTrainedModel
import tensorflow as tf
class UserModel(PreTrainedModel):
config_class = UserModelConfig
def __init__(self, config):
super().__init__(config)
self.model = tf.saved_model.load('tf_retrieval_user_model')
def forward(self, user_id):
return self.model(user_id) |