Let's reproduce NanoGPT with Jax (Part 2)

Louis Wang
7 min readAug 4, 2024

--

Part 1: Build 124M GPT2 with JAX.
Part 2: Optimize the training speed in Single GPU.
Part 3: Multi-GPU Training in Jax.

In Let’s reproduce NanoGPT with Jax (Part 1), we went through the process of building the NanoGPT with Jax from scratch, and introduce the fundamental APIs of JAX, Flax, and Optax. In this post, we will explain how to optimize the model to speed up training with more stable process. After this, you will have 7.7x speedup when comparing Jax’s optimized GPT with 1350k tokens/sec with Andrej’s fully optimized GPT with 175k tokens/sec in his video.

We will visit the following topics during the training optimization:

  • Increaze Batch size
  • Weight Sharing
  • Mixed Precision Training
  • Gradient Checkpointing
  • Gradient Clipping
  • Cosine learning rate scheduler
  • Gradient Accumulation
  • Tips for High-Performance LLMs with JAX and XLA

We skipped things like FlashAttention and torch.compile because the Just-in-time compilation feature in Jax does something very similar under-the-hood as torch.compile. FlashAttention also fused some operations to reduce some data movement between SRAM and HBM to speed up the training.Let’s now start the journey to 1350k tokens/sec training!

Increase Batch Size

First, always get the free lunch: using a larger batch size to maximize GPU utilization. Usually we can do the model profiling to monitor GPU utilization either with nvidia-smi or nvtop . As we know GPU has several streaming multiprocessors (SMs) which run the CUDA kernels. Using many SMs is a signal of a well-utilized GPU. A side note on the GPU memory utilization for Jax: it will preallocate 75% of the total GPU memory when the first JAX operation is run. Preallocating minimizes allocation overhead and memory fragmentation, but can sometimes cause out-of-memory (OOM) errors.

99% GPU utilization during training with 74% GPU memory usage.

Weight Sharing

Weight sharing is a trick to use the same dense layer between token embedding layer to the final pre-softmax dense layer. They share the same shape of [vocab_size, embedding_dimention]. In the simple gpt2 case, it is [50257, 768]. In the original OpenAI’ GPT2, those two weights matrics are identical, which significantly reduce the total parameters. By using this ,we can reduce the total parameters from 163M to 124M (768*50257=38.6M parameters reduction).

# logits = nn.Dense(self.config.vocab_size)(x)
logits = wte.attend(x) # parameter sharing

We can also estimate GPT parameters based on our config settings. We have wK, wQ, wV, and wO (for output layer) matrix for the attention calculation, each of which needs d_model * n_head * d_head = d_model ^ 2. We also have 2 MLP layers in the ResNet part, covering d_model * d_model * 4. In sum, we have 12 * d_model ^2 for each layer. Another part of parameters are from token embeddings: vocab_size * d_model.

12 (layer) * 12 * 768^2 + 50257 * 768 * 2 = 162,129,408 (without weight sharing)
12 (layer) * 12 * 768^2 + 50257 * 768 = 123,532,032 (without weight sharing)

The gap between the above numbers and the real total parameters are the parameters from bias and layernorm.

Mixed Precision Training

Here we use both 16-bit and 32-bit floating point durting training. For faster computation and less memory, we use 16-bit floating point numbers, while for the computations with model updates and loss calculation, 32-bit floating point numbers are used for numerical stability.

Usually we avoid float16 because it reduce the Exponents which could lead to underflow or overflow issues during training when the gradients are too small or too large for the range of float16that the information are lost. Loss scaling is a technique when we multiply a constant number on the loss and gradients to avoid lossing the information. Some GPUs support bfloat16 , which reduces Mantissa while keeps 8-bit Exponent to have a larger range than float16 . Let’s now try to use bfloat16 for the Dense layers and use float32 for the weights and optimizer states. Also for numerical stability, we also use float32 for softmax.

To easily control the data type, we just pass the dtype as jnp.bfloat16 into the model config, and specify it for the flax.linen.Dense , and manually change q and k back to jnp.float32 for softmax calculation.

