Description
Constrained Beam Search extends standard beam search by allowing you to enforce lexical or phrasal constraints in the generated output. This is useful when you know certain words or phrases must appear (e.g., translation dictionaries, product names, slot values), or when multiple outputs are equally probable but only some are desirable for your use case.
Unlike ordinary beam search, constrained beam search steers generation to include required subsequences somewhere in the final output while balancing fluency.
Why it's difficult
Beam search generates token-by-token and scores candidates locally. Forcing a phrase like "is fast" to appear somewhere requires the search to plan several steps ahead and decide when to insert the constrained tokens without breaking fluency. The problem becomes more complex with multiple constraints, optional alternatives, or ordering requirements.
Constrained beam search solves this by:
- Injecting constraint-progressing tokens among regular high-probability candidates
- Grouping beams into banks by how much of the constraints they satisfied
- Selecting beams round-robin across banks to balance fluency and constraint satisfaction
Base model
Model compatibility
- Encoder-decoder and decoder-only transformer models
Additional Arguments
constraints
(list[Constraint]): Advanced constraints, e.g.,PhrasalConstraint
,DisjunctiveConstraint
force_words_ids
(list[list[int]] | list[list[list[int]]]): Simple way to specify words/phrases or disjunctive setsnum_beams
(int): Beam width- Other standard beam args:
length_penalty
,early_stopping
,num_return_sequences
,max_length
Notes:
- Constrained decoding is incompatible with sampling: set
do_sample=False
- Tokenize constraints without adding special tokens
Example 1: Forcing a word (formal German translation)
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
encoder_input_str = "translate English to German: How old are you?"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids
force_words = ["Sie"]
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids
outputs = model.generate(
input_ids,
custom_generate="transformers-community/constrained-beam-search",
force_words_ids=force_words_ids,
num_beams=5,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
trust_remote_code=True,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Expected to contain the forced word: Wie alt sind Sie?
Example 2: Disjunctive constraints (choose any of several forms)
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"]
force_words_ids = [
tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids,
tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,
]
starting_text = ["The soldiers", "The child"]
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids
outputs = model.generate(
input_ids,
custom_generate="transformers-community/constrained-beam-search",
force_words_ids=force_words_ids,
num_beams=10,
num_return_sequences=1,
no_repeat_ngram_size=1,
remove_invalid_values=True,
trust_remote_code=True,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
print(tokenizer.decode(outputs[1], skip_special_tokens=True))
Outputs will include the mandatory word and at least one from the flexible set.
- Downloads last month
- 6