File size: 4,635 Bytes
8ee3c60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1ff731
8ee3c60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
model:
  base_learning_rate: 1.0e-05
  target: ldm.models.diffusion.ddpm.LatentUpscaleDiffusion
  params:
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    image_size: 64
    channels: 4
    cond_stage_trainable: false   # Note: different from the one we trained before
    conditioning_key: "hybrid-adm"
    monitor: val/loss_simple_ema
    scale_factor: 0.18215
    low_scale_key: "lr"

    low_scale_config:
      target: ldm.modules.encoders.modules.LowScaleEncoder
      params:
        scale_factor: 0.18215
        linear_start: 0.00085
        linear_end: 0.0120
        timesteps: 1000
        max_noise_level: 100
        output_size: null
        model_config:
          target: ldm.models.autoencoder.AutoencoderKL
          params:
            embed_dim: 4
            monitor: val/rec_loss
            ddconfig:
              double_z: true
              z_channels: 4
              resolution: 256
              in_channels: 3
              out_ch: 3
              ch: 128
              ch_mult:
                - 1
                - 2
                - 4
                - 4
              num_res_blocks: 2
              attn_resolutions: [ ]
              dropout: 0.0
            lossconfig:
              target: torch.nn.Identity

    scheduler_config: # 10000 warmup steps
      target: ldm.lr_scheduler.LambdaLinearScheduler
      params:
        warm_up_steps: [ 5000 ] # NOTE for resuming. use 10000 if starting from scratch
        cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
        f_start: [ 1.e-6 ]
        f_max: [ 1. ]
        f_min: [ 1. ]

    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        num_classes: 1000
        image_size: 32 # unused
        in_channels: 8
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config:
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder


data:
  target: ldm.data.laion.WebDataModuleFromConfig
  params:
    tar_base: laion/improved_aesthetics_6plus/ims
    batch_size: 3
    num_workers: 2
    multinode: True
    train:
      shards: '{00000..01209}.tar'
      shuffle: 10000
      image_key: jpg
      image_transforms:
      - target: torchvision.transforms.Resize
        params:
          size: 512
          interpolation: 3
      - target: torchvision.transforms.RandomCrop
        params:
          size: 512
      postprocess:
        target: ldm.data.laion.AddLR
        params:
          factor: 4
          output_size: 512

    # NOTE use enough shards to avoid empty validation loops in workers
    validation:
      shards: '{00000..00012}.tar'
      shuffle: 0
      image_key: jpg
      image_transforms:
      - target: torchvision.transforms.Resize
        params:
          size: 512
          interpolation: 3
      - target: torchvision.transforms.CenterCrop
        params:
          size: 512
      postprocess:
        target: ldm.data.laion.AddLR
        params:
          factor: 4
          output_size: 512


lightning:
  find_unused_parameters: False

  modelcheckpoint:
    params:
      every_n_train_steps: 5000

  callbacks:
    image_logger:
      target: main.ImageLogger
      params:
        batch_frequency: 1000
        max_images: 4
        increase_log_steps: False
        log_first_step: False
        log_images_kwargs:
          use_ema_scope: False
          inpaint: False
          plot_progressive_rows: False
          plot_diffusion_rows: False
          N: 4
          unconditional_guidance_scale: 3.0
          unconditional_guidance_label: [""]

  trainer:
    benchmark: True
    val_check_interval: 5000000  # really sorry
    num_sanity_val_steps: 0
    accumulate_grad_batches: 4