Update custom_generate/generate.py
Browse files- custom_generate/generate.py +10 -6
    	
        custom_generate/generate.py
    CHANGED
    
    | @@ -5,7 +5,7 @@ import torch | |
| 5 | 
             
            import torch.nn as nn
         | 
| 6 |  | 
| 7 | 
             
            from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
         | 
| 8 | 
            -
            from transformers.cache_utils import Cache, EncoderDecoderCache
         | 
| 9 | 
             
            from transformers.configuration_utils import PretrainedConfig
         | 
| 10 | 
             
            from transformers.generation.utils import (
         | 
| 11 | 
             
                ALL_CACHE_NAMES,
         | 
| @@ -249,13 +249,17 @@ def _contrastive_search( | |
| 249 | 
             
                                f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
         | 
| 250 | 
             
                                "for contrastive search."
         | 
| 251 | 
             
                            )
         | 
| 252 | 
            -
                         | 
| 253 | 
            -
             | 
| 254 | 
            -
                             | 
|  | |
|  | |
|  | |
|  | |
| 255 | 
             
                        ):
         | 
| 256 | 
             
                            raise ValueError(
         | 
| 257 | 
            -
                                f" | 
| 258 | 
            -
                                " | 
| 259 | 
             
                            )
         | 
| 260 |  | 
| 261 | 
             
                    # contrastive_search main logic start:
         | 
|  | |
| 5 | 
             
            import torch.nn as nn
         | 
| 6 |  | 
| 7 | 
             
            from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
         | 
| 8 | 
            +
            from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
         | 
| 9 | 
             
            from transformers.configuration_utils import PretrainedConfig
         | 
| 10 | 
             
            from transformers.generation.utils import (
         | 
| 11 | 
             
                ALL_CACHE_NAMES,
         | 
|  | |
| 249 | 
             
                                f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
         | 
| 250 | 
             
                                "for contrastive search."
         | 
| 251 | 
             
                            )
         | 
| 252 | 
            +
                        # Only those caches have the necesary methods
         | 
| 253 | 
            +
                        elif not (
         | 
| 254 | 
            +
                            isinstance(past_key_values, DynamicCache)
         | 
| 255 | 
            +
                            or (
         | 
| 256 | 
            +
                                isinstance(past_key_values, EncoderDecoderCache)
         | 
| 257 | 
            +
                                and isinstance(past_key_values.self_attention_cache, DynamicCache)
         | 
| 258 | 
            +
                            )
         | 
| 259 | 
             
                        ):
         | 
| 260 | 
             
                            raise ValueError(
         | 
| 261 | 
            +
                                f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
         | 
| 262 | 
            +
                                "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
         | 
| 263 | 
             
                            )
         | 
| 264 |  | 
| 265 | 
             
                    # contrastive_search main logic start:
         | 

