File size: 494 Bytes
77c66f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import torch
from model import UNetTransformerModel, UNetConfig
# Load config and model
config = UNetConfig(in_channels=1, out_channels=3, image_size=256)
model = UNetTransformerModel(config)
# Load your existing model weights
model.model.load_state_dict(torch.load("unet_epoch20.pth", map_location="cpu"))
# Save the model and config in HF-compatible format
model.save_pretrained("brain_unet_hf")
config.save_pretrained("brain_unet_hf")
print("✅ Saved to brain_unet_hf/")
|