Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,092 Bytes
476e0f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import os
import argparse
import safetensors
import safetensors.torch
def main():
parser = argparse.ArgumentParser(description="Merge multiple safetensors in a directory into a single safetensors")
parser.add_argument("root", type=str, help="Root directory containing safetensors")
args = parser.parse_args()
safetensor_file_paths = [os.path.join(args.root, f) for f in os.listdir(args.root) if f.endswith(".safetensors")]
if len(safetensor_file_paths) == 1:
return
tensors = {}
for path in safetensor_file_paths:
with safetensors.safe_open(path, framework="pt") as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
safetensors.torch.save_file(tensors, os.path.join(args.root, os.path.basename(safetensor_file_paths[0]).split("-")[0]) + ".safetensors")
for f in os.listdir(args.root):
path = os.path.join(args.root, f)
if path.endswith(".index.json"):
os.remove(path)
if path in safetensor_file_paths:
os.remove(path)
if __name__ == "__main__":
main()
|