|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import tempfile |
|
|
import unittest |
|
|
|
|
|
import torch |
|
|
|
|
|
from diffusers import DiffusionPipeline, QuantoConfig |
|
|
from diffusers.quantizers import PipelineQuantizationConfig |
|
|
from diffusers.utils.testing_utils import ( |
|
|
is_transformers_available, |
|
|
require_accelerate, |
|
|
require_bitsandbytes_version_greater, |
|
|
require_quanto, |
|
|
require_torch, |
|
|
require_torch_accelerator, |
|
|
slow, |
|
|
torch_device, |
|
|
) |
|
|
|
|
|
|
|
|
if is_transformers_available(): |
|
|
from transformers import BitsAndBytesConfig as TranBitsAndBytesConfig |
|
|
else: |
|
|
TranBitsAndBytesConfig = None |
|
|
|
|
|
|
|
|
@require_bitsandbytes_version_greater("0.43.2") |
|
|
@require_quanto |
|
|
@require_accelerate |
|
|
@require_torch |
|
|
@require_torch_accelerator |
|
|
@slow |
|
|
class PipelineQuantizationTests(unittest.TestCase): |
|
|
model_name = "hf-internal-testing/tiny-flux-pipe" |
|
|
prompt = "a beautiful sunset amidst the mountains." |
|
|
num_inference_steps = 10 |
|
|
seed = 0 |
|
|
|
|
|
def test_quant_config_set_correctly_through_kwargs(self): |
|
|
components_to_quantize = ["transformer", "text_encoder_2"] |
|
|
quant_config = PipelineQuantizationConfig( |
|
|
quant_backend="bitsandbytes_4bit", |
|
|
quant_kwargs={ |
|
|
"load_in_4bit": True, |
|
|
"bnb_4bit_quant_type": "nf4", |
|
|
"bnb_4bit_compute_dtype": torch.bfloat16, |
|
|
}, |
|
|
components_to_quantize=components_to_quantize, |
|
|
) |
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
|
self.model_name, |
|
|
quantization_config=quant_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
).to(torch_device) |
|
|
for name, component in pipe.components.items(): |
|
|
if name in components_to_quantize: |
|
|
self.assertTrue(getattr(component.config, "quantization_config", None) is not None) |
|
|
quantization_config = component.config.quantization_config |
|
|
self.assertTrue(quantization_config.load_in_4bit) |
|
|
self.assertTrue(quantization_config.quant_method == "bitsandbytes") |
|
|
|
|
|
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps) |
|
|
|
|
|
def test_quant_config_set_correctly_through_granular(self): |
|
|
quant_config = PipelineQuantizationConfig( |
|
|
quant_mapping={ |
|
|
"transformer": QuantoConfig(weights_dtype="int8"), |
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), |
|
|
} |
|
|
) |
|
|
components_to_quantize = list(quant_config.quant_mapping.keys()) |
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
|
self.model_name, |
|
|
quantization_config=quant_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
).to(torch_device) |
|
|
for name, component in pipe.components.items(): |
|
|
if name in components_to_quantize: |
|
|
self.assertTrue(getattr(component.config, "quantization_config", None) is not None) |
|
|
quantization_config = component.config.quantization_config |
|
|
|
|
|
if name == "text_encoder_2": |
|
|
self.assertTrue(quantization_config.load_in_4bit) |
|
|
self.assertTrue(quantization_config.quant_method == "bitsandbytes") |
|
|
else: |
|
|
self.assertTrue(quantization_config.quant_method == "quanto") |
|
|
|
|
|
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps) |
|
|
|
|
|
def test_raises_error_for_invalid_config(self): |
|
|
with self.assertRaises(ValueError) as err_context: |
|
|
_ = PipelineQuantizationConfig( |
|
|
quant_mapping={ |
|
|
"transformer": QuantoConfig(weights_dtype="int8"), |
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), |
|
|
}, |
|
|
quant_backend="bitsandbytes_4bit", |
|
|
) |
|
|
|
|
|
self.assertTrue( |
|
|
str(err_context.exception) |
|
|
== "Both `quant_backend` and `quant_mapping` cannot be specified at the same time." |
|
|
) |
|
|
|
|
|
def test_validation_for_kwargs(self): |
|
|
components_to_quantize = ["transformer", "text_encoder_2"] |
|
|
with self.assertRaises(ValueError) as err_context: |
|
|
_ = PipelineQuantizationConfig( |
|
|
quant_backend="quanto", |
|
|
quant_kwargs={"weights_dtype": "int8"}, |
|
|
components_to_quantize=components_to_quantize, |
|
|
) |
|
|
|
|
|
self.assertTrue( |
|
|
"The signatures of the __init__ methods of the quantization config classes" in str(err_context.exception) |
|
|
) |
|
|
|
|
|
def test_raises_error_for_wrong_config_class(self): |
|
|
quant_config = { |
|
|
"transformer": QuantoConfig(weights_dtype="int8"), |
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), |
|
|
} |
|
|
with self.assertRaises(ValueError) as err_context: |
|
|
_ = DiffusionPipeline.from_pretrained( |
|
|
self.model_name, |
|
|
quantization_config=quant_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
self.assertTrue( |
|
|
str(err_context.exception) == "`quantization_config` must be an instance of `PipelineQuantizationConfig`." |
|
|
) |
|
|
|
|
|
def test_validation_for_mapping(self): |
|
|
with self.assertRaises(ValueError) as err_context: |
|
|
_ = PipelineQuantizationConfig( |
|
|
quant_mapping={ |
|
|
"transformer": DiffusionPipeline(), |
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), |
|
|
} |
|
|
) |
|
|
|
|
|
self.assertTrue("Provided config for module_name=transformer could not be found" in str(err_context.exception)) |
|
|
|
|
|
def test_saving_loading(self): |
|
|
quant_config = PipelineQuantizationConfig( |
|
|
quant_mapping={ |
|
|
"transformer": QuantoConfig(weights_dtype="int8"), |
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), |
|
|
} |
|
|
) |
|
|
components_to_quantize = list(quant_config.quant_mapping.keys()) |
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
|
self.model_name, |
|
|
quantization_config=quant_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
).to(torch_device) |
|
|
|
|
|
pipe_inputs = {"prompt": self.prompt, "num_inference_steps": self.num_inference_steps, "output_type": "latent"} |
|
|
output_1 = pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
pipe.save_pretrained(tmpdir) |
|
|
loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=torch.bfloat16).to(torch_device) |
|
|
for name, component in loaded_pipe.components.items(): |
|
|
if name in components_to_quantize: |
|
|
self.assertTrue(getattr(component.config, "quantization_config", None) is not None) |
|
|
quantization_config = component.config.quantization_config |
|
|
|
|
|
if name == "text_encoder_2": |
|
|
self.assertTrue(quantization_config.load_in_4bit) |
|
|
self.assertTrue(quantization_config.quant_method == "bitsandbytes") |
|
|
else: |
|
|
self.assertTrue(quantization_config.quant_method == "quanto") |
|
|
|
|
|
output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images |
|
|
|
|
|
self.assertTrue(torch.allclose(output_1, output_2)) |
|
|
|