File size: 751 Bytes
56238f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from typing import Dict, Any, Optional

import torch
import torch.nn as nn


import logging
logger = logging.getLogger(__name__)

class ModelLoader:
    def __init__(self,):
        super().__init__()

    def load(self, denoiser):
        if denoiser.weight_path:
            weight = torch.load(denoiser.weight_path, map_location=torch.device('cpu'))

            if denoiser.load_ema:
                prefix = "ema_denoiser."
            else:
                prefix = "denoiser."
            for k, v in denoiser.state_dict().items():
                try:
                    v.copy_(weight["state_dict"][prefix+k])
                except:
                    logger.warning(f"Failed to copy {prefix+k} to denoiser weight")
        return denoiser