|
--- |
|
|
|
library_name: transformers |
|
tags: |
|
- custom_generate |
|
--- |
|
|
|
## Description |
|
|
|
Implementation of [Decoding by Contrasting Layers (DoLa)](https://huggingface.co/papers/2309.03883), |
|
a contrastive decoding strategy for improving factuality and reducing hallucinations in language model outputs. |
|
|
|
DoLa works by **contrasting the logits** from the final layer with those from earlier layers of the model, |
|
amplifying factual knowledge localized in specific layers and suppressing spurious information. |
|
|
|
This can be useful for: |
|
|
|
* **Short-answer tasks** (e.g., TruthfulQA) — using higher layers (`dola_layers="high"`) |
|
* **Long-answer reasoning tasks** (e.g., GSM8K, StrategyQA, FACTOR, VicunaQA) — using lower layers (`dola_layers="low"`) |
|
|
|
DoLa is **not recommended for smaller models** such as GPT-2, as the improvement may be negligible. |
|
|
|
This implementation matches the `DoLa` functionality present in `transformers<4.53.0`. |
|
|
|
--- |
|
|
|
## Base model |
|
|
|
* [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) |
|
|
|
--- |
|
|
|
## Model compatibility |
|
|
|
* Decoder-only transformer models |
|
|
|
--- |
|
|
|
## Additional Arguments |
|
|
|
* **`dola_layers`** (*str* or *List\[int]*, optional): |
|
Which earlier layers to contrast with the final layer. Can be: |
|
|
|
* `"low"` — lower half of layers (recommended for long answers) |
|
* `"high"` — upper half of layers (recommended for short answers) |
|
* List of integer indices (e.g., `[18, 20]`) |
|
|
|
**Note:** |
|
|
|
* Layer 0 is the word embedding; layer 1 is the first transformer block. |
|
* If the model has tied word embeddings, layer 0 is skipped and counting starts at layer 2. |
|
* Typical defaults: |
|
|
|
| # Layers | `"low"` range | `"high"` range | |
|
| -------- | ------------------- | ------------------- | |
|
| > 40 | `(0, 20, 2)` | `(N - 20, N, 2)` | |
|
| ≤ 40 | `range(0, N//2, 2)` | `range(N//2, N, 2)` | |
|
|
|
* **`repetition_penalty`** (*float*, optional, defaults to `None`): |
|
Helps reduce repetition. A value of `1.2` is recommended. |
|
|
|
--- |
|
|
|
## Output Type changes |
|
|
|
* The `generate` method output remains the same as default `transformers` generation, |
|
but logits are post-processed using the DoLa contrastive scoring before token selection. |
|
|
|
--- |
|
|
|
## Example usage |
|
|
|
### Using higher layers (short-answer tasks) |
|
|
|
```python |
|
# requires `transformers>=4.56.0`, previously, it was part of the library |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, infer_device |
|
|
|
device = infer_device() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"Qwen/Qwen3-0.6B", torch_dtype=torch.float16 |
|
).to(device) |
|
|
|
inputs = tokenizer("What is the highest peak in the world?", return_tensors="pt").to(device) |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=50, |
|
do_sample=False, |
|
custom_generate="transformers-community/dola", |
|
trust_remote_code=True, |
|
dola_layers="high" |
|
) |
|
|
|
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)) |
|
``` |
|
|
|
--- |
|
|
|
### Contrasting specific layers |
|
|
|
```python |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, infer_device |
|
|
|
device = infer_device() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"Qwen/Qwen3-0.6B", torch_dtype=torch.float16 |
|
).to(device) |
|
|
|
inputs = tokenizer("What is the highest peak in the world?", return_tensors="pt").to(device) |
|
|
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=50, |
|
do_sample=False, |
|
repetition_penalty=1.2, |
|
custom_generate="transformers-community/dola", |
|
trust_remote_code=True, |
|
dola_layers=[18, 20] |
|
) |
|
|
|
# Only decode the newly generated tokens |
|
print(tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True)) |
|
``` |
|
|