| 718 | |
| 719 @add_start_docstrings( | |
| 720 "The bare Gemma2 Model outputting raw hidden-states without any specific head on top.", | |
| 721 GEMMA2_START_DOCSTRING, | |
| 722 ) | |
| 723 class Gemma2Model(Gemma2PreTrainedModel): | |
| 724 """ | |
| 725 Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Gemma2DecoderLayer`] | |
| 726 | |
| 727 Args: | |
| 728 config: Gemma2Config | |
| 729 """ | |
| 730 | |
| 731 def __init__(self, config: Gemma2Config): | |
| 732 super().__init__(config) | |
| 733 self.padding_idx = config.pad_token_id | |
| 734 self.vocab_size = config.vocab_size | |
| 735 | |
| 736 self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | |
| 737 self.layers = nn.ModuleList( | |
| 738 [Gemma2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | |
| 739 ) | |
| 740 self.norm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| 741 self.gradient_checkpointing = False | |
| 742 | |
| 743 # Initialize weights and apply final processing | |
| 744 self.post_init() | |
| 745 | |
| 746 def get_input_embeddings(self): | |
| 747 return self.embed_tokens | |
| 748 | |
| 749 def set_input_embeddings(self, value): | |
| 750 self.embed_tokens = value | |
| 751 | |
| 752 @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) | |
| 753 def forward( | |
| 754 self, | |
| 755 input_ids: torch.LongTensor = None, | |
| 756 attention_mask: Optional[torch.Tensor] = None, | |
| 757 position_ids: Optional[torch.LongTensor] = None, | |
| 758 past_key_values: Optional[HybridCache] = None, | |
| 759 inputs_embeds: Optional[torch.FloatTensor] = None, | |
| 760 use_cache: Optional[bool] = None, | |
| 761 output_attentions: Optional[bool] = None, | |
| 762 output_hidden_states: Optional[bool] = None, | |
| 763 return_dict: Optional[bool] = None, | |
| 764 cache_position: Optional[torch.LongTensor] = None, | |
| 765 ) -> Union[Tuple, BaseModelOutputWithPast]: | |
| 766 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| 767 output_hidden_states = ( | |
| 768 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| 769 ) | |
| 770 use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| 771 return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| 772 | |
| 773 if (input_ids is None) ^ (inputs_embeds is not None): | |
| 774 raise ValueError("You must specify exactly one of input_ids or inputs_embeds") | |
| 775 | |
| 776 if self.gradient_checkpointing and self.training and use_cache: | |
| 777 logger.warning_once( | |
| 778 "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." | |
| 779 ) | |
| 780 use_cache = False | |
| 781 | |
| 782 if inputs_embeds is None: | |
| 783 inputs_embeds = self.embed_tokens(input_ids) | |
| 784 | |
| 785 if use_cache and past_key_values is None and not self.training: | |
| 786 batch_size, seq_len, _ = inputs_embeds.shape | |
| 787 past_key_values = HybridCache( | |
| 788 self.config, | |
| 789 batch_size=batch_size, | |
| 790 max_cache_len=seq_len, | |
| 791 device=self.device, | |
| 792 dtype=inputs_embeds.dtype, | |
| 793 ) | |
| 794 | |
| 795 if cache_position is None: | |
| 796 past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | |
| 797 cache_position = torch.arange( | |
| 798 past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | |
| 799 ) | |
| 800 | |
| 801 if position_ids is None: | |
| 802 position_ids = cache_position.unsqueeze(0) | |
| 803 | |
| 804 causal_mask = self._update_causal_mask( | |
| 805 attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions | |
| 806 ) | |
| 807 | |
| 808 # embed positions | |
| 809 hidden_states = inputs_embeds | |
| 810 | |
| 811 # normalized | |
| 812 # Gemma2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 | |
| 813 # See https://github.com/huggingface/transformers/pull/29402 | |
| 814 normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) | |
| 815 hidden_states = hidden_states * normalizer | |
| 816 | |
| 817 # decoder layers | |
| 818 all_hidden_states = () if output_hidden_states else None | |
| 819 all_self_attns = () if output_attentions else None | |
| 820 | |
| 821 for decoder_layer in self.layers: | |
| 822 if output_hidden_states: | |
| 823 all_hidden_states += (hidden_states,) | |
| 824 | |
| 825 if self.gradient_checkpointing and self.training: | |
| 826 layer_outputs = self._gradient_checkpointing_func( | |
| 827 decoder_layer.__call__, | |
| 828 hidden_states, | |
| 829 causal_mask, | |
| 830 position_ids, | |
| 831 past_key_values, | |
| 832 output_attentions, | |
| 833 use_cache, | |
| 834 cache_position, | |
| 835 ) | |
| 836 else: | |
| 837 layer_outputs = decoder_layer( | |
| 838 hidden_states, | |
| 839 attention_mask=causal_mask, | |
| 840 position_ids=position_ids, | |
| 841 past_key_value=past_key_values, | |
| 842 output_attentions=output_attentions, | |
| 843 use_cache=use_cache, | |
| 844 cache_position=cache_position, | |
| 845 ) | |
| 846 | |
| 847 hidden_states = layer_outputs[0] | |
| 848 | |
| 849 if output_attentions: | |
| 850 all_self_attns += (layer_outputs[1],) | |
| 851 | |
| 852 hidden_states = self.norm(hidden_states) | |
| 853 | |
| 854 if output_hidden_states: | |
| 855 all_hidden_states += (hidden_states,) | |
| 856 | |
| 857 next_cache = past_key_values if use_cache else None | |
| 858 | |
| 859 if not return_dict: | |
| 860 return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) | |
| 861 return BaseModelOutputWithPast( | |
| 862 last_hidden_state=hidden_states, | |
| 863 past_key_values=next_cache, | |
| 864 hidden_states=all_hidden_states, | |
| 865 attentions=all_self_attns, | |
| 866 ) | |
| 867 | |
| 868 def _update_causal_mask( | |
| 869 self, | |
| 870 attention_mask: torch.Tensor, | |
| 871 input_tensor: torch.Tensor, | |
| 872 cache_position: torch.Tensor, | |
| 873 past_key_values: HybridCache, | |
| 874 output_attentions: bool, | |
| 875 ): | |
| 876 # Flash Attention currently doesn't support static cache but Gemma2 work only with static cache. | |
| 877 # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape | |
| 878 # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible | |
| 879 # as it doesn't cause dynamic control issues. | |
| 880 if self.config._attn_implementation == "flash_attention_2": | |
| 881 return attention_mask | |
| 882 | |
| 883 dtype, device = input_tensor.dtype, input_tensor.device | |
| 884 sequence_length = input_tensor.shape[1] | |
| 885 if isinstance(past_key_values, HybridCache): | |
| 886 target_length = past_key_values.get_max_cache_shape() | |
| 887 else: | |
| 888 target_length = attention_mask.shape[-1] if attention_mask is not None else input_tensor.shape[1] | |
| 889 | |
| 890 # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). | |
| 891 causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( | |
| 892 attention_mask, | |
| 893 sequence_length=sequence_length, | |
| 894 target_length=target_length, | |
| 895 dtype=dtype, | |
| 896 device=device, | |
| 897 cache_position=cache_position, | |
| 898 batch_size=input_tensor.shape[0], | |
| 899 ) | |
| 900 return causal_mask | |
| 901 | |
| 902 @staticmethod | |
| 903 def _prepare_4d_causal_attention_mask_with_cache_position( | |
| 904 attention_mask: torch.Tensor, | |
| 905 sequence_length: int, | |
| 906 target_length: int, | |
| 907 dtype: torch.dtype, | |
| 908 device: torch.device, | |
| 909 cache_position: torch.Tensor, | |
| 910 batch_size: int, | |
| 911 **kwargs, | |
| 912 ): | |
| 913 """ | |
| 914 Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape | |
| 915 `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. | |
| 916 | |
| 917 Args: | |
| 918 attention_mask (`torch.Tensor`): | |
| 919 A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape | |
| 920 `(batch_size, 1, query_length, key_value_length)`. | |
| 921 sequence_length (`int`): | |
| 922 The sequence length being processed. | |
| 923 target_length (`int`): | |
| 924 The target length: when generating with static cache, the mask should be as long as the static cache, | |
| 925 to account for the 0 padding, the part of the cache that is not filled yet. | |
| 926 dtype (`torch.dtype`): | |
| 927 The dtype to use for the 4D attention mask. | |
| 928 device (`torch.device`): | |
| 929 The device to plcae the 4D attention mask on. | |
| 930 cache_position (`torch.Tensor`): | |
| 931 Indices depicting the position of the input sequence tokens in the sequence. | |
| 932 batch_size (`torch.Tensor`): | |
| 933 Batch size. | |
| 934 """ | |
| 935 if attention_mask is not None and attention_mask.dim() == 4: | |
| 936 # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. | |
| 937 causal_mask = attention_mask | |
| 938 else: | |
| 939 min_dtype = torch.finfo(dtype).min | |
| 940 causal_mask = torch.full( | |
| 941 (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device | |
| 942 ) | |
| 943 if sequence_length != 1: | |
| 944 causal_mask = torch.triu(causal_mask, diagonal=1) | |
| 945 causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) | |
| 946 causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) | |
| 947 if attention_mask is not None: | |
| 948 causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit | |
| 949 mask_length = attention_mask.shape[-1] | |
| 950 padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] | |
| 951 padding_mask = padding_mask == 0 | |
| 952 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( | |
| 953 padding_mask, min_dtype | |
| 954 ) | |
| 955 | |
| 956 return causal_mask | |
| 957 | |
| 958 | |
| 959 class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): | |
| 960 _tied_weights_keys = ["lm_head.weight"] | |
| 961 | |
| 962 def __init__(self, config): | |
| 963 super().__init__(config) | |
| 964 self.model = Gemma2Model(config) | |
| 965 self.vocab_size = config.vocab_size | |
| 966 self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| 967 | |
| 968 # Initialize weights and apply final processing | |
| 969 self.post_init() | |
| 970 | |
| 971 def get_input_embeddings(self): | |
| 972 return self.model.embed_tokens | |
| 973 | |
| 974 def set_input_embeddings(self, value): | |
| 975 self.model.embed_tokens = value | |
| 976 | |
| 977 def get_output_embeddings(self): | |
| 978 return self.lm_head | |
| 979 | |
| 980 def set_output_embeddings(self, new_embeddings): | |
| 981 self.lm_head = new_embeddings | |
| 982 | |
| 983 def set_decoder(self, decoder): | |
| 984 self.model = decoder | |
| 985 | |
| 986 def get_decoder(self): | |
| 987 return self.model | |
| 988 | |
| 989 @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) | |
| 990 @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) | |
| 991 def forward( | |
| 992 self, | |
| 993 input_ids: torch.LongTensor = None, | |
| 994 attention_mask: Optional[torch.Tensor] = None, | |
| 995 position_ids: Optional[torch.LongTensor] = None, | |
| 996 past_key_values: Optional[HybridCache] = None, | |
| 997 inputs_embeds: Optional[torch.FloatTensor] = None, | |
| 998 labels: Optional[torch.LongTensor] = None, | |
| 999 use_cache: Optional[bool] = None, | |
| 1000 output_attentions: Optional[bool] = None, | |
| 1001 output_hidden_states: Optional[bool] = None, | |
| 1002 return_dict: Optional[bool] = None, | |
| 1003 cache_position: Optional[torch.LongTensor] = None, | |
| 1004 num_logits_to_keep: int = 0, | |
| 1005 **loss_kwargs, | |
| 1006 ) -> Union[Tuple, CausalLMOutputWithPast]: | |
| 1007 r""" | |
| 1008 Args: | |
| 1009 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| 1010 Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
| 1011 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
| 1012 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
| 1013 | |
| 1014 num_logits_to_keep (`int`, *optional*): | |
| 1015 Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all | |
| 1016 `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that | |
| 1017 token can save memory, which becomes pretty significant for long sequences or large vocabulary size. | |
| 1018 | |
| 1019 Returns: | |
| 1020 | |
| 1021 Example: | |
| 1022 | |
| 1023 ```python | |
| 1024 >>> from transformers import AutoTokenizer, GemmaForCausalLM | |
| 1025 | |
| 1026 >>> model = GemmaForCausalLM.from_pretrained("google/gemma-2-9b") | |
| 1027 >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") | |
| 1028 | |
| 1029 >>> prompt = "What is your favorite condiment?" | |
| 1030 >>> inputs = tokenizer(prompt, return_tensors="pt") | |
| 1031 | |
| 1032 >>> # Generate | |
| 1033 >>> generate_ids = model.generate(inputs.input_ids, max_length=30) | |
| 1034 >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| 1035 "What is your favorite condiment?" | |
| 1036 ```""" | |
| 1037 | |
| 1038 if self.training and self.config._attn_implementation != "eager": | |
| 1039 logger.warning_once( | |
| 1040 "It is strongly recommended to train Gemma2 models with the `eager` attention implementation " | |
| 1041 f"instead of `{self.config._attn_implementation}`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`." | |
| 1042 ) | |
| 1043 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| 1044 output_hidden_states = ( | |
| 1045 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| 1046 ) | |
| 1047 return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| 1048 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
| 1049 outputs = self.model( | |
| 1050 input_ids=input_ids, | |
| 1051 attention_mask=attention_mask, | |
| 1052 position_ids=position_ids, | |
| 1053 past_key_values=past_key_values, | |
| 1054 inputs_embeds=inputs_embeds, | |
| 1055 use_cache=use_cache, | |
| 1056 output_attentions=output_attentions, | |
| 1057 output_hidden_states=output_hidden_states, | |
| 1058 return_dict=return_dict, | |
| 1059 cache_position=cache_position, | |
| 1060 ) | |
| 1061 | |
| 1062 hidden_states = outputs[0] | |
| 1063 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
| 1064 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) | |
| 1065 if self.config.final_logit_softcapping is not None: | |
| 1066 logits = logits / self.config.final_logit_softcapping | |
| 1067 logits = torch.tanh(logits) | |
| 1068 logits = logits * self.config.final_logit_softcapping | |
| 1069 | |
| 1070 loss = None | |
| 1071 if labels is not None: | |
| 1072 loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) | |
| 1073 | |
| 1074 if not return_dict: | |
| 1075 output = (logits,) + outputs[1:] | |
| 1076 return (loss,) + output if loss is not None else output | |
| 1077 | |
| 1078 return CausalLMOutputWithPast( | |
| 1079 loss=loss, | |
| 1080 logits=logits, | |
| 1081 past_key_values=outputs.past_key_values, | |
| 1082 hidden_states=outputs.hidden_states, | |
| 1083 attentions=outputs.attentions, | |
| 1084 ) | |
| 1085 | |
| 1086 def prepare_inputs_for_generation( | |
| 1087 self, | |
| 1088 input_ids, | |
| 1089 past_key_values=None, | |
| 1090 attention_mask=None, | |
| 1091 inputs_embeds=None, | |
| 1092 cache_position=None, | |
| 1093 position_ids=None, | |
| 1094 use_cache=True, | |
| 1095 num_logits_to_keep=None, | |
| 1096 **kwargs, | |
| 1097 ): | |
| 1098 # Overwritten: has a special cache type, `HybridCache` | |
| 1099 | |
| 1100 # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens | |
| 1101 # Exception 1: when passing input_embeds, input_ids may be missing entries | |
| 1102 # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here | |
| 1103 if past_key_values is not None: | |
| 1104 if inputs_embeds is not None: # Exception 1 | |
| 1105 input_ids = input_ids[:, -cache_position.shape[0] :] | |
| 1106 elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) | |
| 1107 input_ids = input_ids[:, cache_position] | |
| 1108 if attention_mask is not None and position_ids is None: | |
| 1109 # create position_ids on the fly for batch generation | |
| 1110 position_ids = attention_mask.long().cumsum(-1) - 1 | |
| 1111 position_ids.masked_fill_(attention_mask == 0, 1) | |
| 1112 if past_key_values: | |
| 1113 position_ids = position_ids[:, -input_ids.shape[1] :] | |
| 1114 # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s | |
| 1115 # `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride | |
| 1116 # during the decoding. Here, simply using `.contiguous()` is not sufficient as in the | |
| 1117 # batch size = 1 case, `position_ids` is already contiguous but with varying stride | |
| 1118 # which retriggers a capture. | |
| 1119 position_ids = position_ids.clone(memory_format=torch.contiguous_format) | |
| 1120 | |
| 1121 # if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
| 1122 if inputs_embeds is not None and cache_position[0] == 0: | |
| 1123 model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} | |
| 1124 else: | |
| 1125 # The clone here is for the same reason as for `position_ids`. | |
| 1126 model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} | |
| 1127 | |
| 1128 if ( | |
| 1129 isinstance(past_key_values, HybridCache) | |
| 1130 and attention_mask.ndim == 2 | |
| 1131 and not self.config._attn_implementation == "flash_attention_2" | |
| 1132 ): | |
| 1133 if model_inputs["inputs_embeds"] is not None: | |
| 1134 batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape | |
| 1135 device = model_inputs["inputs_embeds"].device | |
| 1136 else: | |
| 1137 batch_size, sequence_length = model_inputs["input_ids"].shape | |
| 1138 device = model_inputs["input_ids"].device | |
| 1139 | |
| 1140 attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position( | |
| 1141 attention_mask, | |
| 1142 sequence_length=sequence_length, | |
| 1143 target_length=past_key_values.get_max_cache_shape(), | |
| 1144 dtype=self.lm_head.weight.dtype, | |
| 1145 device=device, | |
| 1146 cache_position=cache_position, | |
| 1147 batch_size=batch_size, | |
| 1148 ) | |
| 1149 | |
| 1150 if num_logits_to_keep is not None: | |
| 1151 model_inputs["num_logits_to_keep"] = num_logits_to_keep | |
| 1152 | |
| 1153 model_inputs.update( | |
| 1154 { | |
| 1155 "position_ids": position_ids, | |
| 1156 "cache_position": cache_position, | |
| 1157 "past_key_values": past_key_values, | |
| 1158 "use_cache": use_cache, | |
| 1159 "attention_mask": attention_mask, | |
| 1160 } | |
| 1161 ) | |
| 1162 return model_inputs | |
| 1163 | |
| 1164 | |
| 1165 @add_start_docstrings( | |
| 1166 """ | |
| 1167 The Gemma2 Model transformer with a sequence classification head on top (linear layer). | |
| 1168 | |
| 1169 [`Gemma2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models | |
| 1170 (e.g. GPT-2) do. | |
| 1171 | |
| 1172 Since it does classification on the last token, it requires to know the position of the last token. If a | |
| 1173 `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If | |
| 1174 no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the | |
| 1175 padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in | |
| 1176 each row of the batch). | |
| 1177 """, | |
| 1178 GEMMA2_START_DOCSTRING, | |
| 1179 ) | |
| 1180 class Gemma2ForSequenceClassification(Gemma2PreTrainedModel): | |
| 1181 def __init__(self, config): | |
| 1182 super().__init__(config) | |
| 1183 self.num_labels = config.num_labels | |
| 1184 self.model = Gemma2Model(config) | |
| 1185 self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) | |
| 1186 | |
| 1187 # Initialize weights and apply final processing | |
| 1188 self.post_init() | |
| 1189 | |
| 1190 def get_input_embeddings(self): | |
| 1191 return self.model.embed_tokens | |
| 1192 | |
| 1193 def set_input_embeddings(self, value): | |
| 1194 self.model.embed_tokens = value | |
| 1195 | |
| 1196 @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) | |
| 1197 def forward( | |
| 1198 self, | |
| 1199 input_ids: Optional[torch.LongTensor] = None, | |
| 1200 attention_mask: Optional[torch.Tensor] = None, | |
| 1201 position_ids: Optional[torch.LongTensor] = None, | |
| 1202 past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, | |
| 1203 inputs_embeds: Optional[torch.FloatTensor] = None, | |
| 1204 labels: Optional[torch.LongTensor] = None, | |
| 1205 use_cache: Optional[bool] = None, | |
| 1206 output_attentions: Optional[bool] = None, | |
| 1207 output_hidden_states: Optional[bool] = None, | |
| 1208 return_dict: Optional[bool] = None, | |
| 1209 ) -> Union[Tuple, SequenceClassifierOutputWithPast]: | |
| 1210 r""" | |
| 1211 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
| 1212 Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | |
| 1213 config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | |
| 1214 `config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
| 1215 """ | |
| 1216 return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| 1217 | |
| 1218 transformer_outputs = self.model( | |
| 1219 input_ids, | |
| 1220 attention_mask=attention_mask, | |
| 1221 position_ids=position_ids, | |
| 1222 past_key_values=past_key_values, | |
| 1223 inputs_embeds=inputs_embeds, | |
| 1224 use_cache=use_cache, | |
| 1225 output_attentions=output_attentions, | |
| 1226 output_hidden_states=output_hidden_states, | |
| 1227 return_dict=return_dict, | |
| 1228 ) | |
| 1229 hidden_states = transformer_outputs[0] | |
| 1230 logits = self.score(hidden_states) | |
| 1231 | |
| 1232 if input_ids is not None: | |
| 1233 batch_size = input_ids.shape[0] | |
| 1234 else: | |
| 1235 batch_size = inputs_embeds.shape[0] | |
| 1236 | |
| 1237 if self.config.pad_token_id is None and batch_size != 1: | |
| 1238 raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") | |
| 1239 if self.config.pad_token_id is None: | |
| 1240 sequence_lengths = -1 | |
| 1241 else: | |
| 1242 if input_ids is not None: | |
| 1243 # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility | |
| 1244 sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 | |
| 1245 sequence_lengths = sequence_lengths % input_ids.shape[-1] | |
| 1246 sequence_lengths = sequence_lengths.to(logits.device) | |
| 1247 else: | |
| 1248 sequence_lengths = -1 | |
| 1249 | |
| 1250 pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] | |
| 1251 | |
| 1252 loss = None | |
| 1253 if labels is not None: | |
| 1254 loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) | |
| 1255 | |
| 1256 if not return_dict: | |
| 1257 output = (pooled_logits,) + transformer_outputs[1:] | |
| 1258 return ((loss,) + output) if loss is not None else output | |
| 1259 | |
| 1260 return SequenceClassifierOutputWithPast( | |
| 1261 loss=loss, | |
| 1262 logits=pooled_logits, | |
| 1263 past_key_values=transformer_outputs.past_key_values, | |
| 1264 hidden_states=transformer_outputs.hidden_states, | |
| 1265 attentions=transformer_outputs.attentions, | |
| 1266 ) | |
| 1267 | |
| 1268 | |
| 1269 @add_start_docstrings( | |
| 1270 """ | |
| 1271 The Gemma2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states | |
| 1272 output) e.g. for Named-Entity-Recognition (NER) tasks. | |
| 1273 """, | |
| 1274 GEMMA2_START_DOCSTRING, | |
| 1275 ) | |
| 1276 class Gemma2ForTokenClassification(Gemma2PreTrainedModel): | |
| 1277 def __init__(self, config): | |
| 1278 super().__init__(config) | |
| 1279 self.num_labels = config.num_labels | |
| 1280 self.model = Gemma2Model(config) | |
| 1281 if getattr(config, "classifier_dropout", None) is not None: | |
| 1282 classifier_dropout = config.classifier_dropout | |
| 1283 elif getattr(config, "hidden_dropout", None) is not None: | |
| 1284 classifier_dropout = config.hidden_dropout | |
| 1285 else: | |
| 1286 classifier_dropout = 0.1 | |
| 1287 self.dropout = nn.Dropout(classifier_dropout) | |
| 1288 self.score = nn.Linear(config.hidden_size, config.num_labels) | |
| 1289 | |
| 1290 # Initialize weights and apply final processing | |
| 1291 self.post_init() | |
| 1292 | |
| 1293 def get_input_embeddings(self): | |
| 1294 return self.model.embed_tokens | |
| 1295 | |
| 1296 def set_input_embeddings(self, value): | |
| 1297 self.model.embed_tokens = value | |
| 1298 | |
| 1299 @add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING) | |
| 1300 @add_code_sample_docstrings( | |
| 1301 checkpoint=_CHECKPOINT_FOR_DOC, | |
| 1302 output_type=TokenClassifierOutput, | |
| 1303 config_class=_CONFIG_FOR_DOC, | |
| 1304 ) | |
| 1305 def forward( | |
| 1306 self, | |
| 1307 input_ids: Optional[torch.LongTensor] = None, | |
| 1308 attention_mask: Optional[torch.Tensor] = None, | |
| 1309 position_ids: Optional[torch.LongTensor] = None, | |
| 1310 past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| 1311 inputs_embeds: Optional[torch.FloatTensor] = None, | |
| 1312 labels: Optional[torch.LongTensor] = None, | |
| 1313 use_cache: Optional[bool] = None, | |
| 1314 output_attentions: Optional[bool] = None, | |
| 1315 output_hidden_states: Optional[bool] = None, | |
| 1316 return_dict: Optional[bool] = None, | |
| 1317 ) -> Union[Tuple, TokenClassifierOutput]: | |
| 1318 r""" | |
| 1319 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
| 1320 Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | |
| 1321 config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | |
| 1322 `config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
| 1323 """ | |
| 1324 return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| 1325 | |
| 1326 outputs = self.model( | |
| 1327 input_ids, | |
| 1328 attention_mask=attention_mask, | |
| 1329 position_ids=position_ids, | |
| 1330 past_key_values=past_key_values, | |
| 1331 inputs_embeds=inputs_embeds, | |
| 1332 use_cache=use_cache, | |
| 1333 output_attentions=output_attentions, | |
| 1334 output_hidden_states=output_hidden_states, | |
| 1335 return_dict=return_dict, | |
| 1336 ) | |
| 1337 sequence_output = outputs[0] | |
| 1338 sequence_output = self.dropout(sequence_output) | |
| 1339 logits = self.score(sequence_output) | |
| 1340 | |
| 1341 loss = None | |
| 1342 if labels is not None: | |
| 1343 loss = self.loss_function(logits, labels, self.config) | |
| 1344 | |
| 1345 if not return_dict: | |
| 1346 output = (logits,) + outputs[2:] | |
| 1347 return ((loss,) + output) if loss is not None else output | |
| 1348 | |
| 1349 return TokenClassifierOutput( | |
| 1350 loss=loss, | |
| 1351 logits=logits, | |
| 1352 hidden_states=outputs.hidden_states, | |
| 1353 attentions=outputs.attentions, | |
| 1354 ) | |