File size: 34,576 Bytes
f69d33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
# Copyright 2022 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Main file for the MAXIM model."""

import functools
from typing import Any, Sequence, Tuple

import einops
import flax.linen as nn
import jax
import jax.numpy as jnp


Conv3x3 = functools.partial(nn.Conv, kernel_size=(3, 3))
Conv1x1 = functools.partial(nn.Conv, kernel_size=(1, 1))
ConvT_up = functools.partial(nn.ConvTranspose,
                             kernel_size=(2, 2),
                             strides=(2, 2))
Conv_down = functools.partial(nn.Conv,
                              kernel_size=(4, 4),
                              strides=(2, 2))

weight_initializer = nn.initializers.normal(stddev=2e-2)


class MlpBlock(nn.Module):
  """A 1-hidden-layer MLP block, applied over the last dimension."""
  mlp_dim: int
  dropout_rate: float = 0.0
  use_bias: bool = True

  @nn.compact
  def __call__(self, x, deterministic=True):
    n, h, w, d = x.shape
    x = nn.Dense(self.mlp_dim, use_bias=self.use_bias,
                 kernel_init=weight_initializer)(x)
    x = nn.gelu(x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic)
    x = nn.Dense(d, use_bias=self.use_bias,
                 kernel_init=weight_initializer)(x)
    return x


def block_images_einops(x, patch_size):
  """Image to patches."""
  batch, height, width, channels = x.shape
  grid_height = height // patch_size[0]
  grid_width = width // patch_size[1]
  x = einops.rearrange(
      x, "n (gh fh) (gw fw) c -> n (gh gw) (fh fw) c",
      gh=grid_height, gw=grid_width, fh=patch_size[0], fw=patch_size[1])
  return x


def unblock_images_einops(x, grid_size, patch_size):
  """patches to images."""
  x = einops.rearrange(
      x, "n (gh gw) (fh fw) c -> n (gh fh) (gw fw) c",
      gh=grid_size[0], gw=grid_size[1], fh=patch_size[0], fw=patch_size[1])
  return x


class UpSampleRatio(nn.Module):
  """Upsample features given a ratio > 0."""
  features: int
  ratio: float
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    n, h, w, c = x.shape
    x = jax.image.resize(
        x,
        shape=(n, int(h * self.ratio), int(w * self.ratio), c),
        method="bilinear")
    x = Conv1x1(features=self.features, use_bias=self.use_bias)(x)
    return x


class CALayer(nn.Module):
  """Squeeze-and-excitation block for channel attention.



  ref: https://arxiv.org/abs/1709.01507

  """
  features: int
  reduction: int = 4
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    # 2D global average pooling
    y = jnp.mean(x, axis=[1, 2], keepdims=True)
    # Squeeze (in Squeeze-Excitation)
    y = Conv1x1(self.features // self.reduction, use_bias=self.use_bias)(y)
    y = nn.relu(y)
    # Excitation (in Squeeze-Excitation)
    y = Conv1x1(self.features, use_bias=self.use_bias)(y)
    y = nn.sigmoid(y)
    return x * y


class RCAB(nn.Module):
  """Residual channel attention block. Contains LN,Conv,lRelu,Conv,SELayer."""
  features: int
  reduction: int = 4
  lrelu_slope: float = 0.2
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    shortcut = x
    x = nn.LayerNorm(name="LayerNorm")(x)
    x = Conv3x3(features=self.features, use_bias=self.use_bias, name="conv1")(x)
    x = nn.leaky_relu(x, negative_slope=self.lrelu_slope)
    x = Conv3x3(features=self.features, use_bias=self.use_bias, name="conv2")(x)
    x = CALayer(features=self.features, reduction=self.reduction,
                use_bias=self.use_bias, name="channel_attention")(x)
    return x + shortcut


class GridGatingUnit(nn.Module):
  """A SpatialGatingUnit as defined in the gMLP paper.



  The 'spatial' dim is defined as the second last.

  If applied on other dims, you should swapaxes first.

  """
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    u, v = jnp.split(x, 2, axis=-1)
    v = nn.LayerNorm(name="intermediate_layernorm")(v)
    n = x.shape[-3]   # get spatial dim
    v = jnp.swapaxes(v, -1, -3)
    v = nn.Dense(n, use_bias=self.use_bias, kernel_init=weight_initializer)(v)
    v = jnp.swapaxes(v, -1, -3)
    return u * (v + 1.)


class GridGmlpLayer(nn.Module):
  """Grid gMLP layer that performs global mixing of tokens."""
  grid_size: Sequence[int]
  use_bias: bool = True
  factor: int = 2
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x, deterministic=True):
    n, h, w, num_channels = x.shape
    gh, gw = self.grid_size
    fh, fw = h // gh, w // gw
    x = block_images_einops(x, patch_size=(fh, fw))
    # gMLP1: Global (grid) mixing part, provides global grid communication.
    y = nn.LayerNorm(name="LayerNorm")(x)
    y = nn.Dense(num_channels * self.factor, use_bias=self.use_bias,
                 kernel_init=weight_initializer, name="in_project")(y)
    y = nn.gelu(y)
    y = GridGatingUnit(use_bias=self.use_bias, name="GridGatingUnit")(y)
    y = nn.Dense(num_channels, use_bias=self.use_bias,
                 kernel_init=weight_initializer, name="out_project")(y)
    y = nn.Dropout(self.dropout_rate)(y, deterministic)
    x = x + y
    x = unblock_images_einops(x, grid_size=(gh, gw), patch_size=(fh, fw))
    return x


