File size: 1,719 Bytes
e3cf75b
e10ddaa
 
 
 
 
 
 
 
6c2eaf8
 
 
e10ddaa
 
 
 
 
 
 
 
 
9f69673
e10ddaa
 
 
 
 
 
 
 
e3cf75b
e10ddaa
 
 
 
6c2eaf8
e10ddaa
 
6c2eaf8
e3cf75b
a0f1fda
e3cf75b
a0f1fda
e3cf75b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from sentence_transformers import SentenceTransformer,util
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import pandas as pd
import os
import sys
src_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..", "src"))
sys.path.append(src_directory)
from data import sample_data
import numpy as np

model = SentenceTransformer('Alibaba-NLP/gte-base-en-v1.5', trust_remote_code=True)

encoding_model = model
logreg_model = None
X_train_embeddings = None

file_path = r"src/data/sms_process_data_main.xlsx"
df = sample_data.get_data_frame(file_path)

def train_model():
    global logreg_model, X_train_embeddings

    if logreg_model is None:
        X_train, X_test, y_train, y_test = train_test_split(df['MessageText'], df['label'], test_size=0.2, random_state=42)
        X_train_embeddings = encoding_model.encode(X_train.tolist())
        
        logreg_model = LogisticRegression(max_iter=100)
        logreg_model.fit(X_train_embeddings, y_train)

def get_prediction(message):
    if logreg_model is None:
        raise ValueError("Model has not been trained yet. Please call train_model first.")

    new_embeddings = encoding_model.encode([message])
    array = np.array(new_embeddings)[0].tolist()

    no_of_dimensions = len(new_embeddings[0])
    dimension_df = pd.DataFrame(array, columns=["Dimension"])
    prediction = logreg_model.predict(new_embeddings).tolist()
    
    return no_of_dimensions, dimension_df, prediction

def get_cosine_similarity(msg_1: str, msg_2: str):
    embeddings = encoding_model.encode([msg_1, msg_2])
    similarity = util.cos_sim(embeddings[0], embeddings[1]).item()
    return round(similarity, 4)