import cv2 import torch import numpy as np from PIL import Image from torchvision import models, transforms from config import DEVICE, FRAME_RATE from tqdm import tqdm from services.model_loader import batch_inference # Load GoogLeNet once from torchvision.models import GoogLeNet_Weights weights = GoogLeNet_Weights.DEFAULT googlenet = models.googlenet(weights=weights).to(DEVICE).eval() feature_extractor = torch.nn.Sequential( googlenet.conv1, googlenet.maxpool1, googlenet.conv2, googlenet.conv3, googlenet.maxpool2, googlenet.inception3a, googlenet.inception3b, googlenet.maxpool3, googlenet.inception4a, googlenet.inception4b, googlenet.inception4c, googlenet.inception4d, googlenet.inception4e, googlenet.maxpool4, googlenet.inception5a, googlenet.inception5b, googlenet.avgpool, torch.nn.Flatten() ) feature_extractor = feature_extractor.eval() transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) def extract_features(video_path): cap = cv2.VideoCapture(video_path) frames = [] indices = [] total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # total_frames = 300 # TEMP print(f"Total frames in video: {total_frames}") for idx in tqdm(range(total_frames)): cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if not ret: break # process frame frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) frame = transform(frame) frames.append(frame) indices.append(idx) cap.release() frames = torch.stack(frames).to(DEVICE) print("Features before GoogleNet extraction:", frames.shape) frames = batch_inference(model=feature_extractor, input=frames, batch_size=32) print("Features after GoogleNet extraction:", frames.shape) return frames, indices