File size: 4,289 Bytes
dc155d4
 
 
 
 
 
 
 
 
 
 
 
879ee4e
dc155d4
 
 
 
 
 
 
 
 
39b7e29
 
 
 
 
 
 
dc155d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
879ee4e
dc155d4
 
 
 
 
 
 
6ff4937
dc155d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82d7cc1
dc155d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
"""

from typing import Any
from typing import Callable
from typing import ParamSpec

import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import Int8WeightOnlyConfig

from optimization_utils import capture_component_call
from optimization_utils import aoti_compile
from optimization_utils import ZeroGPUCompiledModel


P = ParamSpec('P')


TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)

TRANSFORMER_DYNAMIC_SHAPES = {
    'hidden_states': {
        2: TRANSFORMER_NUM_FRAMES_DIM,
    },
}

INDUCTOR_CONFIGS = {
    'conv_1x1_as_mm': True,
    'epilogue_fusion': False,
    'coordinate_descent_tuning': True,
    'coordinate_descent_check_all_directions': True,
    'max_autotune': True,
    'triton.cudagraphs': True,
}


def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):

    @spaces.GPU(duration=1500)
    def compile_transformer():
        
        with capture_component_call(pipeline, 'transformer') as call:
            pipeline(*args, **kwargs)
        
        dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
        dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES

        quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
        quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
        
        hidden_states: torch.Tensor = call.kwargs['hidden_states']
        hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
        if hidden_states.shape[-1] > hidden_states.shape[-2]:
            hidden_states_landscape = hidden_states
            hidden_states_portrait = hidden_states_transposed
        else:
            hidden_states_landscape = hidden_states_transposed
            hidden_states_portrait = hidden_states

        exported_landscape_1 = torch.export.export(
            mod=pipeline.transformer,
            args=call.args,
            kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
            dynamic_shapes=dynamic_shapes,
        )
        
        exported_portrait_2 = torch.export.export(
            mod=pipeline.transformer_2,
            args=call.args,
            kwargs=call.kwargs | {'hidden_states': hidden_states_portrait},
            dynamic_shapes=dynamic_shapes,
        )

        compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
        compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)

        compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
        compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)

        return (
            compiled_landscape_1,
            compiled_landscape_2,
            compiled_portrait_1,
            compiled_portrait_2,
        )

    quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
    cl1, cl2, cp1, cp2 = compile_transformer()

    def combined_transformer_1(*args, **kwargs):
        hidden_states: torch.Tensor = kwargs['hidden_states']
        if hidden_states.shape[-1] > hidden_states.shape[-2]:
            return cl1(*args, **kwargs)
        else:
            return cp1(*args, **kwargs)

    def combined_transformer_2(*args, **kwargs):
        hidden_states: torch.Tensor = kwargs['hidden_states']
        if hidden_states.shape[-1] > hidden_states.shape[-2]:
            return cl2(*args, **kwargs)
        else:
            return cp2(*args, **kwargs)

    transformer_config = pipeline.transformer.config
    transformer_dtype = pipeline.transformer.dtype

    pipeline.transformer = combined_transformer_1
    pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
    pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]

    pipeline.transformer_2 = combined_transformer_2
    pipeline.transformer_2.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
    pipeline.transformer_2.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]