640510702phithak commited on
Commit
4355acb
·
verified ·
1 Parent(s): cbac7b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -7
app.py CHANGED
@@ -2,23 +2,20 @@ from fastapi import FastAPI
2
  from transformers import AutoTokenizer, AutoModel
3
  import torch
4
 
5
- # โหลดโมเดล Sentence-Transformer
6
  MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
7
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
- model = AutoModel.from_pretrained(MODEL_NAME)
9
 
10
- # สร้าง API
11
  app = FastAPI()
12
 
13
- # ฟังก์ชันแปลงข้อความเป็นเวกเตอร์
14
  def get_embedding(text):
15
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
16
  with torch.no_grad():
17
  outputs = model(**inputs)
18
- embedding = outputs.last_hidden_state.mean(dim=1) # ใช้ค่าเฉลี่ยของ hidden states
19
  return embedding.squeeze().tolist()
20
 
21
- # API Endpoint
22
  @app.post("/embed")
23
  async def embed_text(data: dict):
24
  text = data.get("text", "")
 
2
  from transformers import AutoTokenizer, AutoModel
3
  import torch
4
 
5
+ # เปลี่ยน cache directory เป็น /tmp
6
  MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
7
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir="/tmp")
8
+ model = AutoModel.from_pretrained(MODEL_NAME, cache_dir="/tmp")
9
 
 
10
  app = FastAPI()
11
 
 
12
  def get_embedding(text):
13
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
14
  with torch.no_grad():
15
  outputs = model(**inputs)
16
+ embedding = outputs.last_hidden_state.mean(dim=1)
17
  return embedding.squeeze().tolist()
18
 
 
19
  @app.post("/embed")
20
  async def embed_text(data: dict):
21
  text = data.get("text", "")