class BlockGatingUnit(nn.Module):
  """A SpatialGatingUnit as defined in the gMLP paper.



  The 'spatial' dim is defined as the **second last**.

  If applied on other dims, you should swapaxes first.

  """
  use_bias: bool = True

  @nn.compact
  def __call__(self, x):
    u, v = jnp.split(x, 2, axis=-1)
    v = nn.LayerNorm(name="intermediate_layernorm")(v)
    n = x.shape[-2]  # get spatial dim
    v = jnp.swapaxes(v, -1, -2)
    v = nn.Dense(n, use_bias=self.use_bias, kernel_init=weight_initializer)(v)
    v = jnp.swapaxes(v, -1, -2)
    return u * (v + 1.)


class BlockGmlpLayer(nn.Module):
  """Block gMLP layer that performs local mixing of tokens."""
  block_size: Sequence[int]
  use_bias: bool = True
  factor: int = 2
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x, deterministic=True):
    n, h, w, num_channels = x.shape
    fh, fw = self.block_size
    gh, gw = h // fh, w // fw
    x = block_images_einops(x, patch_size=(fh, fw))
    # MLP2: Local (block) mixing part, provides within-block communication.
    y = nn.LayerNorm(name="LayerNorm")(x)
    y = nn.Dense(num_channels * self.factor, use_bias=self.use_bias,
                 kernel_init=weight_initializer, name="in_project")(y)
    y = nn.gelu(y)
    y = BlockGatingUnit(use_bias=self.use_bias, name="BlockGatingUnit")(y)
    y = nn.Dense(num_channels, use_bias=self.use_bias,
                 kernel_init=weight_initializer, name="out_project")(y)
    y = nn.Dropout(self.dropout_rate)(y, deterministic)
    x = x + y
    x = unblock_images_einops(x, grid_size=(gh, gw), patch_size=(fh, fw))
    return x


class ResidualSplitHeadMultiAxisGmlpLayer(nn.Module):
  """The multi-axis gated MLP block."""
  block_size: Sequence[int]
  grid_size: Sequence[int]
  block_gmlp_factor: int = 2
  grid_gmlp_factor: int = 2
  input_proj_factor: int = 2
  use_bias: bool = True
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x, deterministic=True):
    shortcut = x
    n, h, w, num_channels = x.shape
    x = nn.LayerNorm(name="LayerNorm_in")(x)
    x = nn.Dense(num_channels * self.input_proj_factor, use_bias=self.use_bias,
                 kernel_init=weight_initializer, name="in_project")(x)
    x = nn.gelu(x)

    u, v = jnp.split(x, 2, axis=-1)
    # GridGMLPLayer
    u = GridGmlpLayer(
        grid_size=self.grid_size,
        factor=self.grid_gmlp_factor,
        use_bias=self.use_bias,
        dropout_rate=self.dropout_rate,
        name="GridGmlpLayer")(u, deterministic)

    # BlockGMLPLayer
    v = BlockGmlpLayer(
        block_size=self.block_size,
        factor=self.block_gmlp_factor,
        use_bias=self.use_bias,
        dropout_rate=self.dropout_rate,
        name="BlockGmlpLayer")(v, deterministic)

    x = jnp.concatenate([u, v], axis=-1)

    x = nn.Dense(num_channels, use_bias=self.use_bias,
                 kernel_init=weight_initializer, name="out_project")(x)
    x = nn.Dropout(self.dropout_rate)(x, deterministic)
    x = x + shortcut
    return x


