File size: 1,029 Bytes
bd3a23c 0dabde8 bd3a23c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import numpy as np
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def override(self, attrs):
if isinstance(attrs, dict):
self.__dict__.update(**attrs)
elif isinstance(attrs, (list, tuple, set)):
for attr in attrs:
self.override(attr)
elif attrs is not None:
raise NotImplementedError
return self
all_params = {
'Plugin_freevc': AttrDict(
# Diff params
diff=AttrDict(
num_train_steps=1000,
beta_start=1e-4,
beta_end=0.02,
num_infer_steps=50,
v_prediction=True,
),
text_encoder=AttrDict(
model='google/flan-t5-base'
),
opt=AttrDict(
learning_rate=1e-4,
beta1=0.9,
beta2=0.999,
weight_decay=1e-4,
adam_epsilon=1e-08,
),),
}
def get_params(name):
return all_params[name] |