How to Run a Hugging Face Model in JAX (Part 4) - Diffusers
In the previous episodes ([ep01](01-run-hug# How to Run a Hugging Face Model in JAX (Part 4) - Diffusers
In the previous episodes (ep01, ep02, ep03 ), we have run the Llama model in Jax using PyTorch model definitions from HuggingFace transformers and torchax as the interoperability layer. In this episode, we will do so for a image generation model.
The stable diffusion
Let's start with instatiating a Stable Diffusion model from
HuggingFace diffusers, in a simple script. (my case it's saved in jax_hg_04.py
)
import time
import functools
import jax
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
print(type(pipe))
print(isinstantance(pipe, torch.nn.Module))
Running the above you will see
<class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'>
False
There we found something unusual: the StableDiffusionPipeline
is NOT a torch.nn.Module
.
Recall previously (part 1), to convert a torch model to a JAX callable, we use torchax.extract_jax
which only works with torch.nn.Module
s.
Components of StableDiffusion Pipeline:
Looking at the pipe
object above:
In [6]: pipe
Out[6]:
StableDiffusionPipeline {
"_class_name": "StableDiffusionPipeline",
"_diffusers_version": "0.35.1",
"_name_or_path": "stabilityai/stable-diffusion-2-base",
"feature_extractor": [
"transformers",
"CLIPImageProcessor"
],
"image_encoder": [
null,
null
],
"requires_safety_checker": false,
"safety_checker": [
null,
null
],
"scheduler": [
"diffusers",
"PNDMScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
],
"tokenizer": [
"transformers",
"CLIPTokenizer"
],
"unet": [
"diffusers",
"UNet2DConditionModel"
],
"vae": [
"diffusers",
"AutoencoderKL"
]
}
We can see that the pipeline actually have many components, inspecting it
in REPL, we can see that the components vae
, unet
and text_encoder
are torch.nn.Module
's. They will be our starting point.
The torchax.compile
API
For this blog post, we will show case compile
API in torchax.
This API is like torch.compile
; except, instead of using torch-inductor to compile
your model; it is powered with jax.jit
. This way, we will get jax compiled performance
instead of Jax eager model.
This is an wrapper over a torch.nn.Module
, but will use jax.jit
on the
forward function of this module. The wrapped JittableModule
is still a
torch.nn.Module
; so we could substitute it into the pipeline.
So let's modify the above script to
import time
import functools
import jax
import torch
from diffusers import StableDiffusionPipeline
import torchax
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
env = torchax.default_env()
prompt = "a photograph of an astronaut riding a horse"
with env:
# Moves the weights to 'jax' device: i.e. to tensors backed by jax.Array's
pipe.to('jax')
pipe.unet = torchax.compile(
pipe.unet
)
pipe.vae = torchax.compile(pipe.vae)
pipe.text_encoder = torchax.compile(pipe.text_encoder)
image = pipe(prompt, num_inference_steps=20).images[0]
image.save('astronaut_png')
Running it we got:
TypeError: function call_torch at /mnt/disks/hanq/torch_xla/torchax/torchax/interop.py:224 traced for jit returned a value of type <class 'transformers.modeling_outputs.BaseModelOutputWithPooling'> at output component jit, which is not a valid JAX type
Error 1: Pytree's again
Again, we got pytree issues that we deal many times throughout the posts. This is as simple as registrying it as follows:
from jax.tree_util import register_pytree_node
import jax
def base_model_output_with_pooling_flatten(v):
return (v.last_hidden_state, v.pooler_output, v.hidden_states, v.attentions), None
def base_model_output_with_pooling_unflatten(aux_data, children):
return BaseModelOutputWithPooling(*children)
register_pytree_node(
BaseModelOutputWithPooling,
base_model_output_with_pooling_flatten,
base_model_output_with_pooling_unflatten
)
Running it again we hit the second error:
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[]
This occurred in the item() method of jax.Array
The error occurred while tracing the function call_torch at /mnt/disks/hanq/torch_xla/torchax/torchax/interop.py:224 for jit. This concrete value was not available in Python because it depends on the value of the argument kwargs['return_dict'].
Error 2: Static argnames
This error also looks familiar: the pipeline is calling the model passing return_dict=True
(or False
, we don't really care);
by default, JAX thinks it's a variable (that can change); but here we should really treat it as a constant (so that) different
values should trigger recompile.
If we are calling jax.jit
ourselves, we would be passing static_argnames
to jax.jit
( See API doc here); but here, we are using torchax.compile
; how do we
do that?
We can do that by passing CompileOptions
to torchax.compile
, as so:
pipe.unet = torchax.compile(
pipe.unet, torchax.CompileOptions(
jax_jit_kwargs={'static_argnames': ('return_dict', )}
)
)
So basically we can pass arbitrary kwargs
to the underlying jax.jit
as a dictionary.
Now, let's try again. This time we are greeted with the following, more scary error:
Traceback (most recent call last):
File "/mnt/disks/hanq/learning_machine/jax-huggingface/jax_hg_04.py", line 43, in <module>
image = pipe(prompt, num_inference_steps=20).images[0]
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/disks/hanq/miniconda3/envs/py13/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/mnt/disks/hanq/miniconda3/envs/py13/lib/python3.13/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 1061, in __call__
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/disks/hanq/miniconda3/envs/py13/lib/python3.13/site-packages/diffusers/schedulers/scheduling_pndm.py", line 257, in step
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/disks/hanq/miniconda3/envs/py13/lib/python3.13/site-packages/diffusers/schedulers/scheduling_pndm.py", line 382, in step_plms
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
File "/mnt/disks/hanq/miniconda3/envs/py13/lib/python3.13/site-packages/diffusers/schedulers/scheduling_pndm.py", line 418, in _get_prev_sample
alpha_prod_t = self.alphas_cumprod[timestep]
~~~~~~~~~~~~~~~~~~~^^^^^^^^^^
File "/mnt/disks/hanq/torch_xla/torchax/torchax/tensor.py", line 235, in __torch_function__
return self.env.dispatch(func, types, args, kwargs)
~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/disks/hanq/torch_xla/torchax/torchax/tensor.py", line 589, in dispatch
res = op.func(*args, **kwargs)
File "/mnt/disks/hanq/torch_xla/torchax/torchax/ops/jtorch.py", line 290, in getitem
indexes = self._env.t2j_iso(indexes)
^^^^^^^^^
AttributeError: 'Tensor' object has no attribute '_env'
error 4: need to move scheduler
This almost looks like a bug, so let's investigate it throughly.
Let's run our script under pdb:
python -m pdb jax_hf_04.py
Going up the stack and printing out the self.alphas_cumprod
:
(Pdb) p type(self.alphas_cumprod)
<class 'torch.Tensor'>
we notice that this variable alphas_cumprod did not move the jax
device
as pipe.to('jax')
should have done.
Turns out, the scheduler object in pipe
, of type PNDMScheduler
is not
a torch.nn.Module
so tensors on it doesn't get moved with the to
syntax:
(Pdb) p self
PNDMScheduler {
"_class_name": "PNDMScheduler",
"_diffusers_version": "0.35.1",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": false,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"set_alpha_to_one": false,
"skip_prk_steps": true,
"steps_offset": 1,
"timestep_spacing": "leading",
"trained_betas": null
}
(Pdb) p isinstance(self, torch.nn.Module)
False
With this identified, we can move it to ourselves.
Let's add the following function
def move_scheduler(scheduler):
for k, v in scheduler.__dict__.items():
if isinstance(v, torch.Tensor):
setattr(scheduler, k, v.to('jax'))
and call it right after pipe.to('jax')
pipe.to('jax')
+ move_scheduler(pipe.scheduler)
...
Adding this makes above script run successfully, and we got our first image:
Compiling VAE
Despite that the above script finish running, and despite of this line
pipe.vae = torchax.compile(pipe.vae)
We actually didn't run the compiled version of VAE. Why? Because by default
torchax.compile
only compiles the forward
method of the module (as specified here). VAE when it
is getting called, the decode
method is getting called instead.
pipe.vae = torchax.compile(
pipe.vae,
torchax.CompileOptions(
methods_to_compile=['decode'],
jax_jit_kwargs={'static_argnames': ('return_dict', )}
)
)
the correct version is to replace it with the above.
Here, the method_to_compile
option specifies which methods to compile.
Let's add JAX profiling and see the picture:
Let's add the above again after computing the first time, so that we are not capturing compile times.
iteration 2 took: 5.942152s
Having the above really makes a difference:
Before compiling the VAE, it takes 5.9 seconds to generate one image on A100 GPU, after it, it takes 1.07s instead.
iteration 0 took: 53.946763s
100%|█████████████████████████████████████████| 20/20 [00:01<00:00, 19.73it/s]
iteration 1 took: 1.074522s
100%|█████████████████████████████████████████| 20/20 [00:01<00:00, 19.82it/s]
iteration 2 took: 1.067044s
(Like before, we run the code 3 times to show both the time with and without compilation)
Conclusion
In the blog we show case that we can run the Stable Diffusion model from Huggingface using JAX. The only issues we need to deal with is again, pytree registration and static args for compilation. This shows that HuggingFace can support JAX as a framework without needing to reimplement all its model using JAX/flax! gingface-model-in-jax.md), ep02, ep03), we have run the Llama model in Jax using PyTorch model definitions from HuggingFace transformers and torchax as the interoperability layer. In this episode, we will do so for a image generation model.
The stable diffusion
Let's start with instatiating a Stable Diffusion model from
HuggingFace diffusers, in a simple script. (my case it's saved in jax_hg_04.py
)
import time
import functools
import jax
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
print(type(pipe))
print(isinstantance(pipe, torch.nn.Module))
Running the above you will see
<class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'>
False
There we found something unusual: the StableDiffusionPipeline
is NOT a torch.nn.Module
.
Recall previously (part 1), to convert a torch model to a JAX callable, we use torchax.extract_jax
which only works with torch.nn.Module
s.
Components of StableDiffusion Pipeline:
Looking at the pipe
object above:
In [6]: pipe
Out[6]:
StableDiffusionPipeline {
"_class_name": "StableDiffusionPipeline",
"_diffusers_version": "0.35.1",
"_name_or_path": "stabilityai/stable-diffusion-2-base",
"feature_extractor": [
"transformers",
"CLIPImageProcessor"
],
"image_encoder": [
null,
null
],
"requires_safety_checker": false,
"safety_checker": [
null,
null
],
"scheduler": [
"diffusers",
"PNDMScheduler"
],
"text_encoder": [
"transformers",
"CLIPTextModel"
],
"tokenizer": [
"transformers",
"CLIPTokenizer"
],
"unet": [
"diffusers",
"UNet2DConditionModel"
],
"vae": [
"diffusers",
"AutoencoderKL"
]
}
We can see that the pipeline actually have many components, inspecting it
in REPL, we can see that the components vae
, unet
and text_encoder
are torch.nn.Module
's. They will be our starting point.
The torchax.compile
API
For this blog post, we will show case compile
API in torchax.
This API is like torch.compile
; except, instead of using torch-inductor to compile
your model; it is powered with jax.jit
. This way, we will get jax compiled performance
instead of Jax eager model.
This is an wrapper over a torch.nn.Module
, but will use jax.jit
on the
forward function of this module. The wrapped JittableModule
is still a
torch.nn.Module
; so we could substitute it into the pipeline.
So let's modify the above script to
import time
import functools
import jax
import torch
from diffusers import StableDiffusionPipeline
import torchax
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base")
env = torchax.default_env()
prompt = "a photograph of an astronaut riding a horse"
with env:
# Moves the weights to 'jax' device: i.e. to tensors backed by jax.Array's
pipe.to('jax')
pipe.unet = torchax.compile(
pipe.unet
)
pipe.vae = torchax.compile(pipe.vae)
pipe.text_encoder = torchax.compile(pipe.text_encoder)
image = pipe(prompt, num_inference_steps=20).images[0]
image.save('astronaut_png')
Running it we got:
TypeError: function call_torch at /mnt/disks/hanq/torch_xla/torchax/torchax/interop.py:224 traced for jit returned a value of type <class 'transformers.modeling_outputs.BaseModelOutputWithPooling'> at output component jit, which is not a valid JAX type
Error 1: Pytree's again
Again, we got pytree issues that we deal many times throughout the posts. This is as simple as registrying it as follows:
from jax.tree_util import register_pytree_node
import jax
def base_model_output_with_pooling_flatten(v):
return (v.last_hidden_state, v.pooler_output, v.hidden_states, v.attentions), None
def base_model_output_with_pooling_unflatten(aux_data, children):
return BaseModelOutputWithPooling(*children)
register_pytree_node(
BaseModelOutputWithPooling,
base_model_output_with_pooling_flatten,
base_model_output_with_pooling_unflatten
)
Running it again we hit the second error:
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[]
This occurred in the item() method of jax.Array
The error occurred while tracing the function call_torch at /mnt/disks/hanq/torch_xla/torchax/torchax/interop.py:224 for jit. This concrete value was not available in Python because it depends on the value of the argument kwargs['return_dict'].
Error 2: Static argnames
This error also looks familiar: the pipeline is calling the model passing return_dict=True
(or False
, we don't really care);
by default, JAX thinks it's a variable (that can change); but here we should really treat it as a constant (so that) different
values should trigger recompile.
If we are calling jax.jit
ourselves, we would be passing static_argnames
to jax.jit
( See API doc here); but here, we are using torchax.compile
; how do we
do that?
We can do that by passing CompileOptions
to torchax.compile
, as so:
pipe.unet = torchax.compile(
pipe.unet, torchax.CompileOptions(
jax_jit_kwargs={'static_argnames': ('return_dict', )}
)
)
So basically we can pass arbitrary kwargs
to the underlying jax.jit
as a dictionary.
Now, let's try again. This time we are greeted with the following, more scary error:
Traceback (most recent call last):
File "/mnt/disks/hanq/learning_machine/jax-huggingface/jax_hg_04.py", line 43, in <module>
image = pipe(prompt, num_inference_steps=20).images[0]
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/disks/hanq/miniconda3/envs/py13/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/mnt/disks/hanq/miniconda3/envs/py13/lib/python3.13/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 1061, in __call__
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/disks/hanq/miniconda3/envs/py13/lib/python3.13/site-packages/diffusers/schedulers/scheduling_pndm.py", line 257, in step
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/disks/hanq/miniconda3/envs/py13/lib/python3.13/site-packages/diffusers/schedulers/scheduling_pndm.py", line 382, in step_plms
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
File "/mnt/disks/hanq/miniconda3/envs/py13/lib/python3.13/site-packages/diffusers/schedulers/scheduling_pndm.py", line 418, in _get_prev_sample
alpha_prod_t = self.alphas_cumprod[timestep]
~~~~~~~~~~~~~~~~~~~^^^^^^^^^^
File "/mnt/disks/hanq/torch_xla/torchax/torchax/tensor.py", line 235, in __torch_function__
return self.env.dispatch(func, types, args, kwargs)
~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/disks/hanq/torch_xla/torchax/torchax/tensor.py", line 589, in dispatch
res = op.func(*args, **kwargs)
File "/mnt/disks/hanq/torch_xla/torchax/torchax/ops/jtorch.py", line 290, in getitem
indexes = self._env.t2j_iso(indexes)
^^^^^^^^^
AttributeError: 'Tensor' object has no attribute '_env'
error 4: need to move scheduler
This almost looks like a bug, so let's investigate it throughly.
Let's run our script under pdb:
python -m pdb jax_hf_04.py
Going up the stack and printing out the self.alphas_cumprod
:
(Pdb) p type(self.alphas_cumprod)
<class 'torch.Tensor'>
we notice that this variable alphas_cumprod did not move the jax
device
as pipe.to('jax')
should have done.
Turns out, the scheduler object in pipe
, of type PNDMScheduler
is not
a torch.nn.Module
so tensors on it doesn't get moved with the to
syntax:
(Pdb) p self
PNDMScheduler {
"_class_name": "PNDMScheduler",
"_diffusers_version": "0.35.1",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": false,
"num_train_timesteps": 1000,
"prediction_type": "epsilon",
"set_alpha_to_one": false,
"skip_prk_steps": true,
"steps_offset": 1,
"timestep_spacing": "leading",
"trained_betas": null
}
(Pdb) p isinstance(self, torch.nn.Module)
False
With this identified, we can move it to ourselves.
Let's add the following function
def move_scheduler(scheduler):
for k, v in scheduler.__dict__.items():
if isinstance(v, torch.Tensor):
setattr(scheduler, k, v.to('jax'))
and call it right after pipe.to('jax')
pipe.to('jax')
+ move_scheduler(pipe.scheduler)
...
Adding this makes above script run successfully, and we got our first image:
Compiling VAE
Despite that the above script finish running, and despite of this line
pipe.vae = torchax.compile(pipe.vae)
We actually didn't run the compiled version of VAE. Why? Because by default
torchax.compile
only compiles the forward
method of the module (as specified here). VAE when it
is getting called, the decode
method is getting called instead.
pipe.vae = torchax.compile(
pipe.vae,
torchax.CompileOptions(
methods_to_compile=['decode'],
jax_jit_kwargs={'static_argnames': ('return_dict', )}
)
)
the correct version is to replace it with the above.
Here, the method_to_compile
option specifies which methods to compile.
Let's add JAX profiling and see the picture:
Let's add the above again after computing the first time, so that we are not capturing compile times.
iteration 2 took: 5.942152s
Having the above really makes a difference:
Before compiling the VAE, it takes 5.9 seconds to generate one image on A100 GPU, after it, it takes 1.07s instead.
iteration 0 took: 53.946763s
100%|█████████████████████████████████████████| 20/20 [00:01<00:00, 19.73it/s]
iteration 1 took: 1.074522s
100%|█████████████████████████████████████████| 20/20 [00:01<00:00, 19.82it/s]
iteration 2 took: 1.067044s
(Like before, we run the code 3 times to show both the time with and without compilation)
The final script is located here: https://github.com/qihqi/learning_machine/blob/main/jax-huggingface/jax_hg_04.py
Conclusion
In the blog we show case that we can run the Stable Diffusion model from Huggingface using JAX. The only issues we need to deal with is again, pytree registration and static args for compilation. This shows that HuggingFace can support JAX as a framework without needing to reimplement all its model using JAX/flax!