PixNerd / src /utils /model_loader.py
wangshuai6
init
56238f0
raw
history blame contribute delete
751 Bytes
from typing import Dict, Any, Optional
import torch
import torch.nn as nn
import logging
logger = logging.getLogger(__name__)
class ModelLoader:
def __init__(self,):
super().__init__()
def load(self, denoiser):
if denoiser.weight_path:
weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu'))
if denoiser.load_ema:
prefix = "ema_denoiser."
else:
prefix = "denoiser."
for k, v in denoiser.state_dict().items():
try:
v.copy_(weight["state_dict"][prefix+k])
except:
logger.warning(f"Failed to copy {prefix+k} to denoiser weight")
return denoiser