Update API_Model.py
Browse files- API_Model.py +22 -3
API_Model.py
CHANGED
@@ -21,13 +21,31 @@ model.eval().to(device)
|
|
21 |
feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
|
22 |
label_map = {0: "Goal", 1: "Card", 2: "Substitution"}
|
23 |
|
24 |
-
def predict_gradio(
|
|
|
|
|
|
|
|
|
25 |
video_id = str(uuid.uuid4())
|
26 |
work_dir = f"./temp/{video_id}"
|
27 |
os.makedirs(work_dir, exist_ok=True)
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
|
|
|
|
|
31 |
frames = []
|
32 |
while True:
|
33 |
ret, frame = cap.read()
|
@@ -56,7 +74,7 @@ def predict_gradio(video_path):
|
|
56 |
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
|
57 |
confidence, pred = torch.max(probs, dim=1)
|
58 |
|
59 |
-
if confidence.item() > 0.
|
60 |
label = label_map[pred.item()]
|
61 |
start_time = i / fps
|
62 |
end_time = min((i + segment_size), len(frames)) / fps
|
@@ -79,3 +97,4 @@ def predict_gradio(video_path):
|
|
79 |
return predictions, out_path
|
80 |
else:
|
81 |
return predictions, ""
|
|
|
|
21 |
feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
|
22 |
label_map = {0: "Goal", 1: "Card", 2: "Substitution"}
|
23 |
|
24 |
+
def predict_gradio(video):
|
25 |
+
import tempfile
|
26 |
+
import shutil
|
27 |
+
|
28 |
+
# أنشئ مجلد مؤقت للعمل
|
29 |
video_id = str(uuid.uuid4())
|
30 |
work_dir = f"./temp/{video_id}"
|
31 |
os.makedirs(work_dir, exist_ok=True)
|
32 |
|
33 |
+
# نحفظ الفيديو المرفوع على هيئة ملف مؤقت mp4
|
34 |
+
temp_video_path = os.path.join(work_dir, "input.mp4")
|
35 |
+
if isinstance(video, str):
|
36 |
+
# Gradio بيرسل أحيانًا مسار الملف
|
37 |
+
shutil.copy(video, temp_video_path)
|
38 |
+
else:
|
39 |
+
# Gradio بيرسل BytesIO stream (مش شائع بس نغطيه)
|
40 |
+
with open(temp_video_path, "wb") as f:
|
41 |
+
f.write(video.read())
|
42 |
+
|
43 |
+
# نحاول نفتح الفيديو
|
44 |
+
cap = cv2.VideoCapture(temp_video_path)
|
45 |
fps = cap.get(cv2.CAP_PROP_FPS)
|
46 |
+
if fps == 0 or fps != fps: # NaN or 0
|
47 |
+
return [{"error": "Invalid or unreadable video."}], ""
|
48 |
+
|
49 |
frames = []
|
50 |
while True:
|
51 |
ret, frame = cap.read()
|
|
|
74 |
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
|
75 |
confidence, pred = torch.max(probs, dim=1)
|
76 |
|
77 |
+
if confidence.item() > 0.70:
|
78 |
label = label_map[pred.item()]
|
79 |
start_time = i / fps
|
80 |
end_time = min((i + segment_size), len(frames)) / fps
|
|
|
97 |
return predictions, out_path
|
98 |
else:
|
99 |
return predictions, ""
|
100 |
+
|