File size: 1,847 Bytes
0f20058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import gradio as gr
import torch 
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("BigSalmon/InformalToFormalLincoln91Paraphrase")
model = AutoModelForCausalLM.from_pretrained("BigSalmon/InformalToFormalLincoln91Paraphrase")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

def main_generator(text):
    text = tokenizer.encode(text)
    myinput, past_key_values = torch.tensor([text]), None
    myinput = myinput
    #myinput= myinput
    logits, past_key_values = model(myinput, past_key_values = past_key_values, return_dict=False)
    logits = logits[0,-1]
    probabilities = torch.nn.functional.softmax(logits)
    best_logits, best_indices = logits.topk(number_of_outputs)
    best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
    return best_words
  
inputs = [gr.Textbox(lines=1, placeholder="Text Here...", label="Input")]
outputs = gr.Text( label="Insert text")
title="Get the next most likely word"
description = "Get the next most likely word"
examples = ['I wonder']

io = gr.Interface(fn=main_generator, inputs=inputs, outputs=outputs, title=title, description = description, examples = examples,

                  css= """.gr-button-primary { background: -webkit-linear-gradient( 
                    90deg, #355764 0%, #55a8a1 100% ) !important;     background: #355764;
                        background: linear-gradient( 
                    90deg, #355764 0%, #55a8a1 100% ) !important;
                        background: -moz-linear-gradient( 90deg, #355764 0%, #55a8a1 100% ) !important;
                        background: -webkit-linear-gradient( 
                    90deg, #355764 0%, #55a8a1 100% ) !important;
                    color:white !important}"""
                  )
  
io.launch()