Spaces:
Runtime error
Runtime error
| import gc | |
| import tempfile | |
| import unittest | |
| from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig | |
| from diffusers.models.attention_processor import Attention | |
| from diffusers.utils import is_optimum_quanto_available, is_torch_available | |
| from diffusers.utils.testing_utils import ( | |
| backend_empty_cache, | |
| backend_reset_peak_memory_stats, | |
| enable_full_determinism, | |
| nightly, | |
| numpy_cosine_similarity_distance, | |
| require_accelerate, | |
| require_big_accelerator, | |
| require_torch_cuda_compatibility, | |
| torch_device, | |
| ) | |
| if is_optimum_quanto_available(): | |
| from optimum.quanto import QLinear | |
| if is_torch_available(): | |
| import torch | |
| from ..utils import LoRALayer, get_memory_consumption_stat | |
| enable_full_determinism() | |
| class QuantoBaseTesterMixin: | |
| model_id = None | |
| pipeline_model_id = None | |
| model_cls = None | |
| torch_dtype = torch.bfloat16 | |
| # the expected reduction in peak memory used compared to an unquantized model expressed as a percentage | |
| expected_memory_reduction = 0.0 | |
| keep_in_fp32_module = "" | |
| modules_to_not_convert = "" | |
| _test_torch_compile = False | |
| def setUp(self): | |
| backend_reset_peak_memory_stats(torch_device) | |
| backend_empty_cache(torch_device) | |
| gc.collect() | |
| def tearDown(self): | |
| backend_reset_peak_memory_stats(torch_device) | |
| backend_empty_cache(torch_device) | |
| gc.collect() | |
| def get_dummy_init_kwargs(self): | |
| return {"weights_dtype": "float8"} | |
| def get_dummy_model_init_kwargs(self): | |
| return { | |
| "pretrained_model_name_or_path": self.model_id, | |
| "torch_dtype": self.torch_dtype, | |
| "quantization_config": QuantoConfig(**self.get_dummy_init_kwargs()), | |
| } | |
| def test_quanto_layers(self): | |
| model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) | |
| for name, module in model.named_modules(): | |
| if isinstance(module, torch.nn.Linear): | |
| assert isinstance(module, QLinear) | |
| def test_quanto_memory_usage(self): | |
| inputs = self.get_dummy_inputs() | |
| inputs = { | |
| k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool) | |
| } | |
| unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype) | |
| unquantized_model.to(torch_device) | |
| unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs) | |
| quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) | |
| quantized_model.to(torch_device) | |
| quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs) | |
| assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction | |
| def test_keep_modules_in_fp32(self): | |
| r""" | |
| A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. | |
| Also ensures if inference works. | |
| """ | |
| _keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules | |
| self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module | |
| model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) | |
| model.to(torch_device) | |
| for name, module in model.named_modules(): | |
| if isinstance(module, torch.nn.Linear): | |
| if name in model._keep_in_fp32_modules: | |
| assert module.weight.dtype == torch.float32 | |
| self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules | |
| def test_modules_to_not_convert(self): | |
| init_kwargs = self.get_dummy_model_init_kwargs() | |
| quantization_config_kwargs = self.get_dummy_init_kwargs() | |
| quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert}) | |
| quantization_config = QuantoConfig(**quantization_config_kwargs) | |
| init_kwargs.update({"quantization_config": quantization_config}) | |
| model = self.model_cls.from_pretrained(**init_kwargs) | |
| model.to(torch_device) | |
| for name, module in model.named_modules(): | |
| if name in self.modules_to_not_convert: | |
| assert not isinstance(module, QLinear) | |
| def test_dtype_assignment(self): | |
| model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) | |
| with self.assertRaises(ValueError): | |
| # Tries with a `dtype` | |
| model.to(torch.float16) | |
| with self.assertRaises(ValueError): | |
| # Tries with a `device` and `dtype` | |
| device_0 = f"{torch_device}:0" | |
| model.to(device=device_0, dtype=torch.float16) | |
| with self.assertRaises(ValueError): | |
| # Tries with a cast | |
| model.float() | |
| with self.assertRaises(ValueError): | |
| # Tries with a cast | |
| model.half() | |
| # This should work | |
| model.to(torch_device) | |
| def test_serialization(self): | |
| model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) | |
| inputs = self.get_dummy_inputs() | |
| model.to(torch_device) | |
| with torch.no_grad(): | |
| model_output = model(**inputs) | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| model.save_pretrained(tmp_dir) | |
| saved_model = self.model_cls.from_pretrained( | |
| tmp_dir, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| saved_model.to(torch_device) | |
| with torch.no_grad(): | |
| saved_model_output = saved_model(**inputs) | |
| assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5) | |
| def test_torch_compile(self): | |
| if not self._test_torch_compile: | |
| return | |
| model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) | |
| compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False) | |
| model.to(torch_device) | |
| with torch.no_grad(): | |
| model_output = model(**self.get_dummy_inputs()).sample | |
| compiled_model.to(torch_device) | |
| with torch.no_grad(): | |
| compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample | |
| model_output = model_output.detach().float().cpu().numpy() | |
| compiled_model_output = compiled_model_output.detach().float().cpu().numpy() | |
| max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten()) | |
| assert max_diff < 1e-3 | |
| def test_device_map_error(self): | |
| with self.assertRaises(ValueError): | |
| _ = self.model_cls.from_pretrained( | |
| **self.get_dummy_model_init_kwargs(), device_map={0: "8GB", "cpu": "16GB"} | |
| ) | |
| class FluxTransformerQuantoMixin(QuantoBaseTesterMixin): | |
| model_id = "hf-internal-testing/tiny-flux-transformer" | |
| model_cls = FluxTransformer2DModel | |
| pipeline_cls = FluxPipeline | |
| torch_dtype = torch.bfloat16 | |
| keep_in_fp32_module = "proj_out" | |
| modules_to_not_convert = ["proj_out"] | |
| _test_torch_compile = False | |
| def get_dummy_inputs(self): | |
| return { | |
| "hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to( | |
| torch_device, self.torch_dtype | |
| ), | |
| "encoder_hidden_states": torch.randn( | |
| (1, 512, 4096), | |
| generator=torch.Generator("cpu").manual_seed(0), | |
| ).to(torch_device, self.torch_dtype), | |
| "pooled_projections": torch.randn( | |
| (1, 768), | |
| generator=torch.Generator("cpu").manual_seed(0), | |
| ).to(torch_device, self.torch_dtype), | |
| "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), | |
| "img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to( | |
| torch_device, self.torch_dtype | |
| ), | |
| "txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to( | |
| torch_device, self.torch_dtype | |
| ), | |
| "guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype), | |
| } | |
| def get_dummy_training_inputs(self, device=None, seed: int = 0): | |
| batch_size = 1 | |
| num_latent_channels = 4 | |
| num_image_channels = 3 | |
| height = width = 4 | |
| sequence_length = 48 | |
| embedding_dim = 32 | |
| torch.manual_seed(seed) | |
| hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) | |
| torch.manual_seed(seed) | |
| encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( | |
| device, dtype=torch.bfloat16 | |
| ) | |
| torch.manual_seed(seed) | |
| pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) | |
| torch.manual_seed(seed) | |
| text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) | |
| torch.manual_seed(seed) | |
| image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) | |
| timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) | |
| return { | |
| "hidden_states": hidden_states, | |
| "encoder_hidden_states": encoder_hidden_states, | |
| "pooled_projections": pooled_prompt_embeds, | |
| "txt_ids": text_ids, | |
| "img_ids": image_ids, | |
| "timestep": timestep, | |
| } | |
| def test_model_cpu_offload(self): | |
| init_kwargs = self.get_dummy_init_kwargs() | |
| transformer = self.model_cls.from_pretrained( | |
| "hf-internal-testing/tiny-flux-pipe", | |
| quantization_config=QuantoConfig(**init_kwargs), | |
| subfolder="transformer", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| pipe = self.pipeline_cls.from_pretrained( | |
| "hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.enable_model_cpu_offload(device=torch_device) | |
| _ = pipe("a cat holding a sign that says hello", num_inference_steps=2) | |
| def test_training(self): | |
| quantization_config = QuantoConfig(**self.get_dummy_init_kwargs()) | |
| quantized_model = self.model_cls.from_pretrained( | |
| "hf-internal-testing/tiny-flux-pipe", | |
| subfolder="transformer", | |
| quantization_config=quantization_config, | |
| torch_dtype=torch.bfloat16, | |
| ).to(torch_device) | |
| for param in quantized_model.parameters(): | |
| # freeze the model as only adapter layers will be trained | |
| param.requires_grad = False | |
| if param.ndim == 1: | |
| param.data = param.data.to(torch.float32) | |
| for _, module in quantized_model.named_modules(): | |
| if isinstance(module, Attention): | |
| module.to_q = LoRALayer(module.to_q, rank=4) | |
| module.to_k = LoRALayer(module.to_k, rank=4) | |
| module.to_v = LoRALayer(module.to_v, rank=4) | |
| with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): | |
| inputs = self.get_dummy_training_inputs(torch_device) | |
| output = quantized_model(**inputs)[0] | |
| output.norm().backward() | |
| for module in quantized_model.modules(): | |
| if isinstance(module, LoRALayer): | |
| self.assertTrue(module.adapter[1].weight.grad is not None) | |
| class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): | |
| expected_memory_reduction = 0.6 | |
| def get_dummy_init_kwargs(self): | |
| return {"weights_dtype": "float8"} | |
| class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): | |
| expected_memory_reduction = 0.6 | |
| _test_torch_compile = True | |
| def get_dummy_init_kwargs(self): | |
| return {"weights_dtype": "int8"} | |
| class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): | |
| expected_memory_reduction = 0.55 | |
| def get_dummy_init_kwargs(self): | |
| return {"weights_dtype": "int4"} | |
| class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): | |
| expected_memory_reduction = 0.65 | |
| def get_dummy_init_kwargs(self): | |
| return {"weights_dtype": "int2"} | |