|
import torch |
|
import sys |
|
import os |
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) |
|
from layers.summarizer import PGL_SUM |
|
from config import DEVICE |
|
from tqdm import tqdm |
|
|
|
def load_model(weights_path): |
|
model = PGL_SUM( |
|
input_size=1024, |
|
output_size=1024, |
|
num_segments=4, |
|
heads=8, |
|
fusion="add", |
|
pos_enc="absolute" |
|
).to(DEVICE) |
|
model.load_state_dict(torch.load(weights_path, map_location=DEVICE)) |
|
model.eval() |
|
return model |
|
|
|
def batch_inference(model, input, batch_size=128): |
|
model.eval() |
|
output = [] |
|
with torch.no_grad(): |
|
for i in tqdm(range(0, input.size(0), batch_size)): |
|
batch = input[i:i + batch_size].to(DEVICE) |
|
out = model(batch) |
|
output.append(out.cpu()) |
|
return torch.cat(output, dim=0) |
|
|