Spaces:
Sleeping
Sleeping
import os | |
from collections import OrderedDict | |
import torch | |
from .u2net_cloth_segm import U2NET | |
def load_cloth_segm_model(device, checkpoint_path, in_ch=3, out_ch=1): | |
if not os.path.exists(checkpoint_path): | |
print("Invalid path") | |
return | |
model = U2NET(in_ch=in_ch, out_ch=out_ch) | |
model_state_dict = torch.load(checkpoint_path, map_location=device) | |
new_state_dict = OrderedDict() | |
for k, v in model_state_dict.items(): | |
name = k[7:] # remove `module.` | |
new_state_dict[name] = v | |
model.load_state_dict(new_state_dict) | |
model = model.to(device=device) | |
print("Checkpoints loaded from path: {}".format(checkpoint_path)) | |
return model | |