File size: 1,029 Bytes
bd3a23c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0dabde8
bd3a23c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

    def override(self, attrs):
        if isinstance(attrs, dict):
            self.__dict__.update(**attrs)
        elif isinstance(attrs, (list, tuple, set)):
            for attr in attrs:
                self.override(attr)
        elif attrs is not None:
            raise NotImplementedError
        return self


all_params = {
    'Plugin_freevc': AttrDict(
    # Diff params
    diff=AttrDict(
        num_train_steps=1000,
        beta_start=1e-4,
        beta_end=0.02,
        num_infer_steps=50,
        v_prediction=True,
    ),

    text_encoder=AttrDict(
        model='google/flan-t5-base'
    ),
    opt=AttrDict(
        learning_rate=1e-4,
        beta1=0.9,
        beta2=0.999,
        weight_decay=1e-4,
        adam_epsilon=1e-08,
    ),),
}

def get_params(name):
    return all_params[name]