File size: 2,803 Bytes
cfa5572
 
8e54919
 
cfa5572
 
8e54919
5f3d08b
cfa5572
5d0c059
 
da2946f
 
cfa5572
 
 
 
af2487a
cfa5572
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6378a6
cfa5572
 
 
 
 
 
 
8e54919
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
---
license: apache-2.0
base_model:
- xai-org/grok-2
---
### huihui-ai/grok-2
This Python [script](https://huggingface.co/huihui-ai/grok-2/blob/main/convert_safetensors.py) is designed to process and merge sharded weight files 
(in safetensors format) for a machine learning model, specifically targeting the [xai-org/grok-2](https://huggingface.co/xai-org/grok-2) model. The main functionalities include:

Just a simple merge, without any inference code, and does not indicate whether the final model is reasonable or correct.

Now, do we need a custom MixtralForCausalLM?

1. **Collecting safetensors files**: Locates all `pytorch_model-*.safetensors` files in the specified model directory.
2. **Loading files into cache**: Loads all safetensors files into memory and builds a key-to-file mapping.
3. **Merging Tensor Parallel (TP) shards**: Merges shards for tensor parallelism (TP=8) along specific dimensions and verifies the merged tensor shapes.
4. **Grouping weights by layer**: Organizes weights by model layer, with special weights (e.g., `lm_head.weight`, `model.embed_tokens.weight`, and `model.norm.weight`) handled separately.
5. **Saving merged weights**: Saves the grouped weights as new safetensors files and generates a new index file [pytorch_model.bin.index.json](https://huggingface.co/huihui-ai/grok-2/blob/main/pytorch_model.bin.index.json).

### Features
- **Input**: Safetensors files in the `xai-org/grok-2` model directory.
- **Output**: Layer-organized safetensors files and an index file in the `huihui-ai/grok-2` directory.
- **Tensor Parallelism Support**: Handles TP=8 shards, merging tensors along specific dimensions (`w1.weight` and `w3.weight` along dim=0, `w2.weight` along dim=1).
- **Error Handling**: Includes warnings and handling for missing files, shape mismatches, and other exceptions.
- **Shape Validation**: Verifies shapes for specific weights (e.g., MoE layer weights), ensuring merged tensors match expected shapes (e.g., `(16384, 8192)` or `(8192, 16384)`).

### Usage
1. Install the required Python libraries:
   ```bash
   pip install torch safetensors
   ```
2. Place the script in an environment with the `xai-org/grok-2` model directory.
3. Run the script:
   ```bash
   python convert_safetensors.py
   ```
4. Output files will be saved in the `huihui-ai/grok-2` directory, including layer-organized safetensors files and an index file.

### Notes
- Ensure the input directory `xai-org/grok-2` contains valid `pytorch_model-*.safetensors` files.
- The script assumes a tensor parallelism degree of 8 (`tp_count = 8`). Modify the `tp_count` value in the script if needed.
- Memory requirements may be high; run on a machine with sufficient memory.
- If shards are missing or shapes mismatch, the script will print warnings and attempt to proceed.