MO-12 commited on
Commit
0827c75
·
verified ·
1 Parent(s): d10e811

Upload 4 files

Browse files
Files changed (4) hide show
  1. API_Model.py +81 -0
  2. README.md +10 -12
  3. app.py +15 -0
  4. requirements.txt +6 -0
API_Model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor
3
+ import os, cv2, uuid, json
4
+ import numpy as np
5
+ import gdown
6
+
7
+ model_path = "checkpoint_epoch_1.pt"
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # تحميل النموذج من Google Drive لو مش موجود
11
+ if not os.path.exists(model_path):
12
+ print("Downloading checkpoint...")
13
+ url = "https://drive.google.com/uc?id=1dIaptYPq-1fgo0yoBoPlDsbIfs3BEqJI"
14
+ gdown.download(url, model_path, quiet=False)
15
+
16
+ model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base", num_labels=3)
17
+ checkpoint = torch.load(model_path, map_location=device)
18
+ model.load_state_dict(checkpoint["model_state_dict"])
19
+ model.eval().to(device)
20
+
21
+ feature_extractor = VideoMAEFeatureExtractor.from_pretrained("MCG-NJU/videomae-base")
22
+ label_map = {0: "Goal", 1: "Card", 2: "Substitution"}
23
+
24
+ def predict_gradio(video_path):
25
+ video_id = str(uuid.uuid4())
26
+ work_dir = f"./temp/{video_id}"
27
+ os.makedirs(work_dir, exist_ok=True)
28
+
29
+ cap = cv2.VideoCapture(video_path)
30
+ fps = cap.get(cv2.CAP_PROP_FPS)
31
+ frames = []
32
+ while True:
33
+ ret, frame = cap.read()
34
+ if not ret:
35
+ break
36
+ resized = cv2.resize(frame, (224, 224))
37
+ frames.append(resized)
38
+ cap.release()
39
+
40
+ segment_size = int(fps * 5)
41
+ predictions = []
42
+ output_segments = []
43
+
44
+ for i in range(0, len(frames), segment_size):
45
+ segment = frames[i:i+segment_size]
46
+ if len(segment) < 16:
47
+ continue
48
+ indices = np.linspace(0, len(segment)-1, 16).astype(int)
49
+ sampled_frames = [segment[idx] for idx in indices]
50
+
51
+ inputs = feature_extractor(sampled_frames, return_tensors="pt")
52
+ inputs = {k: v.to(device) for k, v in inputs.items()}
53
+
54
+ with torch.no_grad():
55
+ outputs = model(**inputs)
56
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
57
+ confidence, pred = torch.max(probs, dim=1)
58
+
59
+ if confidence.item() > 0.7:
60
+ label = label_map[pred.item()]
61
+ start_time = i / fps
62
+ end_time = min((i + segment_size), len(frames)) / fps
63
+ predictions.append({
64
+ "start": round(start_time, 2),
65
+ "end": round(end_time, 2),
66
+ "label": label,
67
+ "confidence": round(confidence.item(), 3)
68
+ })
69
+ output_segments.append(segment)
70
+
71
+ out_path = f"{work_dir}/summary.mp4"
72
+ if output_segments:
73
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
74
+ out = cv2.VideoWriter(out_path, fourcc, fps, (224, 224))
75
+ for seg in output_segments:
76
+ for frame in seg:
77
+ out.write(frame)
78
+ out.release()
79
+ return predictions, out_path
80
+ else:
81
+ return predictions, None
README.md CHANGED
@@ -1,12 +1,10 @@
1
- ---
2
- title: Api Model
3
- emoji: 👁
4
- colorFrom: gray
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.33.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # VideoMAE Football Event Classifier
2
+
3
+ This app detects important football events like Goal, Card, and Substitution in uploaded videos using a pretrained VideoMAE transformer.
4
+
5
+ ## How to use:
6
+ - Upload a `.mp4` video of a football match.
7
+ - Wait for the model to analyze the video.
8
+ - You'll receive a list of events (with timestamps) and a summary video with detected clips.
9
+
10
+ Model: [VideoMAE](https://huggingface.co/MCG-NJU/videomae-base)
 
 
app.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from API_Model import predict_gradio
3
+
4
+ iface = gr.Interface(
5
+ fn=predict_gradio,
6
+ inputs=gr.Video(label="Upload Video (mp4)"),
7
+ outputs=[
8
+ gr.JSON(label="Prediction Results"),
9
+ gr.Video(label="Summary Video")
10
+ ],
11
+ title="VideoMAE Action Classifier",
12
+ description="Detect Goal / Card / Substitution segments in a football video using VideoMAE"
13
+ )
14
+
15
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ opencv-python
4
+ numpy
5
+ gdown
6
+ gradio