refactor: remove script
Browse files- generate_index.py +0 -46
generate_index.py
DELETED
@@ -1,46 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
from safetensors import safe_open
|
4 |
-
import torch
|
5 |
-
from typing import Dict
|
6 |
-
|
7 |
-
def generate_index(directory: str) -> Dict:
|
8 |
-
index = {
|
9 |
-
"metadata": {"total_size": 0},
|
10 |
-
"weight_map": {}
|
11 |
-
}
|
12 |
-
|
13 |
-
safetensors_files = [f for f in os.listdir(directory) if f.endswith('.safetensors')]
|
14 |
-
safetensors_files.sort() # Ensure consistent ordering
|
15 |
-
|
16 |
-
for filename in safetensors_files:
|
17 |
-
filepath = os.path.join(directory, filename)
|
18 |
-
try:
|
19 |
-
with safe_open(filepath, framework="pt") as f:
|
20 |
-
for key in f.keys():
|
21 |
-
if key in index["weight_map"]:
|
22 |
-
print(f"Warning: Duplicate key '{key}' found in {filename}. Using the last occurrence.")
|
23 |
-
index["weight_map"][key] = filename
|
24 |
-
|
25 |
-
tensor = f.get_tensor(key)
|
26 |
-
tensor_size = tensor.numel() * tensor.element_size()
|
27 |
-
index["metadata"]["total_size"] += tensor_size
|
28 |
-
except Exception as e:
|
29 |
-
print(f"Error processing {filename}: {str(e)}")
|
30 |
-
|
31 |
-
return index
|
32 |
-
|
33 |
-
def save_index(index: Dict, output_file: str):
|
34 |
-
with open(output_file, 'w') as f:
|
35 |
-
json.dump(index, f, indent=2)
|
36 |
-
|
37 |
-
if __name__ == "__main__":
|
38 |
-
current_dir = os.getcwd()
|
39 |
-
output_file = "model.safetensors.index.json"
|
40 |
-
|
41 |
-
index = generate_index(current_dir)
|
42 |
-
save_index(index, output_file)
|
43 |
-
|
44 |
-
print(f"Index generated with {len(index['weight_map'])} tensors.")
|
45 |
-
print(f"Total size: {index['metadata']['total_size']} bytes")
|
46 |
-
print(f"Index saved to {output_file}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|