Running Native PyTorch on TPUs with Zero Code Changes

Community Article Published February 21, 2026

If you spend your days building, training, or fine-tuning deep learning models, you already know the great divide in the AI hardware space. For the longest time, the unwritten rule has been: if you are using Nvidia GPUs, you write PyTorch. If you are using Google’s TPUs, you write JAX.

Sure, there have been attempts to bridge this gap in the past. We’ve seen projects that tried to wrap PyTorch to make it run on TPUs. But let’s be brutally honest—they always felt exactly like what they were: wrappers. They forced you to change your mental model. They forced PyTorch to act like JAX under the hood, relying on things like "Lazy Tensors" or making you rewrite your training loops to fit a paradigm that just didn't feel natural to a PyTorch developer.

But what if you didn't have to choose? What if you could take an unmodified, off-the-shelf PyTorch model from Hugging Face, run it on a TPU, and get top-tier performance?

That is exactly what is happening right now with a massive engineering effort known as TorchTPU. It’s arguably one of the most exciting shifts happening in the ML framework infrastructure space, and it’s going to completely change how we think about hardware lock-in.

Let’s get into the weeds of how this actually works, why it’s so hard to pull off, and why it finally feels like true PyTorch.

The Magic Trick: Just Change the Device String

To understand why TorchTPU is such a big deal, you have to look at the developer experience.

Think about how you normally write PyTorch code targeting a GPU. You load your model, you load your tokenizer, and somewhere at the top of your script, you have a line that looks like this:

device = torch.device('cuda')

Then you map your model to that device using .to(device).

If you wanted to move that to a TPU in the past, you were looking at a significant refactor. You had to learn new APIs, change how your data loaders worked, and fundamentally alter your training loop.

With TorchTPU, the goal is brutal simplicity. To run a complex model—say, fine-tuning Llama 3.2 1B on a custom dataset—you literally make one change.

You import the TPU module, and you change your device string:

device = tpu.get_device()

That’s it. You leave your model alone. You leave your training loop untouched. You don't have to rewrite the attention mechanism. You don't have to touch your Hugging Face imports. You map the model to the TPU device, and it just runs.

When you run inference or kick off an instruction-tuning run, the loss curve drops exactly as it would on a GPU cluster. But how is it doing this without a massive translation layer that kills performance?

Disruptive Principle 1: True PyTorch Citizenship

The core philosophy behind TorchTPU is what the engineering team calls "PyTorch Citizenship."

Previous attempts to get PyTorch on TPUs basically looked at JAX and said, "JAX is really fast on TPUs, so let's make PyTorch act like JAX." This meant forcing PyTorch developers to adopt JAX dependencies and use Lazy Tensors.

Lazy Tensors are a neat concept in theory—they build up a massive computation graph of your whole model before running anything, which allows the compiler to optimize the heck out of it. But it absolutely destroys the primary reason people love PyTorch: Eager Mode.

PyTorch developers love eager mode because it executes line-by-line. You can drop a print statement in the middle of a forward pass and see exactly what the tensor shape is. Lazy evaluation breaks that interactive, intuitive flow.

TorchTPU throws away the JAX-envy. There are no Lazy Tensors. There are no JAX dependencies forced on you. It is built on a solid foundation of standard PyTorch Eager Mode.

It also fully integrates with modern PyTorch features like torch.compile(), which uses Dynamo. Underneath, TorchTPU takes the PyTorch ATen operations (the foundational C++ tensor library inside PyTorch) and lowers them directly into StableHLO. StableHLO is the intermediate representation that the OpenXLA compiler uses to generate highly optimized machine code for the TPU hardware.

By natively mapping ATen to StableHLO, it bypasses the clunky translation layers of the past. It respects the PyTorch roadmap.

Disruptive Principle 2: Fixing the "Wall of Red Text"

If you are a developer, you know the absolute nightmare of framework errors. You make a tiny mistake in tensor dimensionality, and suddenly your terminal is filled with a 500-line C++ stack trace from deep inside a compiler you don't understand.

Older TPU integrations were notorious for this. Because they were trying to shoehorn PyTorch into a different compiler stack, errors wouldn't bubble up correctly. You’d get a runtime error from the bottom of the XLA stack that had absolutely no context about which line of your Python code caused it.

A massive part of the TorchTPU architecture is focused on developer quality-of-life. By building natively, they are ensuring that when things break, you get usable, actionable error messages. Instead of a cryptic memory fault from a hidden kernel, you get a clean Python traceback that points directly to the line in your model where the tensor size mismatch actually happened.

