File size: 5,036 Bytes
538b061 130246f 7ad496b 538b061 c84b6d8 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 130246f 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 c06495c 538b061 130246f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
---
library_name: transformers
base_model:
- TIGER-Lab/MAmmoTH2-8B
- rombodawg/rombos_Replete-Coder-Llama3-8B
tags:
- merge
---
Code + Math Llama3-8B merged with RegMean algorithm.
See details in https://github.com/AuCson/RegMean-LLama3-8B.
## Fast and Numerically Stable RegMean for Merging LLama3-8B
This repo is a fast and numerically stable re-implementation of RegMean model merging algorithm for LLama3-8B.
We merge the following two models.
- [Code Model] [rombodawg/rombos_Replete-Coder-Llama3-8B](https://huggingface.co/rombodawg/rombos_Replete-Coder-Llama3-8B) (Re-implementation of Replete-Coder)
- [Math Model] [TIGER-Lab/MAmmoTH2-8B](https://huggingface.co/TIGER-Lab/MAmmoTH2-8B)
## Results
| Method/Benchmark | GSM8k (Math) | Mathqa (Math) | HumanEval-Instruct (Code) | MBPP (Code) |
| ---- | ---- | ---- | ---- | ---- |
| | 5-shot EM | 0-shot Acc-norm | 0-shot Pass@1 | 3-shot Pass@1 |
| [Math Model](https://huggingface.co/TIGER-Lab/MAmmoTH2-8B) | 70.40* | 43.85 | 36.59 | 40.04 |
| [Code Model](https://huggingface.co/rombodawg/rombos_Replete-Coder-Llama3-8B) | 57.92 | 37.35 | 42.07 | 49.20 |
| [Average](https://huggingface.co/aucson/llama3-code-math-avg-merge) | 65.27 | 44.05 | 43.29 | 47.80 |
| [RegMean ($\alpha$=0.1)](https://huggingface.co/aucson/llama3-code-math-regmean-merge/tree/main) | 68.31 | 44.99 | 44.51 | 45.20 |
\* Official result
\* We found the zero-shot results are sensitive to chat templates and reported best achievable result for HumanInstruct for all models: we modified `lm-evaluation-harness/lm_eval/tasks/humaneval/humaneval.yaml` so that "\`\`\`" can be considered as end of responses.
The merged models, along with the activation inner product matrices, are avaiable on the huggingface hub.
## What's new?
RegMean solves a least square regression problem at each linear layer of the transformer. This is now implemented with built-in PyTorch linalg.lstsq function.
```python
# old
# sum_gram_inv = torch.inverse(sum_gram)
# wt = torch.matmul(sum_gram_inv, sum_gram_m_ws)
# new
wt = torch.linalg.lstsq(sum_gram, sum_gram_m_ws).solution
```
According to PyTorch's official doumentation,
```
This function computes X = A.pinverse() @ B in a faster and more numerically stable way than performing the computations separately.
```
## Computational efficiency
- **Computing gram matrices**: We compute inner product matrics for code and math models on 10k training examples. Each of them take 3-hour on one Quadro RTX A6000 GPU (which can probably accelerated with more efficient LLM inference framework). But we have uploaded them under the [merged model repo](https://huggingface.co/aucson/llama3-code-math-regmean-merge/tree/main) so that you do not need to re-compute.
- **Merging Models**: ~2 minutes on the same GPU for this re-implementation. Please note loading two 8B models and (almost) equally sized inner product matrices at once can take >150GB memory.
## Reproducing the results
1. Create a python environment and install the modified lm-eval-harness library for evaluating merged models.
```
cd lm-eval-harness
pip install -e .
```
The only modification is `lm_eval/tasks/humaneval/humaneval.yaml`.
2. Preparing activation inner product matrices.
You can download them from the [merged model repo](https://huggingface.co/aucson/llama3-code-math-regmean-merge/tree/main) and place them under `runs/merges/math-llama3/gram.pkl` and `runs/merges/code-llama3/gram.pkl`. Alternatively, you can compute them yourself with,
```
python compute_gram.py code
python compute_gram.py math
```
3. Merging models
```
python merge_model.py avg
python merge_model.py regmean
```
4. Evaluation with `lm-eval-harness`. Please follow the safety guidelines of humaneval and mbpp regarding execution of LLM generated code.
```
merge_exp=regmean_0.1
# merge_exp=avg
HF_ALLOW_CODE_EVAL=1 lm_eval --model vllm --model_args pretrained=runs/merges/${merge_exp},tokenizer=meta-llama/Meta-Llama-3-8B,tensor_parallel_size=1,dtype=bfloat16 --tasks mathqa,gsm8k,humaneval_instruct,mbpp --output_path runs/merges/${merge_exp}/lm_eval_results_preds --log_samples --trust_remote_code --confirm_run_unsafe_code
```
## Caveats
Overall, simple averaging works well for LLMs and the benefits of merging algorithms diminishes for merging algorithms [1]
## Citations
For the RegMean algorithm.
```
@inproceedings{
jin2023dataless,
title={Dataless Knowledge Fusion by Merging Weights of Language Models},
author={Xisen Jin and Xiang Ren and Daniel Preotiuc-Pietro and Pengxiang Cheng},
booktitle={The Eleventh International Conference on Learning Representations },
year={2023},
url={https://openreview.net/forum?id=FCnohuR6AnM}
}
```
Here are other useful references that greatly inspired this re-implementation.
[1] Yadav et al. 2024, [What Matters for Model Merging at Scale?](https://arxiv.org/abs/2410.03617)
[2] Tam et al. 2024, [Merging by Matching Models in Task Parameter Subspaces](https://openreview.net/forum?id=qNGo6ghWFB) |