Description

Constrained Beam Search extends standard beam search by allowing you to enforce lexical or phrasal constraints in the generated output. This is useful when you know certain words or phrases must appear (e.g., translation dictionaries, product names, slot values), or when multiple outputs are equally probable but only some are desirable for your use case.

Unlike ordinary beam search, constrained beam search steers generation to include required subsequences somewhere in the final output while balancing fluency.


Why it's difficult

Beam search generates token-by-token and scores candidates locally. Forcing a phrase like "is fast" to appear somewhere requires the search to plan several steps ahead and decide when to insert the constrained tokens without breaking fluency. The problem becomes more complex with multiple constraints, optional alternatives, or ordering requirements.

Constrained beam search solves this by:

  • Injecting constraint-progressing tokens among regular high-probability candidates
  • Grouping beams into banks by how much of the constraints they satisfied
  • Selecting beams round-robin across banks to balance fluency and constraint satisfaction

Base model


Model compatibility

  • Encoder-decoder and decoder-only transformer models

Additional Arguments

  • constraints (list[Constraint]): Advanced constraints, e.g., PhrasalConstraint, DisjunctiveConstraint
  • force_words_ids (list[list[int]] | list[list[list[int]]]): Simple way to specify words/phrases or disjunctive sets
  • num_beams (int): Beam width
  • Other standard beam args: length_penalty, early_stopping, num_return_sequences, max_length

Notes:

  • Constrained decoding is incompatible with sampling: set do_sample=False
  • Tokenize constraints without adding special tokens

Example 1: Forcing a word (formal German translation)

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

encoder_input_str = "translate English to German: How old are you?"
input_ids = tokenizer(encoder_input_str, return_tensors="pt").input_ids

force_words = ["Sie"]
force_words_ids = tokenizer(force_words, add_special_tokens=False).input_ids

outputs = model.generate(
    input_ids,
    custom_generate="transformers-community/constrained-beam-search",
    force_words_ids=force_words_ids,
    num_beams=5,
    num_return_sequences=1,
    no_repeat_ngram_size=1,
    remove_invalid_values=True,
    trust_remote_code=True,
)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Expected to contain the forced word: Wie alt sind Sie?


Example 2: Disjunctive constraints (choose any of several forms)

from transformers import GPT2LMHeadModel, GPT2Tokenizer

model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"]

force_words_ids = [
    tokenizer([force_word], add_prefix_space=True, add_special_tokens=False).input_ids,
    tokenizer(force_flexible, add_prefix_space=True, add_special_tokens=False).input_ids,
]

starting_text = ["The soldiers", "The child"]
input_ids = tokenizer(starting_text, return_tensors="pt").input_ids

outputs = model.generate(
    input_ids,
    custom_generate="transformers-community/constrained-beam-search",
    force_words_ids=force_words_ids,
    num_beams=10,
    num_return_sequences=1,
    no_repeat_ngram_size=1,
    remove_invalid_values=True,
    trust_remote_code=True,
)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))
print(tokenizer.decode(outputs[1], skip_special_tokens=True))

Outputs will include the mandatory word and at least one from the flexible set.

Downloads last month
6
Safetensors
Model size
752M params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support