|
--- |
|
|
|
library_name: transformers |
|
tags: |
|
- custom_generate |
|
--- |
|
|
|
## Description |
|
|
|
Constrained Beam Search extends standard beam search by allowing you to enforce lexical or phrasal constraints in the generated output. This is useful when you know certain words or phrases must appear (e.g., translation dictionaries, product names, slot values), or when multiple outputs are equally probable but only some are desirable for your use case. |
|
|
|
Unlike ordinary beam search, constrained beam search steers generation to include required subsequences somewhere in the final output while balancing fluency. |
|
|
|
--- |
|
|
|
## Why it's difficult |
|
|
|
Beam search generates token-by-token and scores candidates locally. Forcing a phrase like "is fast" to appear somewhere requires the search to plan several steps ahead and decide when to insert the constrained tokens without breaking fluency. The problem becomes more complex with multiple constraints, optional alternatives, or ordering requirements. |
|
|
|
Constrained beam search solves this by: |
|
- Injecting constraint-progressing tokens among regular high-probability candidates |
|
- Grouping beams into banks by how much of the constraints they satisfied |
|
- Selecting beams round-robin across banks to balance fluency and constraint satisfaction |
|
|
|
--- |
|
|
|
## Base model |
|
|
|
* [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) |
|
|
|
--- |
|
|
|
## Model compatibility |
|
|
|
- Encoder-decoder and decoder-only transformer models |
|
|
|
--- |
|
|
|
## Additional Arguments |
|
|
|
- `constraints` (list[Constraint]): Advanced constraints, e.g., `PhrasalConstraint`, `DisjunctiveConstraint` |
|
- `force_words_ids` (list[list[int]] | list[list[list[int]]]): Simple way to specify words/phrases or disjunctive sets |
|
- `num_beams` (int): Beam width |
|
- Other standard beam args: `length_penalty`, `early_stopping`, `num_return_sequences`, `max_length` |
|
|
|
Notes: |
|
- Constrained decoding is incompatible with sampling: set `do_sample=False` |
|
- Tokenize constraints without adding special tokens |
|
|
|
--- |
|
|
|
## Example 1: Forcing a word (formal German translation) |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("t5-base") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") |
|
|
|
encoder_input_str = "translate English to German: How old are you?" |
|
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids |
|
|
|
force_words = ["Sie"] |
|
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids |
|
|
|
outputs = model.generate( |
|
input_ids, |
|
custom_generate="transformers-community/constrained-beam-search", |
|
force_words_ids=force_words_ids, |
|
num_beams=5, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=1, |
|
remove_invalid_values=True, |
|
trust_remote_code=True, |
|
) |
|
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
``` |
|
|
|
Expected to contain the forced word: `Wie alt sind Sie?` |
|
|
|
--- |
|
|
|
## Example 2: Disjunctive constraints (choose any of several forms) |
|
|
|
```python |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
|
|
model = GPT2LMHeadModel.from_pretrained("gpt2") |
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
|
|
force_word = "scared" |
|
force_flexible = ["scream", "screams", "screaming", "screamed"] |
|
|
|
force_words_ids = [ |
|
tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids, |
|
tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids, |
|
] |
|
|
|
starting_text = ["The soldiers", "The child"] |
|
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids |
|
|
|
outputs = model.generate( |
|
input_ids, |
|
custom_generate="transformers-community/constrained-beam-search", |
|
force_words_ids=force_words_ids, |
|
num_beams=10, |
|
num_return_sequences=1, |
|
no_repeat_ngram_size=1, |
|
remove_invalid_values=True, |
|
trust_remote_code=True, |
|
) |
|
|
|
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
print(tokenizer.decode(outputs[1], skip_special_tokens=True)) |
|
``` |
|
|
|
Outputs will include the mandatory word and at least one from the flexible set. |
|
|