File size: 3,767 Bytes
3358274
 
 
 
768966a
 
3358274
 
 
 
 
 
 
 
 
7a778d8
3358274
 
 
 
 
 
 
 
 
 
 
7a778d8
3358274
 
 
 
 
 
 
 
768966a
 
 
 
3358274
768966a
7a778d8
768966a
 
7a778d8
 
 
3358274
 
768966a
3358274
 
 
 
 
ddd7096
3358274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddd7096
 
3358274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a778d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import streamlit as st
import torch
from transformers import BertTokenizer, BertModel
from torch import nn
import h5py
import numpy as np
import warnings

# Suppress PyTorch-related warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.generic")
warnings.filterwarnings("ignore", category=UserWarning, module="transformers.utils.generic")

# Streamlit page configuration
st.set_page_config(page_title="Paraphrase Detection App", page_icon="πŸ”", layout="centered")

# Define model architecture (matching the training notebook)
class ParaphraseModel(nn.Module):
    def __init__(self):
        super(ParaphraseModel, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dense = nn.Linear(768, 128)
        self.relu = nn.ReLU()
        self.output = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_output = outputs[1]  # CLS token
        x = self.dense(pooled_output)
        x = self.relu(x)
        x = self.output(x)
        x = self.sigmoid(x)
        return x

# Load saved model and tokenizer
try:
    with open('paraphrase_model.pkl', 'rb') as f:
        import pickle
        saved_data = pickle.load(f)
    tokenizer = saved_data['tokenizer']
    model = ParaphraseModel()
    
    # Load weights from HDF5 and apply to the model
    with h5py.File(saved_data['weights'], 'r') as f:
        for name, param in model.named_parameters():
            if name in f:
                param_data = torch.tensor(np.array(f[name]))
                param.data = param_data.to(param.device)
    model.eval()
except FileNotFoundError:
    st.error("Error: 'paraphrase_model.pkl' or 'paraphrase_model_weights.h5' not found. Ensure both files are in the same directory as app.py.")
    st.stop()
except Exception as e:
    st.error(f"Error loading model: {str(e)}")
    st.stop()

# Function to predict paraphrase with adjustable threshold
def predict_paraphrase(sentence1, sentence2):
    if not sentence1.strip() or not sentence2.strip():
        return "Please enter valid sentences in both fields.", None
    encodings = tokenizer(
        [sentence1], [sentence2],
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    input_ids = encodings['input_ids']
    attention_mask = encodings['attention_mask']
    token_type_ids = encodings['token_type_ids']
    
    with torch.no_grad():
        outputs = model(input_ids, attention_mask, token_type_ids).squeeze()
        prediction = outputs.item()
        # Adjust threshold to 0.6 to reduce false positives
        result = "Paraphrase" if prediction > 0.6 else "Not a Paraphrase"
        confidence = prediction
    return result, confidence

# Streamlit UI
st.title("πŸ” Paraphrase Detection App")
st.markdown("""
Enter two sentences to check if they are paraphrases of each other.  
**Example:**  
- Sentence 1: "The cat is on the mat."  
- Sentence 2: "The mat has a cat on it."  
""")

# Input fields
sentence1 = st.text_area("Sentence 1", placeholder="Enter the first sentence", height=100, key="sentence1")
sentence2 = st.text_area("Sentence 2", placeholder="Enter the second sentence", height=100, key="sentence2")

# Predict button
if st.button("Check Paraphrase", key="predict_button"):
    with st.spinner("Analyzing sentences..."):
        result, confidence = predict_paraphrase(sentence1, sentence2)
        if confidence is None:
            st.error(result)
        else:
            st.markdown(f"**Prediction**: {result}  \n**Confidence**: {confidence:.4f}")