brain-unet-model / hf_wrapper /convert_save.py
AndaiMD's picture
Upload folder using huggingface_hub
77c66f1 verified
raw
history blame contribute delete
494 Bytes
import torch
from model import UNetTransformerModel, UNetConfig
# Load config and model
config = UNetConfig(in_channels=1, out_channels=3, image_size=256)
model = UNetTransformerModel(config)
# Load your existing model weights
model.model.load_state_dict(torch.load("unet_epoch20.pth", map_location="cpu"))
# Save the model and config in HF-compatible format
model.save_pretrained("brain_unet_hf")
config.save_pretrained("brain_unet_hf")
print("✅ Saved to brain_unet_hf/")