class RDCAB(nn.Module):
  """Residual dense channel attention block. Used in Bottlenecks."""
  features: int
  reduction: int = 16
  use_bias: bool = True
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x, deterministic=True):
    y = nn.LayerNorm(name="LayerNorm")(x)
    y = MlpBlock(
        mlp_dim=self.features,
        dropout_rate=self.dropout_rate,
        use_bias=self.use_bias,
        name="channel_mixing")(
            y, deterministic=deterministic)
    y = CALayer(
        features=self.features,
        reduction=self.reduction,
        use_bias=self.use_bias,
        name="channel_attention")(
            y)
    x = x + y
    return x


class BottleneckBlock(nn.Module):
  """The bottleneck block consisting of multi-axis gMLP block and RDCAB."""
  features: int
  block_size: Sequence[int]
  grid_size: Sequence[int]
  num_groups: int = 1
  block_gmlp_factor: int = 2
  grid_gmlp_factor: int = 2
  input_proj_factor: int = 2
  channels_reduction: int = 4
  dropout_rate: float = 0.0
  use_bias: bool = True

  @nn.compact
  def __call__(self, x, deterministic):
    """Applies the Mixer block to inputs."""
    assert x.ndim == 4  # Input has shape [batch, h, w, c]
    n, h, w, num_channels = x.shape

    # input projection
    x = Conv1x1(self.features, use_bias=self.use_bias, name="input_proj")(x)
    shortcut_long = x

    for i in range(self.num_groups):
      x = ResidualSplitHeadMultiAxisGmlpLayer(
          grid_size=self.grid_size,
          block_size=self.block_size,
          grid_gmlp_factor=self.grid_gmlp_factor,
          block_gmlp_factor=self.block_gmlp_factor,
          input_proj_factor=self.input_proj_factor,
          use_bias=self.use_bias,
          dropout_rate=self.dropout_rate,
          name=f"SplitHeadMultiAxisGmlpLayer_{i}")(x, deterministic)
      # Channel-mixing part, which provides within-patch communication.
      x = RDCAB(
          features=self.features,
          reduction=self.channels_reduction,
          use_bias=self.use_bias,
          name=f"channel_attention_block_1_{i}")(
              x)

    # long skip-connect
    x = x + shortcut_long
    return x


class UNetEncoderBlock(nn.Module):
  """Encoder block in MAXIM."""
  features: int
  block_size: Sequence[int]
  grid_size: Sequence[int]
  num_groups: int = 1
  lrelu_slope: float = 0.2
  block_gmlp_factor: int = 2
  grid_gmlp_factor: int = 2
  input_proj_factor: int = 2
  channels_reduction: int = 4
  dropout_rate: float = 0.0
  downsample: bool = True
  use_global_mlp: bool = True
  use_bias: bool = True
  use_cross_gating: bool = False

  @nn.compact
  def __call__(self, x: jnp.ndarray, skip: jnp.ndarray = None,

               enc: jnp.ndarray = None, dec: jnp.ndarray = None, *,

               deterministic: bool = True) -> jnp.ndarray:
    if skip is not None:
      x = jnp.concatenate([x, skip], axis=-1)

    # convolution-in
    x = Conv1x1(self.features, use_bias=self.use_bias)(x)
    shortcut_long = x

    for i in range(self.num_groups):
      if self.use_global_mlp:
        x = ResidualSplitHeadMultiAxisGmlpLayer(
            grid_size=self.grid_size,
            block_size=self.block_size,
            grid_gmlp_factor=self.grid_gmlp_factor,
            block_gmlp_factor=self.block_gmlp_factor,
            input_proj_factor=self.input_proj_factor,
            use_bias=self.use_bias,
            dropout_rate=self.dropout_rate,
            name=f"SplitHeadMultiAxisGmlpLayer_{i}")(x, deterministic)
      x = RCAB(
          features=self.features,
          reduction=self.channels_reduction,
          use_bias=self.use_bias,
          name=f"channel_attention_block_1{i}")(x)

    x = x + shortcut_long

    if enc is not None and dec is not None:
      assert self.use_cross_gating
      x, _ = CrossGatingBlock(
          features=self.features,
          block_size=self.block_size,
          grid_size=self.grid_size,
          dropout_rate=self.dropout_rate,
          input_proj_factor=self.input_proj_factor,
          upsample_y=False,
          use_bias=self.use_bias,
          name="cross_gating_block")(
              x, enc + dec, deterministic=deterministic)

    if self.downsample:
      x_down = Conv_down(self.features, use_bias=self.use_bias)(x)
      return x_down, x
    else:
      return x