q     = nn.Dense(self.config.n_embd, dtype=self.config.dtype, kernel_init=self.config.kernel_init, bias_init=config.bias_init)(x)
k = nn.Dense(self.config.n_embd, dtype=self.config.dtype, kernel_init=self.config.kernel_init, bias_init=config.bias_init)(x)
v = nn.Dense(self.config.n_embd, dtype=self.config.dtype, kernel_init=self.config.kernel_init, bias_init=config.bias_init)(x)
# q*k / sqrt(dim) -> softmax -> @v
q = jnp.reshape(q, (b, l, d//self.config.n_head , self.config.n_head)).astype(jnp.float32)
k = jnp.reshape(k, (b, l, d//self.config.n_head , self.config.n_head)).astype(jnp.float32)
v = jnp.reshape(v, (b, l, d//self.config.n_head , self.config.n_head))
norm = jnp.sqrt(list(jnp.shape(k))[-1])
attn = jnp.matmul(q,jnp.transpose(k, (0,1,3,2))) / norm
mask = jnp.tril(attn)
attn = jnp.where(mask[:,:,:l,:l], attn, float("-inf")).astype(jnp.float32)
probs = jax.nn.softmax(attn, axis=-1).astype(self.config.dtype)
y = jnp.matmul(probs, v)
y = jnp.reshape(y, (b,l,d))
y = nn.Dense(self.config.n_embd)(y).astype(self.config.dtype)

Gradient Checkpointing with jax.checkpoint (jax.remat)

Gradient checkpointing is a technique to trade the compute for memory: only re-compute some activations during the backward pass instead of keeping all activations in memory after the forward pass. This is very useful for models like transformers where the activatons covers a significant partion of the memory consumption.

In Jax, we can achieve the gradient checkpointing with jax.remat function. From its doc: Use the jax.checkpoint decorator (aliased as jax.remat) with jax.grad to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.

We can use the rematerialization with some certain policy. For example, for the FLOP-bound operations like dot-product, we can use policy=jax.checkpointing_policies.dots_with_no_batch_dims_savable to save the results rather than recomputing them.

model = nn.remat(
GPT, policy=jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims
)(config)

Gradient Clipping and learning rate scheduler

To have a more stable training for large models with millions of parameters like GPT, we follow the original setting in the GPT3 paper to use gradient clipping with max norm as 1.0, and use cosine learning rate scheduler.

Gradient clipping is a simple but effective way to control the training stability for deep learning models. When the model is updated during trianing to minimize the loss, the gradients of the loss function with respect to the model’s parameters are computed. However, when the gradients are large (called gradient exploding), they can cause the model’s parameters to change rapidly, leading to unstable training and potentially causing the model to diverge or converge to a suboptimal solution. Gradient clipping is a technique that clips the magnitude of the gradients to prevent them from becoming too large. It can be done withoptax.clip_by_global_norm .

Learning rate scheduler helps the model converge faster with adjustable learning rate values The Cosine learning rate scheduler is a type of learning rate scheduler that uses a cosine function to schedule the learning rate. It starts with a warmup period to increase the learning rate linearly to a peak value, then start to decay the learning rate with cosine function to the end of the decay steps with the end value.

We can use optax.chain to combine the gradient clipping and learning rate scheduler together.

learning_rate = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=2.5e-4,
warmup_steps= 2000,
decay_steps= 150000,
end_value = 1e-5,
)
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(learning_rate, b1=0.9, b2=0.95, weight_decay=1e-2)
)

Gradient Accumulation

Gradient accumulation is to simulate the training with larger batch size for more accurate estimate of the gradient with limited accelerator memory. We use it to accumulate the gradients over multiple micro-batches. Each sub-batch is independently processed and we do the backward pass and update the optimizer once all the sub-batches have been processed. Let’s add gradient_accumulation_steps in the model config, and when its value is greater than 1, we apply optax.MultiSteps to only update the optimizer every k steps.

if config.gradient_accumulation_steps>1:
optimizer = optax.MultiSteps(
optimizer, every_k_schedule=config.gradient_accumulation_steps
)

We also change the training loop to acknowledge the accumulation steps:

for step in range(train_steps):
t0 = time.time()
for _ in range(config.gradient_accumulation_steps):
x,y = data_loader.next_batch()
loss, train_state = train_step(train_state, x, y)
t1 = time.time()
dt = t1 - t0
tokens_processed = data_loader.B * data_loader.T * config.gradient_accumulation_steps
tokens_per_sec = tokens_processed / dt
print(f"step {step}/{train_steps} | loss: {loss:4f} | dt: {dt*1000:.2f}ms | token/sec = {tokens_per_sec:.2f}")

Tips for High-Performance LLMs with JAX and XLA

Based on Jax’s doc, the following settings of XLA flags can improve the performance. Some are related to communication between GPUs, and so are only relevant when running computations on multiple devices, while others are related to code generation on each device. Here is a more detailed document.

import os
os.environ['XLA_FLAGS'] = (
'--xla_gpu_enable_triton_softmax_fusion=true '
'--xla_gpu_triton_gemm_any=True '
'--xla_gpu_enable_async_collectives=true '
'--xla_gpu_enable_latency_hiding_scheduler=true '
'--xla_gpu_enable_highest_priority_async_stream=true '
)

Let’s now see the outcomes: The first screenshot below was from Andrej’s fully optimized pytorch scripts with (B=16, T=1024) dataloader. The following two are from Jax’s trainer. This is a 7.7x speed up!

Andrej Karpathy’s final optimized Pytorch GPT: ~170k tokens/sec
Jax GPT Before Optimization: 650k tokens/sec
Jax GPT After optimization: 1350k tokens/sec

After the above optimization, we make the GPT training more faster with 1350k tokens/sec. In part 3, we will discuss how to enable multi-host training to further maximum the performance.

--

--

Louis Wang

Machine Learning Engineer @ Snap. I write and talk about ML, and others.