aoti-blocks-load

#23
by cbensimon HF Staff - opened
Files changed (4) hide show
  1. aoti.py +35 -0
  2. app.py +44 -17
  3. optimization.py +0 -106
  4. optimization_utils.py +0 -107
aoti.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import cast
5
+
6
+ import torch
7
+ from huggingface_hub import hf_hub_download
8
+ from spaces.zero.torch.aoti import ZeroGPUCompiledModel
9
+ from spaces.zero.torch.aoti import ZeroGPUWeights
10
+ from torch._functorch._aot_autograd.subclass_parametrization import unwrap_tensor_subclass_parameters
11
+
12
+
13
+ def _shallow_clone_module(module: torch.nn.Module) -> torch.nn.Module:
14
+ clone = object.__new__(module.__class__)
15
+ clone.__dict__ = module.__dict__.copy()
16
+ clone._parameters = module._parameters.copy()
17
+ clone._buffers = module._buffers.copy()
18
+ clone._modules = {k: _shallow_clone_module(v) for k, v in module._modules.items() if v is not None}
19
+ return clone
20
+
21
+
22
+ def aoti_blocks_load(module: torch.nn.Module, repo_id: str, variant: str | None = None):
23
+ repeated_blocks = cast(list[str], module._repeated_blocks)
24
+ aoti_files = {name: hf_hub_download(
25
+ repo_id=repo_id,
26
+ filename='package.pt2',
27
+ subfolder=name if variant is None else f'{name}.{variant}',
28
+ ) for name in repeated_blocks}
29
+ for block_name, aoti_file in aoti_files.items():
30
+ for block in module.modules():
31
+ if block.__class__.__name__ == block_name:
32
+ block_ = _shallow_clone_module(block)
33
+ unwrap_tensor_subclass_parameters(block_)
34
+ weights = ZeroGPUWeights(block_.state_dict())
35
+ block.forward = ZeroGPUCompiledModel(aoti_file, weights)
app.py CHANGED
@@ -9,7 +9,12 @@ import numpy as np
9
  from PIL import Image
10
  import random
11
  import gc
12
- from optimization import optimize_pipeline_
 
 
 
 
 
13
 
14
 
15
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
@@ -23,7 +28,7 @@ MAX_SEED = np.iinfo(np.int32).max
23
 
24
  FIXED_FPS = 16
25
  MIN_FRAMES_MODEL = 8
26
- MAX_FRAMES_MODEL = 81
27
 
28
  MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
29
  MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)
@@ -43,21 +48,29 @@ pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID,
43
  torch_dtype=torch.bfloat16,
44
  ).to('cuda')
45
 
46
- for i in range(3):
47
- gc.collect()
48
- torch.cuda.synchronize()
49
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- OPTIMIZE_WIDTH = 832
52
- OPTIMIZE_HEIGHT = 624
 
53
 
54
- optimize_pipeline_(pipe,
55
- image=Image.new('RGB', (OPTIMIZE_WIDTH, OPTIMIZE_HEIGHT)),
56
- prompt='prompt',
57
- height=OPTIMIZE_HEIGHT,
58
- width=OPTIMIZE_WIDTH,
59
- num_frames=MAX_FRAMES_MODEL,
60
- )
61
 
62
 
63
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
@@ -109,6 +122,14 @@ def resize_image(image: Image.Image) -> Image.Image:
109
  return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
110
 
111
 
 
 
 
 
 
 
 
 
