|
import numpy as np
|
|
|
|
|
|
class AttrDict(dict):
|
|
def __init__(self, *args, **kwargs):
|
|
super(AttrDict, self).__init__(*args, **kwargs)
|
|
self.__dict__ = self
|
|
|
|
def override(self, attrs):
|
|
if isinstance(attrs, dict):
|
|
self.__dict__.update(**attrs)
|
|
elif isinstance(attrs, (list, tuple, set)):
|
|
for attr in attrs:
|
|
self.override(attr)
|
|
elif attrs is not None:
|
|
raise NotImplementedError
|
|
return self
|
|
|
|
|
|
all_params = {
|
|
'Plugin_freevc': AttrDict(
|
|
|
|
diff=AttrDict(
|
|
num_train_steps=1000,
|
|
beta_start=1e-4,
|
|
beta_end=0.02,
|
|
num_infer_steps=50,
|
|
v_prediction=True,
|
|
),
|
|
|
|
text_encoder=AttrDict(
|
|
model='google/flan-t5-base'
|
|
),
|
|
opt=AttrDict(
|
|
learning_rate=1e-4,
|
|
beta1=0.9,
|
|
beta2=0.999,
|
|
weight_decay=1e-4,
|
|
adam_epsilon=1e-08,
|
|
),),
|
|
}
|
|
|
|
def get_params(name):
|
|
return all_params[name] |