Salimshakeel's picture
fixing
1579b70
raw
history blame contribute delete
853 Bytes
import torch
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
from layers.summarizer import PGL_SUM
from config import DEVICE
from tqdm import tqdm
def load_model(weights_path):
model = PGL_SUM(
input_size=1024,
output_size=1024,
num_segments=4,
heads=8,
fusion="add",
pos_enc="absolute"
).to(DEVICE)
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
model.eval()
return model
def batch_inference(model, input, batch_size=128):
model.eval()
output = []
with torch.no_grad():
for i in tqdm(range(0, input.size(0), batch_size)):
batch = input[i:i + batch_size].to(DEVICE)
out = model(batch)
output.append(out.cpu())
return torch.cat(output, dim=0)