112
  def get_duration(
113
  input_image,
114
  prompt,
@@ -121,7 +142,13 @@ def get_duration(
121
  randomize_seed,
122
  progress,
123
  ):
124
- return int(steps) * 15
 
 
 
 
 
 
125
 
126
  @spaces.GPU(duration=get_duration)
127
  def generate_video(
@@ -179,7 +206,7 @@ def generate_video(
179
  if input_image is None:
180
  raise gr.Error("Please upload an input image.")
181
 
182
- num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
183
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
184
  resized_image = resize_image(input_image)
185
 
 
9
  from PIL import Image
10
  import random
11
  import gc
12
+
13
+ from torchao.quantization import quantize_
14
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
15
+ from torchao.quantization import Int8WeightOnlyConfig
16
+
17
+ import aoti
18
 
19
 
20
  MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
 
28
 
29
  FIXED_FPS = 16
30
  MIN_FRAMES_MODEL = 8
31
+ MAX_FRAMES_MODEL = 80
32
 
33
  MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
34
  MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)
 
48
  torch_dtype=torch.bfloat16,
49
  ).to('cuda')
50
 
51
+ pipe.load_lora_weights(
52
+ "Kijai/WanVideo_comfy",
53
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
54
+ adapter_name="lightx2v"
55
+ )
56
+ kwargs_lora = {}
57
+ kwargs_lora["load_into_transformer_2"] = True
58
+ pipe.load_lora_weights(
59
+ "Kijai/WanVideo_comfy",
60
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
61
+ adapter_name="lightx2v_2", **kwargs_lora
62
+ )
63
+ pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
64
+ pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
65
+ pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
66
+ pipe.unload_lora_weights()
67
 
68
+ quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
69
+ quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
70
+ quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
71
 
72
+ aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
73
+ aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
 
 
 
 
 
74
 
75
 
76
  default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
 
122
  return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
123
 
124
 
125
+ def get_num_frames(duration_seconds: float):
126
+ return 1 + int(np.clip(
127
+ int(round(duration_seconds * FIXED_FPS)),
128
+ MIN_FRAMES_MODEL,
129
+ MAX_FRAMES_MODEL,
130
+ ))
131
+
132
+
133
  def get_duration(
134
  input_image,
135
  prompt,
 
142
  randomize_seed,
143
  progress,
144
  ):
145
+ BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
146
+ BASE_STEP_DURATION = 15
147
+ width, height = resize_image(input_image).size
148
+ frames = get_num_frames(duration_seconds)
149
+ factor = frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
150
+ step_duration = BASE_STEP_DURATION * factor ** 1.5
151
+ return 10 + int(steps) * step_duration
152
 
153
  @spaces.GPU(duration=get_duration)
154
  def generate_video(
 
206
  if input_image is None:
207
  raise gr.Error("Please upload an input image.")
208
 
209
+ num_frames = get_num_frames(duration_seconds)
210
  current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
211
  resized_image = resize_image(input_image)
212
 
optimization.py DELETED
@@ -1,106 +0,0 @@
1
- """
2
- """
3
-
4
- from typing import Any
5
- from typing import Callable
6
- from typing import ParamSpec
7
-
8
- import spaces
9
- import torch
10
- from torch.utils._pytree import tree_map_only
11
- from torchao.quantization import quantize_
12
- from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
- from torchao.quantization import Int8WeightOnlyConfig
14
-
15
- from optimization_utils import capture_component_call
16
- from optimization_utils import aoti_compile
17
- from optimization_utils import drain_module_parameters
18
-
19
-
20
- P = ParamSpec('P')
21
-
22
- LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
23
-
24
- LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
25
- LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)
26
-
27
- TRANSFORMER_DYNAMIC_SHAPES = {
28
- 'hidden_states': {
29
- 2: LATENT_FRAMES_DIM,
30
- 3: 2 * LATENT_PATCHED_HEIGHT_DIM,
31
- 4: 2 * LATENT_PATCHED_WIDTH_DIM,
32
- },
33
- }
34
-
35
- INDUCTOR_CONFIGS = {
36
- 'conv_1x1_as_mm': True,
37
- 'epilogue_fusion': False,
38
- 'coordinate_descent_tuning': True,
39
- 'coordinate_descent_check_all_directions': True,
40
- 'max_autotune': True,
41
- 'triton.cudagraphs': True,
42
- }
43
-
44
-
45
- def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
46
-
47
- @spaces.GPU(duration=1500)
48
- def compile_transformer():
49
-
50
- # This LoRA fusion part remains the same
51
- pipeline.load_lora_weights(
52
- "Kijai/WanVideo_comfy",
53
- weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
54
- adapter_name="lightx2v"
55
- )
56
- kwargs_lora = {}
57
- kwargs_lora["load_into_transformer_2"] = True
58
- pipeline.load_lora_weights(
59
- "Kijai/WanVideo_comfy",
60
- weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
61
- adapter_name="lightx2v_2", **kwargs_lora
62
- )
63
- pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
64
- pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
65
- pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
66
- pipeline.unload_lora_weights()
67
-
68
- with capture_component_call(pipeline, 'transformer') as call:
69
- pipeline(*args, **kwargs)
70
-
71
- dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
72
- dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
73
-
74
- quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
75
- quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
76
-
77
-
78
- exported_1 = torch.export.export(
79
- mod=pipeline.transformer,
80
- args=call.args,
81
- kwargs=call.kwargs,
82
- dynamic_shapes=dynamic_shapes,
83
- )
84
-
85
- exported_2 = torch.export.export(
86
- mod=pipeline.transformer_2,
87
- args=call.args,
88
- kwargs=call.kwargs,
89
- dynamic_shapes=dynamic_shapes,
90
- )
91
-
92
- compiled_1 = aoti_compile(exported_1, INDUCTOR_CONFIGS)
93
- compiled_2 = aoti_compile(exported_2, INDUCTOR_CONFIGS)
94
-
95
- return compiled_1, compiled_2
96
-
97
-
98
- quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
99
-
100
- compiled_transformer_1, compiled_transformer_2 = compile_transformer()
101
-
102
- pipeline.transformer.forward = compiled_transformer_1
103
- drain_module_parameters(pipeline.transformer)
104
-
105
- pipeline.transformer_2.forward = compiled_transformer_2
106
- drain_module_parameters(pipeline.transformer_2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
optimization_utils.py DELETED
@@ -1,107 +0,0 @@
1
- """
2
- """
3
- import contextlib
4
- from contextvars import ContextVar
5
- from io import BytesIO
6
- from typing import Any
7
- from typing import cast
8
- from unittest.mock import patch
9
-
10
- import torch
11
- from torch._inductor.package.package import package_aoti
12
- from torch.export.pt2_archive._package import AOTICompiledModel
13
- from torch.export.pt2_archive._package_weights import Weights
14
-
15
-
16
- INDUCTOR_CONFIGS_OVERRIDES = {
17
- 'aot_inductor.package_constants_in_so': False,
18
- 'aot_inductor.package_constants_on_disk': True,
19
- 'aot_inductor.package': True,
20
- }
21
-
22
-
23
- class ZeroGPUWeights:
24
- def __init__(self, constants_map: dict[str, torch.Tensor], to_cuda: bool = False):
25
- if to_cuda:
26
- self.constants_map = {name: tensor.to('cuda') for name, tensor in constants_map.items()}
27
- else:
28
- self.constants_map = constants_map
29
- def __reduce__(self):
30
- constants_map: dict[str, torch.Tensor] = {}
31
- for name, tensor in self.constants_map.items():
32
- tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
33
- constants_map[name] = tensor_.copy_(tensor).detach().share_memory_()
34
- return ZeroGPUWeights, (constants_map, True)
35
-
36
-
37
- class ZeroGPUCompiledModel:
38
- def __init__(self, archive_file: torch.types.FileLike, weights: ZeroGPUWeights):
39
- self.archive_file = archive_file
40
- self.weights = weights
41
- self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
42
- def __call__(self, *args, **kwargs):
43
- if (compiled_model := self.compiled_model.get()) is None:
44
- compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
45
- compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
46
- self.compiled_model.set(compiled_model)
47
- return compiled_model(*args, **kwargs)
48
- def __reduce__(self):
49
- return ZeroGPUCompiledModel, (self.archive_file, self.weights)
50
-
51
-
52
- def aoti_compile(
53
- exported_program: torch.export.ExportedProgram,
54
- inductor_configs: dict[str, Any] | None = None,
55
- ):
56
- inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
57
- gm = cast(torch.fx.GraphModule, exported_program.module())
58
- assert exported_program.example_inputs is not None
59
- args, kwargs = exported_program.example_inputs
60
- artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
61
- archive_file = BytesIO()
62
- files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
63
- package_aoti(archive_file, files)
64
- weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
65
- zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights})
66
- return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
67
-
68
-
69
- @contextlib.contextmanager
70
- def capture_component_call(
71
- pipeline: Any,
72
- component_name: str,
73
- component_method='forward',
74
- ):
75
-
76
- class CapturedCallException(Exception):
77
- def __init__(self, *args, **kwargs):
78
- super().__init__()
79
- self.args = args
80
- self.kwargs = kwargs
81
-
82
- class CapturedCall:
83
- def __init__(self):
84
- self.args: tuple[Any, ...] = ()
85
- self.kwargs: dict[str, Any] = {}
86
-
87
- component = getattr(pipeline, component_name)
88
- captured_call = CapturedCall()
89
-
90
- def capture_call(*args, **kwargs):
91
- raise CapturedCallException(*args, **kwargs)
92
-
93
- with patch.object(component, component_method, new=capture_call):
94
- try:
95
- yield captured_call
96
- except CapturedCallException as e:
97
- captured_call.args = e.args
98
- captured_call.kwargs = e.kwargs
99
-
100
-
101
- def drain_module_parameters(module: torch.nn.Module):
102
- state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
103
- state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
104
- module.load_state_dict(state_dict, assign=True)
105
- for name, param in state_dict.items():
106
- meta = state_dict_meta[name]
107
- param.data = torch.Tensor([]).to(**meta)