|
import streamlit as st |
|
import torch |
|
from transformers import BertTokenizer, BertModel |
|
from torch import nn |
|
import h5py |
|
import numpy as np |
|
import warnings |
|
|
|
|
|
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.generic") |
|
warnings.filterwarnings("ignore", category=UserWarning, module="transformers.utils.generic") |
|
|
|
|
|
st.set_page_config(page_title="Paraphrase Detection App", page_icon="π", layout="centered") |
|
|
|
|
|
class ParaphraseModel(nn.Module): |
|
def __init__(self): |
|
super(ParaphraseModel, self).__init__() |
|
self.bert = BertModel.from_pretrained('bert-base-uncased') |
|
self.dense = nn.Linear(768, 128) |
|
self.relu = nn.ReLU() |
|
self.output = nn.Linear(128, 1) |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids): |
|
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) |
|
pooled_output = outputs[1] |
|
x = self.dense(pooled_output) |
|
x = self.relu(x) |
|
x = self.output(x) |
|
x = self.sigmoid(x) |
|
return x |
|
|
|
|
|
try: |
|
with open('paraphrase_model.pkl', 'rb') as f: |
|
import pickle |
|
saved_data = pickle.load(f) |
|
tokenizer = saved_data['tokenizer'] |
|
model = ParaphraseModel() |
|
|
|
|
|
with h5py.File(saved_data['weights'], 'r') as f: |
|
for name, param in model.named_parameters(): |
|
if name in f: |
|
param_data = torch.tensor(np.array(f[name])) |
|
param.data = param_data.to(param.device) |
|
model.eval() |
|
except FileNotFoundError: |
|
st.error("Error: 'paraphrase_model.pkl' or 'paraphrase_model_weights.h5' not found. Ensure both files are in the same directory as app.py.") |
|
st.stop() |
|
except Exception as e: |
|
st.error(f"Error loading model: {str(e)}") |
|
st.stop() |
|
|
|
|
|
def predict_paraphrase(sentence1, sentence2): |
|
if not sentence1.strip() or not sentence2.strip(): |
|
return "Please enter valid sentences in both fields.", None |
|
encodings = tokenizer( |
|
[sentence1], [sentence2], |
|
max_length=128, |
|
padding='max_length', |
|
truncation=True, |
|
return_tensors='pt' |
|
) |
|
input_ids = encodings['input_ids'] |
|
attention_mask = encodings['attention_mask'] |
|
token_type_ids = encodings['token_type_ids'] |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids, attention_mask, token_type_ids).squeeze() |
|
prediction = outputs.item() |
|
|
|
result = "Paraphrase" if prediction > 0.6 else "Not a Paraphrase" |
|
confidence = prediction |
|
return result, confidence |
|
|
|
|
|
st.title("π Paraphrase Detection App") |
|
st.markdown(""" |
|
Enter two sentences to check if they are paraphrases of each other. |
|
**Example:** |
|
- Sentence 1: "The cat is on the mat." |
|
- Sentence 2: "The mat has a cat on it." |
|
""") |
|
|
|
|
|
sentence1 = st.text_area("Sentence 1", placeholder="Enter the first sentence", height=100, key="sentence1") |
|
sentence2 = st.text_area("Sentence 2", placeholder="Enter the second sentence", height=100, key="sentence2") |
|
|
|
|
|
if st.button("Check Paraphrase", key="predict_button"): |
|
with st.spinner("Analyzing sentences..."): |
|
result, confidence = predict_paraphrase(sentence1, sentence2) |
|
if confidence is None: |
|
st.error(result) |
|
else: |
|
st.markdown(f"**Prediction**: {result} \n**Confidence**: {confidence:.4f}") |