class UNetDecoderBlock(nn.Module):
  """Decoder block in MAXIM."""
  features: int
  block_size: Sequence[int]
  grid_size: Sequence[int]
  num_groups: int = 1
  lrelu_slope: float = 0.2
  block_gmlp_factor: int = 2
  grid_gmlp_factor: int = 2
  input_proj_factor: int = 2
  channels_reduction: int = 4
  dropout_rate: float = 0.0
  downsample: bool = True
  use_global_mlp: bool = True
  use_bias: bool = True

  @nn.compact
  def __call__(self, x: jnp.ndarray, bridge: jnp.ndarray = None,

               deterministic: bool = True) -> jnp.ndarray:
    x = ConvT_up(self.features, use_bias=self.use_bias)(x)

    x = UNetEncoderBlock(
        self.features,
        num_groups=self.num_groups,
        lrelu_slope=self.lrelu_slope,
        block_size=self.block_size,
        grid_size=self.grid_size,
        block_gmlp_factor=self.block_gmlp_factor,
        grid_gmlp_factor=self.grid_gmlp_factor,
        channels_reduction=self.channels_reduction,
        use_global_mlp=self.use_global_mlp,
        dropout_rate=self.dropout_rate,
        downsample=False,
        use_bias=self.use_bias)(x, skip=bridge, deterministic=deterministic)
    return x


class GetSpatialGatingWeights(nn.Module):
  """Get gating weights for cross-gating MLP block."""
  features: int
  block_size: Sequence[int]
  grid_size: Sequence[int]
  input_proj_factor: int = 2
  dropout_rate: float = 0.0
  use_bias: bool = True

  @nn.compact
  def __call__(self, x, deterministic):
    n, h, w, num_channels = x.shape

    # input projection
    x = nn.LayerNorm(name="LayerNorm_in")(x)
    x = nn.Dense(
        num_channels * self.input_proj_factor,
        use_bias=self.use_bias,
        name="in_project")(
            x)
    x = nn.gelu(x)
    u, v = jnp.split(x, 2, axis=-1)

    # Get grid MLP weights
    gh, gw = self.grid_size
    fh, fw = h // gh, w // gw
    u = block_images_einops(u, patch_size=(fh, fw))
    dim_u = u.shape[-3]
    u = jnp.swapaxes(u, -1, -3)
    u = nn.Dense(
        dim_u, use_bias=self.use_bias, kernel_init=nn.initializers.normal(2e-2),
        bias_init=nn.initializers.ones)(u)
    u = jnp.swapaxes(u, -1, -3)
    u = unblock_images_einops(u, grid_size=(gh, gw), patch_size=(fh, fw))

    # Get Block MLP weights
    fh, fw = self.block_size
    gh, gw = h // fh, w // fw
    v = block_images_einops(v, patch_size=(fh, fw))
    dim_v = v.shape[-2]
    v = jnp.swapaxes(v, -1, -2)
    v = nn.Dense(
        dim_v, use_bias=self.use_bias, kernel_init=nn.initializers.normal(2e-2),
        bias_init=nn.initializers.ones)(v)
    v = jnp.swapaxes(v, -1, -2)
    v = unblock_images_einops(v, grid_size=(gh, gw), patch_size=(fh, fw))

    x = jnp.concatenate([u, v], axis=-1)
    x = nn.Dense(num_channels, use_bias=self.use_bias, name="out_project")(x)
    x = nn.Dropout(self.dropout_rate)(x, deterministic)
    return x