It sounds like a small thing, but when you are debugging a multi-billion parameter model, actionable error handling is the difference between going home at 5 PM and pulling an all-nighter.

Disruptive Principle 3: Making Eager Mode Actually Fast

So, TorchTPU gives you native Eager Mode. That’s great for debugging, but there is a massive technical elephant in the room: Eager Mode is historically slow on AI accelerators.

When you execute operations one by one (op-by-op), the hardware spends more time waiting for the next instruction from the CPU than it does actually crunching numbers. To get peak performance, you usually need to compile the whole graph so the hardware can fuse operations together (like combining a matrix multiplication and an activation function into a single physical step).

To compete directly with the highly optimized PyTorch-to-CUDA pipeline, the TorchTPU engineers had to get creative. They are implementing a "JIT (Just-In-Time) approach for Eager."

The goal is to bridge the gap between slow op-by-op execution and fast compiled execution. Even when you are running in eager mode, TorchTPU looks for opportunities to grab subgraphs—chunks of multiple operations—and fuse them together on the fly using the XLA compiler. This gives you the feel and debuggability of eager mode, but the massive acceleration of compiled code.

They are also modernizing the entire XLA stack underneath. A great example of this is how they handle Bounded Dynamism.

Language models deal with dynamic sequence lengths all the time. Traditional compilers hate dynamic shapes; they want to know exactly how big a tensor is so they can allocate the perfect amount of memory. If the shape changes, the compiler panics and triggers a full recompilation, which stalls your training loop. TorchTPU is pushing the XLA stack to handle bounded dynamism natively, meaning as long as the dynamic shapes stay within certain bounds, the TPU can keep crunching without stopping to recompile. They are also implementing pre-compiled kernels to drastically reduce startup times.

The Holy Grail: Solving Distributed Training

Everything I’ve mentioned so far is impressive, but if you want to know what makes TorchTPU a true game-changer, it’s how they handle distributed training.

Training massive models like Llama 3 70B doesn't happen on one chip. It happens across hundreds or thousands of chips.

In the JAX ecosystem, you handle this with a concept called "sharding." You write your code for one giant virtual device, and you tell the compiler how to shard (split) the tensors across the physical chips.

In the PyTorch ecosystem, developers don't use sharding. They use specific distributed idioms like DDP (Distributed Data Parallel) and FSDP (Fully Sharded Data Parallel).

Previous attempts to run PyTorch on TPUs basically said: "Hey PyTorch developers, we know you like FSDP, but to use TPUs, you need to rewrite your models to use JAX-style sharding."

It was a total non-starter. Nobody wants to rewrite their distributed logic.

TorchTPU fixes this by intercepting the communication at the exact right layer.

Think about how PyTorch Distributed works. At the very top, you have wrappers like FSDP v2 or DTensor. Those wrappers talk to the ProcessGroup API, which manages the collective communications across GPUs (things like AllReduce, AllGather, Broadcast). On Nvidia hardware, these process groups talk to NCCL (Nvidia Collective Communications Library).

Instead of trying to emulate sharding at the top layer, TorchTPU goes all the way down to the bottom. They intercept the collectives at the ProcessGroup level. When PyTorch calls for an AllReduce, TorchTPU maps that directly to a StableHLO collective operation that the TPU network understands perfectly.

The result? The top layer doesn't even know it's not running on GPUs.

You can take a model that relies heavily on standard PyTorch distributed APIs—whether it’s Hugging Face's implementation, Fairscale, or standard FSDP v2—and it just works. No patches. No rewrites. The PyTorch distributed wrappers pass instructions down, and TorchTPU quietly translates the network chatter into TPU-native operations.

What This Means for the Future

The progress here is staggering. Moving from a prototype to running distributed inference on a 70-billion parameter Llama 3 model across multiple chips happened in the span of just a few months. They are rapidly churning through the thousands of ATen operations required to support every weird and wonderful thing PyTorch can do.

For the developer community, this is a massive breath of fresh air.

We are finally moving toward a world where your code is completely decoupled from the metal it runs on. You shouldn't have to rewrite your model architecture or your training loop just because your cloud provider gave you a different type of accelerator.

By respecting the PyTorch ecosystem, keeping eager mode intact, and building a bridge from native ATen ops and distributed collectives straight to the XLA compiler, TorchTPU isn't just a wrapper. It’s a native citizen. And it means the days of hardware lock-in for AI frameworks might finally be coming to an end. Most of these info I wrote here are from PyTorch on TPUs. The team at Google is actively developing the PyTorch TPU stack to ensure a seamless experience. For background on their roadmap and current status, see their PyTorch on TPU RFC.

Community

Sign up or log in to comment