Move EfficientAttentionConfig to module level
Browse files
sparktts/modules/speaker/perceiver_encoder.py
CHANGED
@@ -45,6 +45,21 @@ def once(fn):
|
|
45 |
|
46 |
print_once = once(print)
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
# main class
|
49 |
|
50 |
|
@@ -62,20 +77,6 @@ class Attend(nn.Module):
|
|
62 |
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
63 |
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
64 |
|
65 |
-
# Define config as a regular class instead of namedtuple
|
66 |
-
class EfficientAttentionConfig:
|
67 |
-
def __init__(self, enable_flash, enable_math, enable_mem_efficient):
|
68 |
-
self.enable_flash = enable_flash
|
69 |
-
self.enable_math = enable_math
|
70 |
-
self.enable_mem_efficient = enable_mem_efficient
|
71 |
-
|
72 |
-
def _asdict(self):
|
73 |
-
return {
|
74 |
-
'enable_flash': self.enable_flash,
|
75 |
-
'enable_math': self.enable_math,
|
76 |
-
'enable_mem_efficient': self.enable_mem_efficient
|
77 |
-
}
|
78 |
-
|
79 |
self.cpu_config = EfficientAttentionConfig(True, True, True)
|
80 |
self.cuda_config = None
|
81 |
|
@@ -88,7 +89,7 @@ class Attend(nn.Module):
|
|
88 |
print_once(
|
89 |
"A100 GPU detected, using flash attention if input tensor is on cuda"
|
90 |
)
|
91 |
-
self.cuda_config =
|
92 |
else:
|
93 |
print_once(
|
94 |
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|
|
|
45 |
|
46 |
print_once = once(print)
|
47 |
|
48 |
+
# Define config class at module level
|
49 |
+
class EfficientAttentionConfig:
|
50 |
+
def __init__(self, enable_flash, enable_math, enable_mem_efficient):
|
51 |
+
self.enable_flash = enable_flash
|
52 |
+
self.enable_math = enable_math
|
53 |
+
self.enable_mem_efficient = enable_mem_efficient
|
54 |
+
|
55 |
+
def _asdict(self):
|
56 |
+
return {
|
57 |
+
'enable_flash': self.enable_flash,
|
58 |
+
'enable_math': self.enable_math,
|
59 |
+
'enable_mem_efficient': self.enable_mem_efficient
|
60 |
+
}
|
61 |
+
|
62 |
+
|
63 |
# main class
|
64 |
|
65 |
|
|
|
77 |
use_flash and version.parse(torch.__version__) < version.parse("2.0.0")
|
78 |
), "in order to use flash attention, you must be using pytorch 2.0 or above"
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
self.cpu_config = EfficientAttentionConfig(True, True, True)
|
81 |
self.cuda_config = None
|
82 |
|
|
|
89 |
print_once(
|
90 |
"A100 GPU detected, using flash attention if input tensor is on cuda"
|
91 |
)
|
92 |
+
self.cuda_config = EfficientAttentionConfig(True, False, False)
|
93 |
else:
|
94 |
print_once(
|
95 |
"Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda"
|