File size: 403 Bytes
4047d7b
 
 
 
 
 
 
 
 
 
9501156
4047d7b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from user_model.configuration import UserModelConfig
from transformers import PreTrainedModel
import tensorflow as tf


class UserModel(PreTrainedModel):
    config_class = UserModelConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = tf.saved_model.load('tf_retrieval_user_model')

    def forward(self, user_id):
        return self.model(user_id)