import copy import torch import inspect import warnings import numpy as np import torch.nn as nn from typing import Optional, Union, List, Callable import torch.distributed as dist from transformers.generation.streamers import BaseStreamer from transformers.generation.utils import ( GenerationConfig, GenerationMode, LogitsProcessorList, StoppingCriteriaList, GenerateOutput, GenerationMixin, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput, GenerateNonBeamOutput, is_deepspeed_zero3_enabled, is_torchdynamo_compiling, NEED_SETUP_CACHE_CLASSES_MAPPING, QUANT_BACKEND_CLASSES_MAPPING, is_hqq_available, QuantizedCacheConfig, is_quanto_available, DynamicCache, EncoderDecoderCache, logging ) # from transformers.generation.stopping_criteria import validate_stopping_criteria logger = logging.get_logger(__name__) class GenerationWithCTC(GenerationMixin): @torch.no_grad() def generate( self, inputs: Optional[torch.Tensor] = None, generation_config: Optional[GenerationConfig] = None, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, synced_gpus: Optional[bool] = None, assistant_model: Optional["PreTrainedModel"] = None, streamer: Optional["BaseStreamer"] = None, streamer_unit: Optional["BaseStreamer"] = None, streaming_unit_gen = False, negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) self._validate_model_kwargs(model_kwargs.copy()) self._validate_assistant(assistant_model) # 2. Set generation parameters if not already defined if synced_gpus is None: if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1: synced_gpus = True else: synced_gpus = False logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys()) requires_attention_mask = "encoder_outputs" not in model_kwargs kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None # 3. Define model inputs inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( inputs, generation_config.bos_token_id, model_kwargs ) batch_size = inputs_tensor.shape[0] device = inputs_tensor.device self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device) # decoder-only models must use left-padding for batched generation. if not self.config.is_encoder_decoder and not is_torchdynamo_compiling(): # If `input_ids` was given, check if the last id in any sequence is `pad_token_id` # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off. if ( generation_config._pad_token_tensor is not None and batch_size > 1 and len(inputs_tensor.shape) == 2 and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0 ): logger.warning( "A decoder-only architecture is being used, but right-padding was detected! For correct " "generation results, please set `padding_side='left'` when initializing the tokenizer." ) # 4. Define other model kwargs # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are # generating the first new token or not, and we only want to use the embeddings for the first new token) if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": model_kwargs["use_cache"] = True else: model_kwargs["use_cache"] = generation_config.use_cache if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask: model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation( inputs_tensor, generation_config._pad_token_tensor, generation_config._eos_token_tensor ) if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs: # if model is encoder decoder encoder_outputs are created and added to `model_kwargs` model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, model_kwargs, model_input_name, generation_config ) # 5. Prepare `input_ids` which will be used for auto-regressive generation if self.config.is_encoder_decoder: input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation( batch_size=batch_size, model_input_name=model_input_name, model_kwargs=model_kwargs, decoder_start_token_id=generation_config._decoder_start_token_tensor, device=inputs_tensor.device, ) else: input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") if generation_config.token_healing: input_ids = self.heal_tokens(input_ids, tokenizer) if streamer is not None: streamer.put(input_ids.cpu()) # 6. Prepare `max_length` depending on other stopping criteria. input_ids_length = input_ids.shape[-1] has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None generation_config = self._prepare_generated_length( generation_config=generation_config, has_default_max_length=has_default_max_length, has_default_min_length=has_default_min_length, model_input_name=model_input_name, inputs_tensor=inputs_tensor, input_ids_length=input_ids_length, ) use_dynamic_cache_by_default = False if "mamba" in self.__class__.__name__.lower(): cache_name = "cache_params" else: cache_name = "past_key_values" if generation_config.cache_implementation is not None and (model_kwargs.get(cache_name) is not None): raise ValueError( f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` (a " "Cache object) is unsupported. Please use only one of the two." ) elif generation_config.cache_implementation is not None: if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: if generation_config.cache_implementation == "static" and not self._supports_static_cache: raise ValueError( "This model does not support `cache_implementation='static'`. Please check the following " "issue: https://github.com/huggingface/transformers/issues/28981" ) model_kwargs[cache_name] = self._get_cache( generation_config.cache_implementation, getattr(generation_config, "num_beams", 1) * batch_size, generation_config.max_length, model_kwargs, ) elif generation_config.cache_implementation == "quantized": if not self._supports_quantized_cache: raise ValueError( "This model does not support the quantized cache. If you want your model to support quantized " "cache, please open an issue." ) cache_config = ( generation_config.cache_config if generation_config.cache_config is not None else QuantizedCacheConfig() ) cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend] if cache_config.backend == "quanto" and not is_quanto_available(): raise ImportError( "You need to install `quanto` in order to use KV cache quantization with quanto backend. " "Please install it via with `pip install quanto`" ) elif cache_config.backend == "HQQ" and not is_hqq_available(): raise ImportError( "You need to install `HQQ` in order to use KV cache quantization with HQQ backend. " "Please install it via with `pip install hqq`" ) model_kwargs[cache_name] = cache_class(cache_config) # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # keeps copying the cache thus using much more memory elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): past = model_kwargs.get(cache_name, None) requires_cross_attention_cache = ( self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None ) if past is None: model_kwargs[cache_name] = ( DynamicCache() if not requires_cross_attention_cache else EncoderDecoderCache(DynamicCache(), DynamicCache()) ) use_dynamic_cache_by_default = True elif isinstance(past, tuple): model_kwargs[cache_name] = ( DynamicCache.from_legacy_cache(past) if not requires_cross_attention_cache else EncoderDecoderCache.from_legacy_cache(past) ) use_dynamic_cache_by_default = True self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) # 7. determine generation mode generation_mode = generation_config.get_generation_mode(assistant_model) if (streamer is not None or streamer_unit is not None) and (generation_config.num_beams > 1): raise ValueError( "`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1." ) if self.device.type != input_ids.device.type: warnings.warn( "You are calling .generate() with the `input_ids` being on a device type different" f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." " Please make sure that you have put `input_ids` to the" f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" " running `.generate()`.", UserWarning, ) # 8. prepare distribution pre_processing samplers prepared_logits_processor = self._get_logits_processor( generation_config=generation_config, input_ids_seq_length=input_ids_length, encoder_input_ids=inputs_tensor, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, logits_processor=logits_processor, device=inputs_tensor.device, model_kwargs=model_kwargs, negative_prompt_ids=negative_prompt_ids, negative_prompt_attention_mask=negative_prompt_attention_mask, ) # 9. prepare stopping criteria prepared_stopping_criteria = self._get_stopping_criteria( generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs ) # 10. go into different generation modes if generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): # 11. prepare logits warper prepared_logits_warper = ( self._get_logits_warper(generation_config, device=input_ids.device) if generation_config.do_sample else None ) # 12. expand input_ids with `num_return_sequences` additional sequences per batch input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) if streaming_unit_gen: return self._sample_streaming_unit( input_ids, logits_processor=prepared_logits_processor, logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, streamer_unit=streamer_unit, **model_kwargs, ) else: return self._sample( input_ids, logits_processor=prepared_logits_processor, logits_warper=prepared_logits_warper, stopping_criteria=prepared_stopping_criteria, generation_config=generation_config, synced_gpus=synced_gpus, streamer=streamer, **model_kwargs, ) else: raise NotImplementedError def _sample( self, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], logits_warper: Optional[LogitsProcessorList], **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: # init values pad_token_id = generation_config._pad_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): raise ValueError( "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " f"{logits_warper})." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished batch_size = input_ids.shape[0] this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # prepare variable output controls (note: some models won't accept all output controls) model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) # forward pass to get next token outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) next_token_logits = outputs.logits[:, -1, :].clone() # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) if do_sample: next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_logits: raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # token selection if do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration del outputs if streamer is not None: streamer.end() if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( sequences=input_ids, scores=scores, logits=raw_logits, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return input_ids def _sample_streaming_unit( self, input_ids: torch.LongTensor, logits_processor: LogitsProcessorList, stopping_criteria: StoppingCriteriaList, generation_config: GenerationConfig, synced_gpus: bool, streamer: Optional["BaseStreamer"], streamer_unit: Optional["BaseStreamer"], logits_warper: Optional[LogitsProcessorList], **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: # init values pad_token_id = generation_config._pad_token_tensor output_attentions = generation_config.output_attentions output_hidden_states = generation_config.output_hidden_states output_scores = generation_config.output_scores output_logits = generation_config.output_logits return_dict_in_generate = generation_config.return_dict_in_generate has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) do_sample = generation_config.do_sample if do_sample is True and not isinstance(logits_warper, LogitsProcessorList): raise ValueError( "`do_sample` is set to `True`, `logits_warper` must be a `LogitsProcessorList` instance (it is " f"{logits_warper})." ) # init attention / hidden states / scores tuples scores = () if (return_dict_in_generate and output_scores) else None raw_logits = () if (return_dict_in_generate and output_logits) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None encoder_hidden_states = ( model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None ) # keep track of which sequences are already finished batch_size = input_ids.shape[0] this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) generated_units = torch.tensor([]) while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) # prepare variable output controls (note: some models won't accept all output controls) model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) # forward pass to get next token outputs = self(**model_inputs, return_dict=True) if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration # (the clone itself is always small) next_token_logits = outputs.logits[:, -1, :].clone() # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) if do_sample: next_token_scores = logits_warper(input_ids, next_token_scores) # Store scores, attentions and hidden_states when required if return_dict_in_generate: if output_scores: scores += (next_token_scores,) if output_logits: raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) if output_hidden_states: decoder_hidden_states += ( (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) ) # token selection if do_sample: probs = nn.functional.softmax(next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_scores, dim=-1) # speechgen hidden_states = torch.cat([decoder_hidden_states[0][-1][:, -1:, :]] + [decoder_hidden_states[i][-1] for i in range(1, len(decoder_hidden_states))], dim=1) ctc_pred = self.speech_generator.predict(hidden_states.squeeze(0)) cur_units = ctc_postprocess(ctc_pred, blank=self.model.config.unit_vocab_size) # finished sentences should have their next token be a padding token if has_eos_stopping_criteria: next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) # update generated ids, model inputs, and length for next step input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) if streamer is not None: streamer.put(next_tokens.cpu()) if streamer_unit is not None: for i in range(len(generated_units), len(cur_units)): streamer_unit.put(cur_units[i].unsqueeze(0)) generated_units = cur_units model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, ) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) this_peer_finished = unfinished_sequences.max() == 0 # This is needed to properly delete outputs.logits which may be very large for first iteration # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration del outputs if streamer is not None: streamer.end() if return_dict_in_generate: if self.config.is_encoder_decoder: return GenerateEncoderDecoderOutput( sequences=input_ids, scores=scores, logits=raw_logits, encoder_attentions=encoder_attentions, encoder_hidden_states=encoder_hidden_states, decoder_attentions=decoder_attentions, cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return GenerateDecoderOnlyOutput( sequences=input_ids, scores=scores, logits=raw_logits, attentions=decoder_attentions, hidden_states=decoder_hidden_states, past_key_values=model_kwargs.get("past_key_values"), ) else: return input_ids def ctc_postprocess(tokens, blank): _toks = tokens.squeeze(0).tolist() deduplicated_toks = [v for i, v in enumerate(_toks) if i == 0 or v != _toks[i - 1]] hyp = torch.tensor([v for v in deduplicated_toks if v != blank]) return hyp