Update custom_generate/generate.py
Browse files- custom_generate/generate.py +10 -6
custom_generate/generate.py
CHANGED
|
@@ -5,7 +5,7 @@ import torch
|
|
| 5 |
import torch.nn as nn
|
| 6 |
|
| 7 |
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
| 8 |
-
from transformers.cache_utils import Cache, EncoderDecoderCache
|
| 9 |
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
from transformers.generation.utils import (
|
| 11 |
ALL_CACHE_NAMES,
|
|
@@ -249,13 +249,17 @@ def _contrastive_search(
|
|
| 249 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
| 250 |
"for contrastive search."
|
| 251 |
)
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
):
|
| 256 |
raise ValueError(
|
| 257 |
-
f"
|
| 258 |
-
"
|
| 259 |
)
|
| 260 |
|
| 261 |
# contrastive_search main logic start:
|
|
|
|
| 5 |
import torch.nn as nn
|
| 6 |
|
| 7 |
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
| 8 |
+
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
| 9 |
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
from transformers.generation.utils import (
|
| 11 |
ALL_CACHE_NAMES,
|
|
|
|
| 249 |
f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
|
| 250 |
"for contrastive search."
|
| 251 |
)
|
| 252 |
+
# Only those caches have the necesary methods
|
| 253 |
+
elif not (
|
| 254 |
+
isinstance(past_key_values, DynamicCache)
|
| 255 |
+
or (
|
| 256 |
+
isinstance(past_key_values, EncoderDecoderCache)
|
| 257 |
+
and isinstance(past_key_values.self_attention_cache, DynamicCache)
|
| 258 |
+
)
|
| 259 |
):
|
| 260 |
raise ValueError(
|
| 261 |
+
f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
|
| 262 |
+
"dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
|
| 263 |
)
|
| 264 |
|
| 265 |
# contrastive_search main logic start:
|