from typing import Union, Optional, TYPE_CHECKING import torch from transformers import LogitsProcessorList, StoppingCriteriaList, GenerationConfig from transformers.generation.utils import ( GenerateBeamOutput, GenerationMixin, GenerateBeamDecoderOnlyOutput, GenerateBeamEncoderDecoderOutput, ) import torch.nn as nn import torch.nn.functional as F import numpy as np import logging from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import ConstrainedBeamSearchScorer if TYPE_CHECKING: from transformers.generation.streamers import BaseStreamer logger = logging.getLogger(__name__) def _constrained_beam_search( model, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"] = None, **model_kwargs, ) -> Union[GenerateBeamOutput, torch.LongTensor]: r""" Generates sequences of token ids for models with a language modeling head using **constrained beam search decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. Parameters: input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`): The sequence used as a prompt for the generation. logits_processor (`LogitsProcessorList`): An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] used to modify the prediction scores of the language modeling head applied at each generation step. stopping_criteria (`StoppingCriteriaList`): An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] used to tell if the generation loop should stop. generation_config ([`~generation.GenerationConfig`]): The generation configuration to be used as parametrization of the decoding method. synced_gpus (`bool`): Whether to continue running the while loop until max_length (needed to avoid deadlocking with `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). model_kwargs: Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is an encoder-decoder model the kwargs should include `encoder_outputs`. Return: [`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a [`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and `return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if `model.config.is_encoder_decoder=True`. """ if generation_config.constraints is not None or generation_config.force_words_ids is not None: constrained_wrong_parameter_msg = ( "one of `constraints`, `force_words_ids` is not `None`, triggering constrained beam search. " "However, `{flag_name}` is set to `{flag_value}`, which is incompatible with this generation " "mode. Set `constraints` and `force_words_ids` to `None` or unset `{flag_name}` to continue." ) if generation_config.do_sample is True: raise ValueError( constrained_wrong_parameter_msg.format(flag_name="do_sample", flag_value=generation_config.do_sample) ) final_constraints = [] if generation_config.constraints is not None: final_constraints = generation_config.constraints if generation_config.force_words_ids is not None: def typeerror(): raise ValueError( "`force_words_ids` has to either be a `list[list[list[int]]]` or `list[list[int]]` " f"of positive integers, but is {generation_config.force_words_ids}." ) if ( not isinstance(generation_config.force_words_ids, list) or len(generation_config.force_words_ids) == 0 ): typeerror() for word_ids in generation_config.force_words_ids: if isinstance(word_ids[0], list): if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() if any(not isinstance(token_ids, list) for token_ids in word_ids): typeerror() if any( any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) for token_ids in word_ids ): typeerror() constraint = DisjunctiveConstraint(word_ids) else: if not isinstance(word_ids, list) or len(word_ids) == 0: typeerror() if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids): typeerror() constraint = PhrasalConstraint(word_ids) final_constraints.append(constraint) # define beam scorer constrained_beam_scorer = ConstrainedBeamSearchScorer( constraints=final_constraints, batch_size=input_ids.shape[0] // generation_config.num_beams, num_beams=generation_config.num_beams, device=input_ids.device, length_penalty=generation_config.length_penalty, do_early_stopping=generation_config.early_stopping, num_beam_hyps_to_keep=generation_config.num_return_sequences, max_length=generation_config.max_length, ) # init values pad_token_id = generation_config._pad_token_tensor eos_token_id = generation_config._eos_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate batch_size = len(constrained_beam_scorer._beam_hyps) num_beams = constrained_beam_scorer.num_beams batch_beam_size, cur_len = input_ids.shape[:2] model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs) if num_beams * batch_size != batch_beam_size: raise ValueError( f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None beam_indices = ( tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None ) decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and model.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens # of the first beam are considered to avoid sampling the exact same tokens across all beams. beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores[:, 1:] = -1e9 beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs) # prepare variable output controls (note: some models won't accept all output controls) model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) outputs = model(**model_inputs, return_dict=True) # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping model_kwargs = model._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder, ) if synced_gpus and this_peer_finished: cur_len = cur_len + 1 continue # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) # .float() is needed to retain precision for later logits manipulations next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device) next_token_scores = nn.functional.log_softmax( next_token_logits, dim=-1 ) # (batch_size * num_beams, vocab_size) next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as( next_token_scores_processed ) scores_for_all_vocab = next_token_scores.clone() # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_logits: raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,) ) if model.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,) ) # reshape for beam search vocab_size = next_token_scores.shape[-1] next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam. n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0 next_token_scores, next_tokens = torch.topk( next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True ) next_indices = (next_tokens / vocab_size).long() next_tokens = next_tokens % vocab_size # stateless beam_outputs = constrained_beam_scorer.process( input_ids, next_token_scores, next_tokens, next_indices, scores_for_all_vocab, pad_token_id=pad_token_id, eos_token_id=eos_token_id, beam_indices=beam_indices, decoder_prompt_len=decoder_prompt_len, ) beam_scores = beam_outputs["next_beam_scores"] beam_next_tokens = beam_outputs["next_beam_tokens"] beam_idx = beam_outputs["next_beam_indices"] input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration # IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory # (that way the memory peak does not include outputs.logits) del outputs # NOTE: we need to check if `model._reorder_cache` exists for special models like RAG, RecurrentGemma etc. if model_kwargs.get("past_key_values", None) is not None: if hasattr(model, "_reorder_cache"): model_kwargs["past_key_values"] = model._reorder_cache(model_kwargs["past_key_values"], beam_idx) else: model_kwargs["past_key_values"].reorder_cache(beam_idx) if return_dict_in_generate and output_scores: beam_indices = tuple(beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))) # increase cur_len cur_len = cur_len + 1 if constrained_beam_scorer.is_done or all(stopping_criteria(input_ids, scores)): this_peer_finished = True sequence_outputs = constrained_beam_scorer.finalize( input_ids, beam_scores, next_tokens, next_indices, pad_token_id=pad_token_id, eos_token_id=eos_token_id, max_length=stopping_criteria.max_length, beam_indices=beam_indices, decoder_prompt_len=decoder_prompt_len, ) if return_dict_in_generate: if not output_scores: sequence_outputs["sequence_scores"] = None if model.config.is_encoder_decoder: return GenerateBeamEncoderDecoderOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, logits=raw_logits, beam_indices=sequence_outputs["beam_indices"], encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return GenerateBeamDecoderOnlyOutput( sequences=sequence_outputs["sequences"], sequences_scores=sequence_outputs["sequence_scores"], scores=scores, logits=raw_logits, beam_indices=sequence_outputs["beam_indices"], attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return sequence_outputs["sequences"] def generate(model, *args, **kwargs): """Custom generate function for constrained beam search decoding. Args: model (`PreTrainedModel`): The model to generate from. num_beams (`int`): The number of beams to use for beam search. constraints (`list[Constraint]`, *optional*): Custom constraints that can be added to the generation to ensure that the output will contain the use of certain tokens as defined by `Constraint` objects, in the most sensible way possible. force_words_ids (`list[list[list[int]]]`): List of token ids that must be generated. If given a `list[list[int]]`, this is treated as a simple list of words that must be included, the opposite to `bad_words_ids`. If given `list[list[list[int]]]`, this triggers a [disjunctive constraint](https://github.com/huggingface/transformers/issues/14081), where one can allow different forms of each word. length_penalty (`float`): The length penalty to use for beam search. early_stopping (`bool`): Whether to stop beam search when sufficient beams have finished. num_return_sequences (`int`): The number of sequences to return. max_length (`int`): The maximum length of the generated sequence. """ generation_outputs = GenerationMixin.generate( model, *args, custom_generate=_constrained_beam_search, **kwargs ) return generation_outputs