""" """ import torch from huggingface_hub import hf_hub_download from spaces.zero.torch.aoti import ZeroGPUCompiledModel from spaces.zero.torch.aoti import ZeroGPUWeights def aoti_load(module: torch.nn.Module, repo_id: str): repeated_blocks = module._repeated_blocks aoti_files = {name: hf_hub_download(repo_id, f'{name}.pt2') for name in repeated_blocks} for block_name, aoti_file in aoti_files.items(): for block in module.modules(): if block.__class__.__name__ == block_name: weights = ZeroGPUWeights(block.state_dict()) block.forward = ZeroGPUCompiledModel(aoti_file, weights)