File size: 494 Bytes
77c66f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from model import UNetTransformerModel, UNetConfig

# Load config and model
config = UNetConfig(in_channels=1, out_channels=3, image_size=256)
model = UNetTransformerModel(config)

# Load your existing model weights
model.model.load_state_dict(torch.load("unet_epoch20.pth", map_location="cpu"))

# Save the model and config in HF-compatible format
model.save_pretrained("brain_unet_hf")
config.save_pretrained("brain_unet_hf")

print("✅ Saved to brain_unet_hf/")