How to migrate From Nvidia GPU to Google TPU [Part 2]¶
Published: 2026-04-26
Alternative title: How to migrate from PyTorch to Jax
As mentioned in previous article, One of the biggest mental model shifts if that we need to think about what computational graphs we are launching on our devices, and their properties.
The questions to answer are:
- What are the graphs that we are launching?
- What are the inputs and outputs of the graph?
- What are the set of shapes+dtypes that these inputs / outputs will be.
Question number 3 is important to avoid recompilations unnecessarily.
Note that, none of the above questions are TPU specific, rather, they are Jax-specific (or XLA specific), which is great because it means we can continue this exploration without buying TPUs first.
Let's first focus on a training example.
Common training setup¶
Let's go through well understood example: say a training setup like the one in https://docs.pytorch.org/tutorials/beginner/introyt/trainingyt.html
In essence, the tutorial from official Pytorch website does the following:
- Defines a model
- Defines optimizer
- Actual train loop: a. Load data b. run forward pass with data c. run backward pass d. run optimizer step
The train loop is illustrated in this bit of code (copied from the tutorial above)
def train_one_epoch(epoch_index, tb_writer):
running_loss = 0.
last_loss = 0.
# Here, we use enumerate(training_loader) instead of
# iter(training_loader) so that we can track the batch
# index and do some intra-epoch reporting
for i, data in enumerate(training_loader):
# Every data instance is an input + label pair
inputs, labels = data
# Zero your gradients for every batch!
optimizer.zero_grad()
# Make predictions for this batch
outputs = model(inputs)
# Compute the loss and its gradients
loss = loss_fn(outputs, labels)
loss.backward()
# Adjust learning weights
optimizer.step()
# Gather data and report
running_loss += loss.item()
if i % 1000 == 999:
last_loss = running_loss / 1000 # loss per batch
print(' batch {} loss: {}'.format(i + 1, last_loss))
tb_x = epoch_index * len(training_loader) + i + 1
tb_writer.add_scalar('Loss/train', last_loss, tb_x)
running_loss = 0.
return last_loss
Now, we want to pin point what are the computational graphs here. Naively, one might think that just making 3 graphs (forward, backward, optimizer) is the most easy one. It turns out that it doesn't work. It will become clear why when we try to define the inputs and outputs of the graph.
The common approach is either 1. One graph with the entire train step. Or, 2. One forward-backward joint graph, and one optimizer graph.
Let's see what are the inputs and outputs are:
a single train step¶
Say, we decided to define one graph for the train step. What inputs is needed to do a train step? We can do this by asking the question on each line of code, what data it needs.
output = model(input)the input part of data, from(input, label)pair.- But also, implicitly, we need the model's current weight! (those are stored in the model itself)
loss = loss_fn(outputs, labels)this needs thelabelpart of data.- Technically, It also need the implicit state that when
model(input)is ran saved, those are needed to do backward, with our one graph setup, this part are not considered outside input, as they are generated from the computation within the graph itself. - Now optimizer step: it will need the gradients, the weights, and optimizer's internal state.
What about the output? i.e. what are we try to accomplish when running the train step? Here the model weights are updated in-place, and that is our goal. So the output of our graph must include the updated weights. We also need to return the updated optimizer step, because optimizer might keep internal state.
Ok, with the above analysis, we now have that our train step's function signature will look like this:
def train_step(data, model_weights, opt_state) -> (updated_model_weights, updated_opt_stae)
We can also throw in the loss there, as even though it is not needed, it's nice to log
it out to a monitoring service, say Weights and Biases etc. We cannot
do that inside of train_step because train_step must be a pure function (i.e. no side effects).
If we need to log out more stuff, say you want to monitor the abs-value of norm of gradients to
detect gradient explosion etc, they right approach is also to return it.
So let's add a metrics field to return, that will be a dict containing loss and other metrics.
def train_step(data, model_weights, opt_state) -> (updated_model_weights, updated_opt_state, metrics)
Let's meditate on this pure function idea a little more. Another property of pure functions is that the output is completely determined by it's inputs. In otherwords, if you want the computation to do something different next time, then pass in a different input.
This is profund when dealing with randomness, as pure function basically says no randomness. Say if your model uses dropout, and want to drop different things on different train steps, or have other random behavior that helps training, then you want to pass in the random number generator seed in, as another input.
With all the above information, if I were to rewrite using PyTorch to conform with pure function requirement, it would look like this:
def train_step(data, model_weights, opt_state, rng):
torch.manual_seed(rng)
input, label = data
# make a new optimizer to simulate opt logic using passed in state
optimizer = Optimizer(model_weights)
optimizer.load_state_dict(opt_state)
optimizer.zero_grad()
# functional call allows evaluate model with passed in weights
output = torch.func.functional_call(model, model_weights, args=(input, ))
loss = loss_fn(output, label)
loss.backward()
optimizer.step()
metrics['loss'] = loss
return model.state_dict(), optimizer.state_dict(), metrics
We can feel the crunckyness of writing torch code this way, because, well, torch assumes you will be doing in-place updates eagerly, and that is not Jax (or XLA or TPU) compatible. This is not idiomatic or fully runnable PyTorch.
def train_step(rng, data, model_weights, opt_state):
input, label = data
# this function is a pure function too
def compute_loss(model_weights, rng, data):
input, label = data
# Jax style model is usually a pure function that takes in model_weights already
output = model(rng, model_weights, input)
return loss_fn(output, label)
# this function computes gradient wrt the FIRST argument,
# so it's important to put model_weights as first argument when
# defining loss_fn above
grad_fn = jax.value_and_grad(compute_loss)
loss, grads = grad_fn(model_weights, rng, data)
updates, updated_opt_state = optimizer.update(grads, opt_state)
updated_weights = optax.apply_update(model_weights, updates)
metrics['loss'] = loss
return updated_weights, updated_opt_state, metrics
Notice that, despite that both looks different, they are actually representing the same computation. In other words, it's mostly cosmetic changes.
Now we represented our device computational graph as a Python function. The primary way for Jax to define a computational graph is actually exactly this: define a python pure function.
To compile this function with XLA, we wrap it with jax.jit
So our train loop look like:
step_fn_compiled = jax.jit(train_step)
optimizer = optax.adam(0.1)
opt_state = optimizer.init(weights)
rng = ... # defines rng
for i, data in enumerate(dataloader):
# this generates a new rng each time, so we would run different dropouts on model
rng, rng_to_use = jax.random.split(rng)
weights, opt_state, metrics = step_fn_compiled(rng_to_use, data, weights, opt_state)
if i % 1000 == 999:
# log out the loss
tb_writer.add_scalar('Loss/train', metrics['loss'], tb_x)
Above we are 1. defining the computational graph (with jax.jit), and 2. repeatedly calling it. All side effects, such as writing loss to the dashboard etc, are handled outside of the graph.
Now, think about input shapes¶
What are the shapes of the inputs? The weights is a dict of model weights (or more generally, a [pytree] of model weights), and a model's weights shape are predefined, so their shapes don't change. Same for optimizer state.
So only places that the shapes are variable are in data.
The data is whatever the training data is stored in the files. So to make sure
that the shapes are uniform, we need to make sure that the input data have the same shape.
I have seen cases that are not.
For example, for next-token prediction, the input sequence length can be variable length, as in eager mode, the cross entropy loss is between token (0, n - 1) to token (1, n) and is equally valid for any n. Or, for image processing, many models can take in images of arbitrary resolution, and there can be images of different resolution in the training set.
During normal PyTorch training, the variability of training data shapes doesn't matter, but it start to matter for JAX/XLA. Everytime a different shape is processed, Jax recompiles the graph and it can take take a few minutes.
Here the recommendation is to limit how many distinct shapes your graph sees. For sequences, we can either pad to nearest power of 2 (or some buckets), or concatenate sequences too a known shape (say, 8k) with paddings/maskings so that the model can process many sequences at once. This usually requires changes in the upstream data processing pipeline to produce data in this shape (which is outside of the main training loop).
Few more things to tidy up¶
With the above excercise, hopefully one will have a good idea what is the effort of migrating to Jax that is beyond simple syntax remap. A few more things that is useful to add:
1. donating buffer for weights / opt_state.¶
As currently stands, it actually uses double the memory for weights,
one copy for the weights itself, and one copy for the updated weights that
step_fn_compiled returns. After the returned weight is assigned to weights,
gc removes one. So the peak memory is double of weights (without according opt_state).
The trick to make jax reuse memory which will have the same effect of inplace update is
with buffer donation. This way, the function itself is still logically pure function,
just that some inputs will be invalidated after used.
To use it, imply pass donate_argnums kwarg to jax.jit, see more here
i.e. replace
step_fn_compiled = jax.jit(train_step)
with
step_fn_compiled = jax.jit(train_step, donate_argnums=(2, 3))
2. Gradient checkpointing¶
A common training trick to save memory is gradient checkpoint, which is the technique of not saving some intermediate state required for computing gradients in memory, but recompute them in backward. This trades compute for memory.
To get that for jax, on simply wrap the loss function with jax.checkpoint
before passing it to jax.value_and_grad.
One can get host offloading with this API as well, which is to save the intermediates in CPU ram instead of recomputing.
2 graph version¶
I will be brief on the thought process because it's very similar to one step version. One graph will be:
def forward_backward(weights, data) -> (loss, gradients)
def optimizer_update(weights, gradients, opt_state) -> (updated_weights, updated_opt_state)
Both graph are jax.jit'd and launched in sequence.
Also a side note on why separating forward and backward graph does not work:
forward graph must return the intermediates needed to compute backward,
and backward need to take those as input. It's very hard to write code to express
that cleanly because the intermediates saved are usually hidden by the framework
and is not something cleanly visible. (We can write it out if we have to, though. It's just troublesome).
Conclusion¶
Beyond just translating syntax of PyTorch to Jax, we looked at how to analyze your existing training loop setup and think of how to represent it as repeatedly launch pure computational graph. Hopefully this can serve you on evaluating how much effort it will take to migrate to Jax.