Spaces:
Sleeping
Sleeping
File size: 3,609 Bytes
6376749 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
# coding=utf-8
# Copyright 2023-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from peft import (
LoraConfig,
PrefixTuningConfig,
PromptEncoderConfig,
PromptTuningConfig,
)
CONFIG_CLASSES = (
LoraConfig,
PrefixTuningConfig,
PromptEncoderConfig,
PromptTuningConfig,
)
CONFIG_TESTING_KWARGS = (
{
"r": 8,
"lora_alpha": 32,
"target_modules": None,
"lora_dropout": 0.05,
"bias": "none",
"task_type": "CAUSAL_LM",
},
{
"num_virtual_tokens": 10,
"task_type": "CAUSAL_LM",
},
{
"num_virtual_tokens": 10,
"encoder_hidden_size": 32,
"task_type": "CAUSAL_LM",
},
{
"num_virtual_tokens": 10,
"task_type": "CAUSAL_LM",
},
)
CLASSES_MAPPING = {
"lora": (LoraConfig, CONFIG_TESTING_KWARGS[0]),
"prefix_tuning": (PrefixTuningConfig, CONFIG_TESTING_KWARGS[1]),
"prompt_encoder": (PromptEncoderConfig, CONFIG_TESTING_KWARGS[2]),
"prompt_tuning": (PromptTuningConfig, CONFIG_TESTING_KWARGS[3]),
}
# Adapted from https://github.com/huggingface/transformers/blob/48327c57182fdade7f7797d1eaad2d166de5c55b/src/transformers/activations.py#LL166C7-L166C22
class ClassInstantier(OrderedDict):
def __getitem__(self, key, *args, **kwargs):
# check if any of the kwargs is inside the config class kwargs
if any([kwarg in self[key][1] for kwarg in kwargs]):
new_config_kwargs = self[key][1].copy()
new_config_kwargs.update(kwargs)
return (self[key][0], new_config_kwargs)
return super().__getitem__(key, *args, **kwargs)
def get_grid_parameters(self, model_list):
r"""
Returns a list of all possible combinations of the parameters in the config classes.
"""
grid_parameters = []
for model_tuple in model_list:
model_id, lora_kwargs, prefix_tuning_kwargs, prompt_encoder_kwargs, prompt_tuning_kwargs = model_tuple
for key, value in self.items():
if key == "lora":
# update value[1] if necessary
if lora_kwargs is not None:
value[1].update(lora_kwargs)
elif key == "prefix_tuning":
# update value[1] if necessary
if prefix_tuning_kwargs is not None:
value[1].update(prefix_tuning_kwargs)
elif key == "prompt_encoder":
# update value[1] if necessary
if prompt_encoder_kwargs is not None:
value[1].update(prompt_encoder_kwargs)
else:
# update value[1] if necessary
if prompt_tuning_kwargs is not None:
value[1].update(prompt_tuning_kwargs)
grid_parameters.append((f"test_{model_id}_{key}", model_id, value[0], value[1]))
return grid_parameters
PeftTestConfigManager = ClassInstantier(CLASSES_MAPPING)
|