class CrossGatingBlock(nn.Module):
  """Cross-gating MLP block."""
  features: int
  block_size: Sequence[int]
  grid_size: Sequence[int]
  dropout_rate: float = 0.0
  input_proj_factor: int = 2
  upsample_y: bool = True
  use_bias: bool = True

  @nn.compact
  def __call__(self, x, y, deterministic=True):
    # Upscale Y signal, y is the gating signal.
    if self.upsample_y:
      y = ConvT_up(self.features, use_bias=self.use_bias)(y)

    x = Conv1x1(self.features, use_bias=self.use_bias)(x)
    n, h, w, num_channels = x.shape
    y = Conv1x1(num_channels, use_bias=self.use_bias)(y)

    assert y.shape == x.shape
    shortcut_x = x
    shortcut_y = y

    # Get gating weights from X
    x = nn.LayerNorm(name="LayerNorm_x")(x)
    x = nn.Dense(num_channels, use_bias=self.use_bias, name="in_project_x")(x)
    x = nn.gelu(x)
    gx = GetSpatialGatingWeights(
        features=num_channels,
        block_size=self.block_size,
        grid_size=self.grid_size,
        dropout_rate=self.dropout_rate,
        use_bias=self.use_bias,
        name="SplitHeadMultiAxisGating_x")(
            x, deterministic=deterministic)

    # Get gating weights from Y
    y = nn.LayerNorm(name="LayerNorm_y")(y)
    y = nn.Dense(num_channels, use_bias=self.use_bias, name="in_project_y")(y)
    y = nn.gelu(y)
    gy = GetSpatialGatingWeights(
        features=num_channels,
        block_size=self.block_size,
        grid_size=self.grid_size,
        dropout_rate=self.dropout_rate,
        use_bias=self.use_bias,
        name="SplitHeadMultiAxisGating_y")(
            y, deterministic=deterministic)

    # Apply cross gating: X = X * GY, Y = Y * GX
    y = y * gx
    y = nn.Dense(num_channels, use_bias=self.use_bias, name="out_project_y")(y)
    y = nn.Dropout(self.dropout_rate)(y, deterministic=deterministic)
    y = y + shortcut_y

    x = x * gy  # gating x using y
    x = nn.Dense(num_channels, use_bias=self.use_bias, name="out_project_x")(x)
    x = nn.Dropout(self.dropout_rate)(x, deterministic=deterministic)
    x = x + y + shortcut_x  # get all aggregated signals
    return x, y


