thunnai commited on
Commit
d123787
·
1 Parent(s): dfcd575

Move EfficientAttentionConfig to module level

Browse files
sparktts/modules/speaker/perceiver_encoder.py CHANGED
@@ -45,6 +45,21 @@ def once(fn):
45
 
46
  print_once = once(print)
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  # main class
49
 
50
 
@@ -62,20 +77,6 @@ class Attend(nn.Module):
62
  use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
63
  ), "in order to use flash attention, you must be using pytorch 2.0 or above"
64
 
65
- # Define config as a regular class instead of namedtuple
66
- class EfficientAttentionConfig:
67
- def __init__(self, enable_flash, enable_math, enable_mem_efficient):
68
- self.enable_flash = enable_flash
69
- self.enable_math = enable_math
70
- self.enable_mem_efficient = enable_mem_efficient
71
-
72
- def _asdict(self):
73
- return {
74
- 'enable_flash': self.enable_flash,
75
- 'enable_math': self.enable_math,
76
- 'enable_mem_efficient': self.enable_mem_efficient
77
- }
78
-
79
  self.cpu_config = EfficientAttentionConfig(True, True, True)
80
  self.cuda_config = None
81
 
@@ -88,7 +89,7 @@ class Attend(nn.Module):
88
  print_once(
89
  "A100 GPU detected, using flash attention if input tensor is on cuda"
90
  )
91
- self.cuda_config = self.config(True, False, False)
92
  else:
93
  print_once(
94
  "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
 
45
 
46
  print_once = once(print)
47
 
48
+ # Define config class at module level
49
+ class EfficientAttentionConfig:
50
+ def __init__(self, enable_flash, enable_math, enable_mem_efficient):
51
+ self.enable_flash = enable_flash
52
+ self.enable_math = enable_math
53
+ self.enable_mem_efficient = enable_mem_efficient
54
+
55
+ def _asdict(self):
56
+ return {
57
+ 'enable_flash': self.enable_flash,
58
+ 'enable_math': self.enable_math,
59
+ 'enable_mem_efficient': self.enable_mem_efficient
60
+ }
61
+
62
+
63
  # main class
64
 
65
 
 
77
  use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
78
  ), "in order to use flash attention, you must be using pytorch 2.0 or above"
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  self.cpu_config = EfficientAttentionConfig(True, True, True)
81
  self.cuda_config = None
82
 
 
89
  print_once(
90
  "A100 GPU detected, using flash attention if input tensor is on cuda"
91
  )
92
+ self.cuda_config = EfficientAttentionConfig(True, False, False)
93
  else:
94
  print_once(
95
  "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"