Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import argparse | |
import safetensors | |
import safetensors.torch | |
def main(): | |
parser = argparse.ArgumentParser(description="Merge multiple safetensors in a directory into a single safetensors") | |
parser.add_argument("root", type=str, help="Root directory containing safetensors") | |
args = parser.parse_args() | |
safetensor_file_paths = [os.path.join(args.root, f) for f in os.listdir(args.root) if f.endswith(".safetensors")] | |
if len(safetensor_file_paths) == 1: | |
return | |
tensors = {} | |
for path in safetensor_file_paths: | |
with safetensors.safe_open(path, framework="pt") as f: | |
for k in f.keys(): | |
tensors[k] = f.get_tensor(k) | |
safetensors.torch.save_file(tensors, os.path.join(args.root, os.path.basename(safetensor_file_paths[0]).split("-")[0]) + ".safetensors") | |
for f in os.listdir(args.root): | |
path = os.path.join(args.root, f) | |
if path.endswith(".index.json"): | |
os.remove(path) | |
if path in safetensor_file_paths: | |
os.remove(path) | |
if __name__ == "__main__": | |
main() | |