class SAM(nn.Module):
  """Supervised attention module for multi-stage training.



  Introduced by MPRNet [CVPR2021]: https://github.com/swz30/MPRNet

  """
  features: int
  output_channels: int = 3
  use_bias: bool = True

  @nn.compact
  def __call__(self, x: jnp.ndarray, x_image: jnp.ndarray, *,

               train: bool) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Apply the SAM module to the input and features.



    Args:

      x: the output features from UNet decoder with shape (h, w, c)

      x_image: the input image with shape (h, w, 3)

      train: Whether it is training



    Returns:

      A tuple of tensors (x1, image) where (x1) is the sam features used for the

        next stage, and (image) is the output restored image at current stage.

    """
    # Get features
    x1 = Conv3x3(self.features, use_bias=self.use_bias)(x)

    # Output restored image X_s
    if self.output_channels == 3:
      image = Conv3x3(self.output_channels, use_bias=self.use_bias)(x) + x_image
    else:
      image = Conv3x3(self.output_channels, use_bias=self.use_bias)(x)

    # Get attention maps for features
    x2 = nn.sigmoid(Conv3x3(self.features, use_bias=self.use_bias)(image))

    # Get attended feature maps
    x1 = x1 * x2

    # Residual connection
    x1 = x1 + x
    return x1, image


class MAXIM(nn.Module):
  """The MAXIM model function with multi-stage and multi-scale supervision.



  For more model details, please check the CVPR paper:

  MAXIM: MUlti-Axis MLP for Image Processing (https://arxiv.org/abs/2201.02973)



  Attributes:

    features: initial hidden dimension for the input resolution.

    depth: the number of downsampling depth for the model.

    num_stages: how many stages to use. It will also affects the output list.

    num_groups: how many blocks each stage contains.

    use_bias: whether to use bias in all the conv/mlp layers.

    num_supervision_scales: the number of desired supervision scales.

    lrelu_slope: the negative slope parameter in leaky_relu layers.

    use_global_mlp: whether to use the multi-axis gated MLP block (MAB) in each

      layer.

    use_cross_gating: whether to use the cross-gating MLP block (CGB) in the

      skip connections and multi-stage feature fusion layers.

    high_res_stages: how many stages are specificied as high-res stages. The

      rest (depth - high_res_stages) are called low_res_stages.

    block_size_hr: the block_size parameter for high-res stages.

    block_size_lr: the block_size parameter for low-res stages.

    grid_size_hr: the grid_size parameter for high-res stages.

    grid_size_lr: the grid_size parameter for low-res stages.

    num_bottleneck_blocks: how many bottleneck blocks.

    block_gmlp_factor: the input projection factor for block_gMLP layers.

    grid_gmlp_factor: the input projection factor for grid_gMLP layers.

    input_proj_factor: the input projection factor for the MAB block.

    channels_reduction: the channel reduction factor for SE layer.

    num_outputs: the output channels.

    dropout_rate: Dropout rate.



  Returns:

    The output contains a list of arrays consisting of multi-stage multi-scale

    outputs. For example, if num_stages = num_supervision_scales = 3 (the

    model used in the paper), the output specs are: outputs =

    [[output_stage1_scale1, output_stage1_scale2, output_stage1_scale3],

     [output_stage2_scale1, output_stage2_scale2, output_stage2_scale3],

     [output_stage3_scale1, output_stage3_scale2, output_stage3_scale3],]

    The final output can be retrieved by outputs[-1][-1].

  """
  features: int = 64
  depth: int = 3
  num_stages: int = 2
  num_groups: int = 1
  use_bias: bool = True
  num_supervision_scales: int = 1
  lrelu_slope: float = 0.2
  use_global_mlp: bool = True
  use_cross_gating: bool = True
  high_res_stages: int = 2
  block_size_hr: Sequence[int] = (16, 16)
  block_size_lr: Sequence[int] = (8, 8)
  grid_size_hr: Sequence[int] = (16, 16)
  grid_size_lr: Sequence[int] = (8, 8)
  num_bottleneck_blocks: int = 1
  block_gmlp_factor: int = 2
  grid_gmlp_factor: int = 2
  input_proj_factor: int = 2
  channels_reduction: int = 4
  num_outputs: int = 3
  dropout_rate: float = 0.0

  @nn.compact
  def __call__(self, x: jnp.ndarray, *, train: bool = False) -> Any:

    n, h, w, c = x.shape  # input image shape
    shortcuts = []
    shortcuts.append(x)
    # Get multi-scale input images
    for i in range(1, self.num_supervision_scales):
      shortcuts.append(jax.image.resize(
          x, shape=(n, h // (2**i), w // (2**i), c), method="nearest"))

    # store outputs from all stages and all scales
    # Eg, [[(64, 64, 3), (128, 128, 3), (256, 256, 3)],   # Stage-1 outputs
    #      [(64, 64, 3), (128, 128, 3), (256, 256, 3)],]  # Stage-2 outputs
    outputs_all = []
    sam_features, encs_prev, decs_prev = [], [], []

    for idx_stage in range(self.num_stages):
      # Input convolution, get multi-scale input features
      x_scales = []
      for i in range(self.num_supervision_scales):
        x_scale = Conv3x3(
            (2**i) * self.features,
            use_bias=self.use_bias,
            name=f"stage_{idx_stage}_input_conv_{i}")(
                shortcuts[i])

        # If later stages, fuse input features with SAM features from prev stage
        if idx_stage > 0:
          # use larger blocksize at high-res stages
          if self.use_cross_gating:
            block_size = self.block_size_hr if i < self.high_res_stages else self.block_size_lr
            grid_size = self.grid_size_hr if i < self.high_res_stages else self.block_size_lr
            x_scale, _ = CrossGatingBlock(
                features=(2**i) * self.features,
                block_size=block_size,
                grid_size=grid_size,
                dropout_rate=self.dropout_rate,
                input_proj_factor=self.input_proj_factor,
                upsample_y=False,
                use_bias=self.use_bias,
                name=f"stage_{idx_stage}_input_fuse_sam_{i}")(
                    x_scale, sam_features.pop(), deterministic=not train)
          else:
            x_scale = Conv1x1(
                (2**i) * self.features,
                use_bias=self.use_bias,
                name=f"stage_{idx_stage}_input_catconv_{i}")(
                    jnp.concatenate(
                        [x_scale, sam_features.pop()], axis=-1))

        x_scales.append(x_scale)

      # start encoder blocks
      encs = []
      x = x_scales[0]  # First full-scale input feature

      for i in range(self.depth):  # 0, 1, 2
        # use larger blocksize at high-res stages, vice versa.
        block_size = self.block_size_hr if i < self.high_res_stages else self.block_size_lr
        grid_size = self.grid_size_hr if i < self.high_res_stages else self.block_size_lr
        use_cross_gating_layer = True if idx_stage > 0 else False

        # Multi-scale input if multi-scale supervision
        x_scale = x_scales[i] if i < self.num_supervision_scales else None

        # UNet Encoder block
        enc_prev = encs_prev.pop() if idx_stage > 0 else None
        dec_prev = decs_prev.pop() if idx_stage > 0 else None

        x, bridge = UNetEncoderBlock(
            features=(2**i) * self.features,
            num_groups=self.num_groups,
            downsample=True,
            lrelu_slope=self.lrelu_slope,
            block_size=block_size,
            grid_size=grid_size,
            block_gmlp_factor=self.block_gmlp_factor,
            grid_gmlp_factor=self.grid_gmlp_factor,
            input_proj_factor=self.input_proj_factor,
            channels_reduction=self.channels_reduction,
            use_global_mlp=self.use_global_mlp,
            dropout_rate=self.dropout_rate,
            use_bias=self.use_bias,
            use_cross_gating=use_cross_gating_layer,
            name=f"stage_{idx_stage}_encoder_block_{i}")(
                x,
                skip=x_scale,
                enc=enc_prev,
                dec=dec_prev,
                deterministic=not train)

        # Cache skip signals
        encs.append(bridge)

      # Global MLP bottleneck blocks
      for i in range(self.num_bottleneck_blocks):
        x = BottleneckBlock(
            block_size=self.block_size_lr,
            grid_size=self.block_size_lr,
            features=(2**(self.depth - 1)) * self.features,
            num_groups=self.num_groups,
            block_gmlp_factor=self.block_gmlp_factor,
            grid_gmlp_factor=self.grid_gmlp_factor,
            input_proj_factor=self.input_proj_factor,
            dropout_rate=self.dropout_rate,
            use_bias=self.use_bias,
            channels_reduction=self.channels_reduction,
            name=f"stage_{idx_stage}_global_block_{i}")(
                x, deterministic=not train)
      # cache global feature for cross-gating
      global_feature = x

      # start cross gating. Use multi-scale feature fusion
      skip_features = []
      for i in reversed(range(self.depth)):  # 2, 1, 0
        # use larger blocksize at high-res stages
        block_size = self.block_size_hr if i < self.high_res_stages else self.block_size_lr
        grid_size = self.grid_size_hr if i < self.high_res_stages else self.block_size_lr

        # get additional multi-scale signals
        signal = jnp.concatenate([
            UpSampleRatio(
                (2**i) * self.features,
                ratio=2**(j - i),
                use_bias=self.use_bias)(enc) for j, enc in enumerate(encs)
        ],
                                 axis=-1)

        # Use cross-gating to cross modulate features
        if self.use_cross_gating:
          skips, global_feature = CrossGatingBlock(
              features=(2**i) * self.features,
              block_size=block_size,
              grid_size=grid_size,
              input_proj_factor=self.input_proj_factor,
              dropout_rate=self.dropout_rate,
              upsample_y=True,
              use_bias=self.use_bias,
              name=f"stage_{idx_stage}_cross_gating_block_{i}")(
                  signal, global_feature, deterministic=not train)
        else:
          skips = Conv1x1(
              (2**i) * self.features, use_bias=self.use_bias)(
                  signal)
          skips = Conv3x3((2**i) * self.features, use_bias=self.use_bias)(skips)

        skip_features.append(skips)

      # start decoder. Multi-scale feature fusion of cross-gated features
      outputs, decs, sam_features = [], [], []
      for i in reversed(range(self.depth)):
        # use larger blocksize at high-res stages
        block_size = self.block_size_hr if i < self.high_res_stages else self.block_size_lr
        grid_size = self.grid_size_hr if i < self.high_res_stages else self.block_size_lr

        # get multi-scale skip signals from cross-gating block
        signal = jnp.concatenate([
            UpSampleRatio(
                (2**i) * self.features,
                ratio=2**(self.depth - j - 1 - i),
                use_bias=self.use_bias)(skip)
            for j, skip in enumerate(skip_features)
        ],
                                 axis=-1)

        # Decoder block
        x = UNetDecoderBlock(
            features=(2**i) * self.features,
            num_groups=self.num_groups,
            lrelu_slope=self.lrelu_slope,
            block_size=block_size,
            grid_size=grid_size,
            block_gmlp_factor=self.block_gmlp_factor,
            grid_gmlp_factor=self.grid_gmlp_factor,
            input_proj_factor=self.input_proj_factor,
            channels_reduction=self.channels_reduction,
            use_global_mlp=self.use_global_mlp,
            dropout_rate=self.dropout_rate,
            use_bias=self.use_bias,
            name=f"stage_{idx_stage}_decoder_block_{i}")(
                x, bridge=signal, deterministic=not train)

        # Cache decoder features for later-stage's usage
        decs.append(x)

        # output conv, if not final stage, use supervised-attention-block.
        if i < self.num_supervision_scales:
          if idx_stage < self.num_stages - 1:  # not last stage, apply SAM
            sam, output = SAM(
                (2**i) * self.features,
                output_channels=self.num_outputs,
                use_bias=self.use_bias,
                name=f"stage_{idx_stage}_supervised_attention_module_{i}")(
                    x, shortcuts[i], train=train)
            outputs.append(output)
            sam_features.append(sam)
          else:  # Last stage, apply output convolutions
            output = Conv3x3(self.num_outputs,
                             use_bias=self.use_bias,
                             name=f"stage_{idx_stage}_output_conv_{i}")(x)
            output = output + shortcuts[i]
            outputs.append(output)
      # Cache encoder and decoder features for later-stage's usage
      encs_prev = encs[::-1]
      decs_prev = decs

      # Store outputs
      outputs_all.append(outputs)
    return outputs_all


def Model(*, variant=None, **kw):
  """Factory function to easily create a Model variant like "S".



  Every model file should have this Model() function that returns the flax

  model function. The function name should be fixed.



  Args:

    variant: UNet model variants. Options: 'S-1' | 'S-2' | 'S-3'

        | 'M-1' | 'M-2' | 'M-3'

    **kw: Other UNet config dicts.



  Returns:

    The MAXIM() model function

  """

  if variant is not None:
    config = {
        # params: 6.108515000000001 M, GFLOPS: 93.163716608
        "S-1": {
            "features": 32,
            "depth": 3,
            "num_stages": 1,
            "num_groups": 2,
            "num_bottleneck_blocks": 2,
            "block_gmlp_factor": 2,
            "grid_gmlp_factor": 2,
            "input_proj_factor": 2,
            "channels_reduction": 4,
        },
        # params: 13.35383 M, GFLOPS: 206.743273472
        "S-2": {
            "features": 32,
            "depth": 3,
            "num_stages": 2,
            "num_groups": 2,
            "num_bottleneck_blocks": 2,
            "block_gmlp_factor": 2,
            "grid_gmlp_factor": 2,
            "input_proj_factor": 2,
            "channels_reduction": 4,
        },
        # params: 20.599145 M, GFLOPS: 320.32194560000005
        "S-3": {
            "features": 32,
            "depth": 3,
            "num_stages": 3,
            "num_groups": 2,
            "num_bottleneck_blocks": 2,
            "block_gmlp_factor": 2,
            "grid_gmlp_factor": 2,
            "input_proj_factor": 2,
            "channels_reduction": 4,
        },
        # params: 19.361219000000002 M, 308.495712256 GFLOPs
        "M-1": {
            "features": 64,
            "depth": 3,
            "num_stages": 1,
            "num_groups": 2,
            "num_bottleneck_blocks": 2,
            "block_gmlp_factor": 2,
            "grid_gmlp_factor": 2,
            "input_proj_factor": 2,
            "channels_reduction": 4,
        },
        # params: 40.83911 M, 675.25541888 GFLOPs
        "M-2": {
            "features": 64,
            "depth": 3,
            "num_stages": 2,
            "num_groups": 2,
            "num_bottleneck_blocks": 2,
            "block_gmlp_factor": 2,
            "grid_gmlp_factor": 2,
            "input_proj_factor": 2,
            "channels_reduction": 4,
        },
        # params: 62.317001 M, 1042.014666752 GFLOPs
        "M-3": {
            "features": 64,
            "depth": 3,
            "num_stages": 3,
            "num_groups": 2,
            "num_bottleneck_blocks": 2,
            "block_gmlp_factor": 2,
            "grid_gmlp_factor": 2,
            "input_proj_factor": 2,
            "channels_reduction": 4,
        },
    }[variant]

    for k, v in config.items():
      kw.setdefault(k, v)

  return MAXIM(**kw)