Add similarity endpoint

#1
by Tonycrux - opened
Files changed (1) hide show
  1. app.py +18 -1
app.py CHANGED
@@ -2,7 +2,8 @@ import re
2
  import joblib
3
  from fastapi import FastAPI, Request
4
  from pydantic import BaseModel
5
- from fastapi.middleware.cors import CORSMiddleware
 
6
 
7
  app = FastAPI()
8
 
@@ -16,6 +17,7 @@ app.add_middleware(
16
  # Load model and vectorizer
17
  model = joblib.load("team_classifier_model.joblib")
18
  vectorizer = joblib.load("tfidf_vectorizer.joblib")
 
19
 
20
 
21
  def clean_text(text):
@@ -28,13 +30,28 @@ class InputText(BaseModel):
28
  subject: str
29
  message: str
30
 
 
 
 
 
 
 
31
  @app.get("/")
32
  def root():
33
  return {"status": "running", "message": "Use POST /classify"}
34
 
 
35
  @app.post("/classify")
36
  async def classify_ticket(data: InputText):
37
  combined = clean_text(f"{data.subject} {data.message}")
38
  vec = vectorizer.transform([combined])
39
  prediction = model.predict(vec)[0]
40
  return {"team": prediction}
 
 
 
 
 
 
 
 
 
2
  import joblib
3
  from fastapi import FastAPI, Request
4
  from pydantic import BaseModel
5
+ from fastapi.middleware.cors import CORSMiddleware.
6
+ from sentence_transformers import SentenceTransformer, util
7
 
8
  app = FastAPI()
9
 
 
17
  # Load model and vectorizer
18
  model = joblib.load("team_classifier_model.joblib")
19
  vectorizer = joblib.load("tfidf_vectorizer.joblib")
20
+ sbert_model = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
21
 
22
 
23
  def clean_text(text):
 
30
  subject: str
31
  message: str
32
 
33
+
34
+ class SimilarityRequest(BaseModel):
35
+ text1: str
36
+ text2: str
37
+
38
+
39
  @app.get("/")
40
  def root():
41
  return {"status": "running", "message": "Use POST /classify"}
42
 
43
+
44
  @app.post("/classify")
45
  async def classify_ticket(data: InputText):
46
  combined = clean_text(f"{data.subject} {data.message}")
47
  vec = vectorizer.transform([combined])
48
  prediction = model.predict(vec)[0]
49
  return {"team": prediction}
50
+
51
+
52
+ @app.post("/similarity")
53
+ async def compute_similarity(data: SimilarityRequest):
54
+ emb1 = sbert_model.encode(data.text1, convert_to_tensor=True)
55
+ emb2 = sbert_model.encode(data.text2, convert_to_tensor=True)
56
+ score = util.pytorch_cos_sim(emb1, emb2).item()
57
+ return {"similarity": score}