import torch import sys import os sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) from layers.summarizer import PGL_SUM from config import DEVICE from tqdm import tqdm def load_model(weights_path): model = PGL_SUM( input_size=1024, output_size=1024, num_segments=4, heads=8, fusion="add", pos_enc="absolute" ).to(DEVICE) model.load_state_dict(torch.load(weights_path, map_location=DEVICE)) model.eval() return model def batch_inference(model, input, batch_size=128): model.eval() output = [] with torch.no_grad(): for i in tqdm(range(0, input.size(0), batch_size)): batch = input[i:i + batch_size].to(DEVICE) out = model(batch) output.append(out.cpu()) return torch.cat(output, dim=0)