nguynking's picture
Update app.py
35d25ca
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from typing import List, Dict, Any
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class Predictor():
def __init__(self, path="", tokenizer_path='bert-base-uncased'):
self.model = AutoModelForSequenceClassification.from_pretrained(path, trust_remote_code=True).to(device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
def preprocess(self, inputs: List[str]):
MAX_LENGHT = 15
tokens_unseen = self.tokenizer.batch_encode_plus(
inputs,
max_length = MAX_LENGHT,
pad_to_max_length=True,
truncation=True
)
unseen_seq = torch.tensor(tokens_unseen['input_ids'])
unseen_mask = torch.tensor(tokens_unseen['attention_mask'])
return unseen_seq, unseen_mask
def postprocess(self, preds):
preds = np.argmax(preds, axis = 1)
prediction_label = "This is fake news" if preds[0] == 1 else "This is fake news"
# print(f"Đây là {prediction_label} new.")
return prediction_label
def predict(self, inputs: str):
unseen_seq, unseen_mask = self.preprocess([inputs])
with torch.no_grad():
preds = self.model(unseen_seq, unseen_mask)
preds = preds.detach().cpu().numpy()
return self.postprocess(preds)
# Instantiate a predictor
predictor = Predictor('leroyrr/fake-news-detection-bert')
# Create title and description for our task
title = "Fake News Detection Demo"
description = "Detect fake news"
article = "Created from nguyenquocviet/fake-news-detection-bert"
# Create the Gradio interface
iface = gr.Interface(fn=predictor.predict,
inputs="textbox",
outputs="textbox",
title=title,
description=description,
article=article)
# Launch the interface
iface.launch()