rahul7star's picture
Migrated from GitHub
05fcd0f verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import AdamW
import numpy as np
import itertools
from .warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from .IFNet_HDv3 import *
from .loss import *
import devicetorch
device = devicetorch.get(torch)
class Model:
def __init__(self, local_rank=-1):
self.flownet = IFNet()
self.device()
self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
self.epe = EPE()
# self.vgg = VGGPerceptualLoss().to(device)
self.sobel = SOBEL()
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
def train(self):
self.flownet.train()
def eval(self):
self.flownet.eval()
def device(self):
self.flownet.to(device)
def load_model(self, path, rank=0):
def convert(param):
if rank == -1:
return {k.replace("module.", ""): v for k, v in param.items() if "module." in k}
else:
return param
if rank <= 0:
model_path = "{}/flownet.pkl".format(path)
# Check PyTorch version to safely use weights_only
from packaging import version
use_weights_only = version.parse(torch.__version__) >= version.parse("1.13")
load_kwargs = {}
if not torch.cuda.is_available():
load_kwargs['map_location'] = "cpu"
if use_weights_only:
# For modern PyTorch, be explicit and safe
load_kwargs['weights_only'] = True
# print(f"PyTorch >= 1.13 detected. Loading RIFE model with weights_only=True.")
state_dict = torch.load(model_path, **load_kwargs)
else:
# For older PyTorch, load the old way
print(f"PyTorch < 1.13 detected. Loading RIFE model using legacy method.")
state_dict = torch.load(model_path, **load_kwargs)
self.flownet.load_state_dict(convert(state_dict))
def inference(self, img0, img1, scale=1.0):
imgs = torch.cat((img0, img1), 1)
scale_list = [4 / scale, 2 / scale, 1 / scale]
flow, mask, merged = self.flownet(imgs, scale_list)
return merged[2]
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group["lr"] = learning_rate
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if training:
self.train()
else:
self.eval()
scale = [4, 2, 1]
flow, mask, merged = self.flownet(torch.cat((imgs, gt), 1), scale=scale, training=training)
loss_l1 = (merged[2] - gt).abs().mean()
loss_smooth = self.sobel(flow[2], flow[2] * 0).mean()
# loss_vgg = self.vgg(merged[2], gt)
if training:
self.optimG.zero_grad()
loss_G = loss_cons + loss_smooth * 0.1
loss_G.backward()
self.optimG.step()
else:
flow_teacher = flow[2]
return merged[2], {
"mask": mask,
"flow": flow[2][:, :2],
"loss_l1": loss_l1,
"loss_cons": loss_cons,
"loss_smooth": loss_smooth,
}