MO-12 commited on
Commit
d00772b
·
verified ·
1 Parent(s): 7f00975

Update API_Model.py

Browse files
Files changed (1) hide show
  1. API_Model.py +22 -3
API_Model.py CHANGED
@@ -21,13 +21,31 @@ model.eval().to(device)
21
  feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
22
  label_map = {0: "Goal", 1: "Card", 2: "Substitution"}
23
 
24
- def predict_gradio(video_path):
 
 
 
 
25
  video_id = str(uuid.uuid4())
26
  work_dir = f"./temp/{video_id}"
27
  os.makedirs(work_dir, exist_ok=True)
28
 
29
- cap = cv2.VideoCapture(video_path)
 
 
 
 
 
 
 
 
 
 
 
30
  fps = cap.get(cv2.CAP_PROP_FPS)
 
 
 
31
  frames = []
32
  while True:
33
  ret, frame = cap.read()
@@ -56,7 +74,7 @@ def predict_gradio(video_path):
56
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
57
  confidence, pred = torch.max(probs, dim=1)
58
 
59
- if confidence.item() > 0.7:
60
  label = label_map[pred.item()]
61
  start_time = i / fps
62
  end_time = min((i + segment_size), len(frames)) / fps
@@ -79,3 +97,4 @@ def predict_gradio(video_path):
79
  return predictions, out_path
80
  else:
81
  return predictions, ""
 
 
21
  feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
22
  label_map = {0: "Goal", 1: "Card", 2: "Substitution"}
23
 
24
+ def predict_gradio(video):
25
+ import tempfile
26
+ import shutil
27
+
28
+ # أنشئ مجلد مؤقت للعمل
29
  video_id = str(uuid.uuid4())
30
  work_dir = f"./temp/{video_id}"
31
  os.makedirs(work_dir, exist_ok=True)
32
 
33
+ # نحفظ الفيديو المرفوع على هيئة ملف مؤقت mp4
34
+ temp_video_path = os.path.join(work_dir, "input.mp4")
35
+ if isinstance(video, str):
36
+ # Gradio بيرسل أحيانًا مسار الملف
37
+ shutil.copy(video, temp_video_path)
38
+ else:
39
+ # Gradio بيرسل BytesIO stream (مش شائع بس نغطيه)
40
+ with open(temp_video_path, "wb") as f:
41
+ f.write(video.read())
42
+
43
+ # نحاول نفتح الفيديو
44
+ cap = cv2.VideoCapture(temp_video_path)
45
  fps = cap.get(cv2.CAP_PROP_FPS)
46
+ if fps == 0 or fps != fps: # NaN or 0
47
+ return [{"error": "Invalid or unreadable video."}], ""
48
+
49
  frames = []
50
  while True:
51
  ret, frame = cap.read()
 
74
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)
75
  confidence, pred = torch.max(probs, dim=1)
76
 
77
+ if confidence.item() > 0.70:
78
  label = label_map[pred.item()]
79
  start_time = i / fps
80
  end_time = min((i + segment_size), len(frames)) / fps
 
97
  return predictions, out_path
98
  else:
99
  return predictions, ""
100
+