Krishna086's picture
Update app.py
ddd7096 verified
import streamlit as st
import torch
from transformers import BertTokenizer, BertModel
from torch import nn
import h5py
import numpy as np
import warnings
# Suppress PyTorch-related warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.generic")
warnings.filterwarnings("ignore", category=UserWarning, module="transformers.utils.generic")
# Streamlit page configuration
st.set_page_config(page_title="Paraphrase Detection App", page_icon="πŸ”", layout="centered")
# Define model architecture (matching the training notebook)
class ParaphraseModel(nn.Module):
def __init__(self):
super(ParaphraseModel, self).__init__()
self.bert = BertModel.from_pretrained('bert-base-uncased')
self.dense = nn.Linear(768, 128)
self.relu = nn.ReLU()
self.output = nn.Linear(128, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, input_ids, attention_mask, token_type_ids):
outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
pooled_output = outputs[1] # CLS token
x = self.dense(pooled_output)
x = self.relu(x)
x = self.output(x)
x = self.sigmoid(x)
return x
# Load saved model and tokenizer
try:
with open('paraphrase_model.pkl', 'rb') as f:
import pickle
saved_data = pickle.load(f)
tokenizer = saved_data['tokenizer']
model = ParaphraseModel()
# Load weights from HDF5 and apply to the model
with h5py.File(saved_data['weights'], 'r') as f:
for name, param in model.named_parameters():
if name in f:
param_data = torch.tensor(np.array(f[name]))
param.data = param_data.to(param.device)
model.eval()
except FileNotFoundError:
st.error("Error: 'paraphrase_model.pkl' or 'paraphrase_model_weights.h5' not found. Ensure both files are in the same directory as app.py.")
st.stop()
except Exception as e:
st.error(f"Error loading model: {str(e)}")
st.stop()
# Function to predict paraphrase with adjustable threshold
def predict_paraphrase(sentence1, sentence2):
if not sentence1.strip() or not sentence2.strip():
return "Please enter valid sentences in both fields.", None
encodings = tokenizer(
[sentence1], [sentence2],
max_length=128,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encodings['input_ids']
attention_mask = encodings['attention_mask']
token_type_ids = encodings['token_type_ids']
with torch.no_grad():
outputs = model(input_ids, attention_mask, token_type_ids).squeeze()
prediction = outputs.item()
# Adjust threshold to 0.6 to reduce false positives
result = "Paraphrase" if prediction > 0.6 else "Not a Paraphrase"
confidence = prediction
return result, confidence
# Streamlit UI
st.title("πŸ” Paraphrase Detection App")
st.markdown("""
Enter two sentences to check if they are paraphrases of each other.
**Example:**
- Sentence 1: "The cat is on the mat."
- Sentence 2: "The mat has a cat on it."
""")
# Input fields
sentence1 = st.text_area("Sentence 1", placeholder="Enter the first sentence", height=100, key="sentence1")
sentence2 = st.text_area("Sentence 2", placeholder="Enter the second sentence", height=100, key="sentence2")
# Predict button
if st.button("Check Paraphrase", key="predict_button"):
with st.spinner("Analyzing sentences..."):
result, confidence = predict_paraphrase(sentence1, sentence2)
if confidence is None:
st.error(result)
else:
st.markdown(f"**Prediction**: {result} \n**Confidence**: {confidence:.4f}")