|
import cv2 |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
from torchvision import models, transforms |
|
from config import DEVICE, FRAME_RATE |
|
from tqdm import tqdm |
|
from services.model_loader import batch_inference |
|
|
|
|
|
from torchvision.models import GoogLeNet_Weights |
|
weights = GoogLeNet_Weights.DEFAULT |
|
googlenet = models.googlenet(weights=weights).to(DEVICE).eval() |
|
|
|
feature_extractor = torch.nn.Sequential( |
|
googlenet.conv1, |
|
googlenet.maxpool1, |
|
googlenet.conv2, |
|
googlenet.conv3, |
|
googlenet.maxpool2, |
|
googlenet.inception3a, |
|
googlenet.inception3b, |
|
googlenet.maxpool3, |
|
googlenet.inception4a, |
|
googlenet.inception4b, |
|
googlenet.inception4c, |
|
googlenet.inception4d, |
|
googlenet.inception4e, |
|
googlenet.maxpool4, |
|
googlenet.inception5a, |
|
googlenet.inception5b, |
|
googlenet.avgpool, |
|
torch.nn.Flatten() |
|
) |
|
feature_extractor = feature_extractor.eval() |
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225] |
|
) |
|
]) |
|
|
|
def extract_features(video_path): |
|
cap = cv2.VideoCapture(video_path) |
|
frames = [] |
|
indices = [] |
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
print(f"Total frames in video: {total_frames}") |
|
|
|
for idx in tqdm(range(total_frames)): |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
frame = transform(frame) |
|
|
|
frames.append(frame) |
|
indices.append(idx) |
|
|
|
cap.release() |
|
|
|
frames = torch.stack(frames).to(DEVICE) |
|
print("Features before GoogleNet extraction:", frames.shape) |
|
frames = batch_inference(model=feature_extractor, input=frames, batch_size=32) |
|
print("Features after GoogleNet extraction:", frames.shape) |
|
|
|
return frames, indices |
|
|