import torch | |
from model import UNetTransformerModel, UNetConfig | |
# Load config and model | |
config = UNetConfig(in_channels=1, out_channels=3, image_size=256) | |
model = UNetTransformerModel(config) | |
# Load your existing model weights | |
model.model.load_state_dict(torch.load("unet_epoch20.pth", map_location="cpu")) | |
# Save the model and config in HF-compatible format | |
model.save_pretrained("brain_unet_hf") | |
config.save_pretrained("brain_unet_hf") | |
print("✅ Saved to brain_unet_hf/") | |