darpanaswal commited on
Commit
e68549b
·
verified ·
1 Parent(s): 3836d32

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +1 -2
cross_encoder_reranking_train.py CHANGED
@@ -13,8 +13,7 @@ from sklearn.metrics.pairwise import cosine_similarity
13
 
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  # Load embedder once
16
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
17
- embedder = embedder.to(device)
18
 
19
  def embed_text_list(texts):
20
  return embedder.encode(texts, convert_to_tensor=False, device=device)
 
13
 
14
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
  # Load embedder once
16
+ embedder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").to(device)
 
17
 
18
  def embed_text_list(texts):
19
  return embedder.encode(texts, convert_to_tensor=False, device=device)