yaya36095 commited on
Commit
ba3f2d0
·
verified ·
1 Parent(s): 62624b4

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +16 -20
handler.py CHANGED
@@ -5,25 +5,24 @@ from typing import Dict, Any
5
 
6
  import torch
7
  from PIL import Image
8
- from safetensors.torch import load_file
9
  from timm import create_model
10
  from torchvision import transforms
11
 
12
 
13
  class EndpointHandler:
14
- """Custom image-classification pipeline for Hugging Face Inference Endpoints."""
15
 
16
  # --------------------------------------------------
17
- # 1) تحميل النموذج والوزن مرة واحدة عند تشغيل الـ Endpoint
18
  # --------------------------------------------------
19
  def __init__(self, model_dir: str) -> None:
20
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
- # وزن محفوظ بصيغة safetensors
23
- weights_path = os.path.join(model_dir, "model.safetensors")
24
- state_dict = load_file(weights_path)
25
 
26
- # أنشئ نفس معماريّة ViT التى درّبتها (num_classes = 5)
27
  self.model = create_model("vit_base_patch16_224", num_classes=5)
28
  self.model.load_state_dict(state_dict)
29
  self.model.eval().to(self.device)
@@ -47,25 +46,23 @@ class EndpointHandler:
47
  # --------------------------------------------------
48
  # 2) دوال مساعدة
49
  # --------------------------------------------------
50
- def _image_from_bytes(self, b: bytes) -> Image.Image:
51
- """decode base64 → PIL"""
52
- return Image.open(io.BytesIO(base64.b64decode(b)))
53
-
54
- def _to_tensor(self, img: Image.Image) -> torch.Tensor:
55
- """PIL → tensor (1 × 3 × 224 × 224) على نفس الجهاز"""
56
  return self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device)
57
 
 
 
 
58
  # --------------------------------------------------
59
- # 3) الدالة الرئيسة التى تستدعيها المنصّة لكل طلب
60
  # --------------------------------------------------
61
  def __call__(self, data: Any) -> Dict[str, float]:
62
  """
63
  يدعم:
64
- • Widget — يمرّر PIL.Image مباشرةً
65
- • REST يمرّر dict وفيه مفتاح "inputs" أو "image" (base64)
66
  """
67
- # — الحصول على صورة PIL —
68
  img: Image.Image | None = None
 
69
  if isinstance(data, Image.Image):
70
  img = data
71
  elif isinstance(data, dict):
@@ -73,14 +70,13 @@ class EndpointHandler:
73
  if isinstance(payload, (str, bytes)):
74
  if isinstance(payload, str):
75
  payload = payload.encode()
76
- img = self._image_from_bytes(payload)
77
 
78
  if img is None:
79
  return {"error": "No image provided"}
80
 
81
- # — الاستدلال —
82
  with torch.no_grad():
83
- logits = self.model(self._to_tensor(img))
84
  probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
85
 
86
  return {self.labels[i]: float(probs[i]) for i in range(len(self.labels))}
 
5
 
6
  import torch
7
  from PIL import Image
 
8
  from timm import create_model
9
  from torchvision import transforms
10
 
11
 
12
  class EndpointHandler:
13
+ """Custom ViT image-classifier for Hugging Face Inference Endpoints."""
14
 
15
  # --------------------------------------------------
16
+ # 1) تحميل النموذج والوزن مرة واحدة
17
  # --------------------------------------------------
18
  def __init__(self, model_dir: str) -> None:
19
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
+ # ─── تحميل وزن ViT الصحيح ───
22
+ weights_path = os.path.join(model_dir, "pytorch_model.bin") # موجود بالفعل
23
+ state_dict = torch.load(weights_path, map_location="cpu")
24
 
25
+ # إنشاء ViT Base Patch-16 بعدد فئات 5
26
  self.model = create_model("vit_base_patch16_224", num_classes=5)
27
  self.model.load_state_dict(state_dict)
28
  self.model.eval().to(self.device)
 
46
  # --------------------------------------------------
47
  # 2) دوال مساعدة
48
  # --------------------------------------------------
49
+ def _img_to_tensor(self, img: Image.Image) -> torch.Tensor:
 
 
 
 
 
50
  return self.preprocess(img.convert("RGB")).unsqueeze(0).to(self.device)
51
 
52
+ def _decode_b64(self, b: bytes) -> Image.Image:
53
+ return Image.open(io.BytesIO(base64.b64decode(b)))
54
+
55
  # --------------------------------------------------
56
+ # 3) الدالة الرئيسة
57
  # --------------------------------------------------
58
  def __call__(self, data: Any) -> Dict[str, float]:
59
  """
60
  يدعم:
61
+ • Widget (PIL.Image)
62
+ • REST (base64 فى data["inputs"] أو data["image"])
63
  """
 
64
  img: Image.Image | None = None
65
+
66
  if isinstance(data, Image.Image):
67
  img = data
68
  elif isinstance(data, dict):
 
70
  if isinstance(payload, (str, bytes)):
71
  if isinstance(payload, str):
72
  payload = payload.encode()
73
+ img = self._decode_b64(payload)
74
 
75
  if img is None:
76
  return {"error": "No image provided"}
77
 
 
78
  with torch.no_grad():
79
+ logits = self.model(self._img_to_tensor(img))
80
  probs = torch.nn.functional.softmax(logits.squeeze(0), dim=0)
81
 
82
  return {self.labels[i]: float(probs[i]) for i in range(len(self.labels))}