import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.optim import AdamW import numpy as np import itertools from .warplayer import warp from torch.nn.parallel import DistributedDataParallel as DDP from .IFNet_HDv3 import * from .loss import * import devicetorch device = devicetorch.get(torch) class Model: def __init__(self, local_rank=-1): self.flownet = IFNet() self.device() self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4) self.epe = EPE() # self.vgg = VGGPerceptualLoss().to(device) self.sobel = SOBEL() if local_rank != -1: self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) def train(self): self.flownet.train() def eval(self): self.flownet.eval() def device(self): self.flownet.to(device) def load_model(self, path, rank=0): def convert(param): if rank == -1: return {k.replace("module.", ""): v for k, v in param.items() if "module." in k} else: return param if rank <= 0: model_path = "{}/flownet.pkl".format(path) # Check PyTorch version to safely use weights_only from packaging import version use_weights_only = version.parse(torch.__version__) >= version.parse("1.13") load_kwargs = {} if not torch.cuda.is_available(): load_kwargs['map_location'] = "cpu" if use_weights_only: # For modern PyTorch, be explicit and safe load_kwargs['weights_only'] = True # print(f"PyTorch >= 1.13 detected. Loading RIFE model with weights_only=True.") state_dict = torch.load(model_path, **load_kwargs) else: # For older PyTorch, load the old way print(f"PyTorch < 1.13 detected. Loading RIFE model using legacy method.") state_dict = torch.load(model_path, **load_kwargs) self.flownet.load_state_dict(convert(state_dict)) def inference(self, img0, img1, scale=1.0): imgs = torch.cat((img0, img1), 1) scale_list = [4 / scale, 2 / scale, 1 / scale] flow, mask, merged = self.flownet(imgs, scale_list) return merged[2] def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): for param_group in self.optimG.param_groups: param_group["lr"] = learning_rate img0 = imgs[:, :3] img1 = imgs[:, 3:] if training: self.train() else: self.eval() scale = [4, 2, 1] flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training) loss_l1 = (merged[2] - gt).abs().mean() loss_smooth = self.sobel(flow[2], flow[2] * 0).mean() # loss_vgg = self.vgg(merged[2], gt) if training: self.optimG.zero_grad() loss_G = loss_cons + loss_smooth * 0.1 loss_G.backward() self.optimG.step() else: flow_teacher = flow[2] return merged[2], { "mask": mask, "flow": flow[2][:, :2], "loss_l1": loss_l1, "loss_cons": loss_cons, "loss_smooth": loss_smooth, }