Plugging Grain into JAX training: batching + accelerator transfer#
This guide covers the last mile between a Grain pipeline and a JAX training step: how to batch records into arrays of the right shape, and how to move those batches onto your accelerators efficiently: host-device prefetch, sharding across devices, and distributed-training shards.
# @test {"output": "ignore"}
!pip install grain
# @test {"output": "ignore"}
!pip install tensorflow_datasets
# @test {"output": "ignore"}
!pip install jax
import grain
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_datasets as tfds
1. Minimal end-to-end pipeline#
The shortest pipeline you’d want for JAX training: source -> shuffle -> preprocess -> batch -> iterate -> device_put -> step.
source = tfds.data_source("mnist", split="train")
ds = (
grain.MapDataset.source(source)
.seed(42)
.shuffle()
.map(
lambda r: {
"image": r["image"].astype(np.float32) / 255.0,
"label": r["label"],
}
)
.batch(batch_size=128, drop_remainder=True) # new leading dim
.to_iter_dataset()
)
for batch in ds:
batch = jax.device_put(batch) # default device
print(jax.tree.map(lambda x: (x.shape, x.dtype), batch))
break
{'image': ((128, 28, 28, 1), dtype('float32')), 'label': ((128,), dtype('int32'))}
WARNING:absl:OpenCV is not installed. We recommend using OpenCV because it is faster according to our benchmarks. Defaulting to PIL to decode images...
A few things to notice:
batch(...)lives onMapDataset. It stacks PyTree leaves along a new leading axis (here[128, 28, 28, 1]for images,[128]for labels).drop_remainder=Trueguarantees a static batch shape, which letsjax.jitcache one compiled version of the step.to_iter_dataset()turns the random-accessMapDatasetinto anIterDataset. Do this after any random-access transforms (shuffle, batch, repeat) and before any streaming transforms (prefetch,device_put).
2. Batching tips that matter for JAX#
Stable shapes. JAX recompiles whenever input shapes change. Pair batch(drop_remainder=True) with .repeat() so the loop never produces a short final batch:
ds = (
grain.MapDataset.source(source)
.seed(42)
.shuffle()
.repeat() # infinite stream
.map(
lambda r: {
"image": r["image"].astype(np.float32) / 255.0,
"label": r["label"],
}
)
.batch(128, drop_remainder=True)
)
print("length:", len(ds)) # sys.maxsize
length: 72057594037927935
Custom collation. The default batch_fn stacks leaves with np.stack. Pass your own when you need padding, ragged handling, or anything non-uniform:
def pad_collate(items):
max_len = max(x["tokens"].shape[0] for x in items)
tokens = np.stack(
[np.pad(x["tokens"], (0, max_len - x["tokens"].shape[0])) for x in items]
)
return {"tokens": tokens}
# Toy stream of variable-length token sequences.
ragged = grain.MapDataset.source(
[{"tokens": np.arange(np.random.randint(2, 6))} for _ in range(16)]
)
ragged = ragged.batch(4, batch_fn=pad_collate, drop_remainder=True)
for i in range(3):
print(ragged[i]["tokens"].shape)
(4, 4)
(4, 5)
(4, 5)
For variable-length token streams, also look at grain.experimental.batch_and_pad — it pads partial final batches to the requested batch size with a sentinel, so you keep one static shape without dropping data.
3. Moving batches to the accelerator#
There are three options. Pick the lowest tier that meets your needs.
Option A: plain jax.device_put#
Fine for prototyping and small models:
ds = (
grain.MapDataset.source(source)
.seed(42)
.shuffle()
.map(
lambda r: {
"image": r["image"].astype(np.float32) / 255.0,
"label": r["label"],
}
)
.batch(128, drop_remainder=True)
.to_iter_dataset()
)
for step, batch in zip(range(2), ds):
batch = jax.device_put(batch)
print(step, batch["image"].shape, batch["image"].sharding)
0 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
1 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
The transfer happens on the main thread between every next(...), so the host blocks while the device receives data. On a real training loop this can leave the accelerator idle.
Option B: overlap host work with ThreadPrefetchIterDataset#
Run the pipeline’s CPU work on a background thread so the next batch is ready by the time the device is done with the previous step:
ds = (
grain.MapDataset.source(source)
.seed(42)
.shuffle()
.map(
lambda r: {
"image": r["image"].astype(np.float32) / 255.0,
"label": r["label"],
}
)
.batch(128, drop_remainder=True)
.to_iter_dataset()
)
ds = grain.experimental.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=4)
ds = ds.map(jax.device_put) # transfer still on iter thread
for step, batch in zip(range(3), ds):
print(step, batch["image"].shape, batch["image"].sharding)
0 (128, 28, 28, 1) SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
1 (128, 28, 28, 1) SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
2 (128, 28, 28, 1) SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
Option C: two-stage prefetch with grain.experimental.device_put#
The recommended pattern for real training. It runs a CPU buffer and a device-resident buffer, so a batch is already on the accelerator before the step asks for it:
ds = (
grain.MapDataset.source(source)
.seed(42)
.shuffle()
.map(
lambda r: {
"image": r["image"].astype(np.float32) / 255.0,
"label": r["label"],
}
)
.batch(128, drop_remainder=True)
.to_iter_dataset()
)
ds = grain.experimental.device_put(
ds=ds,
device=jax.devices()[0], # or a Sharding (see below)
cpu_buffer_size=4, # batches buffered on host
device_buffer_size=2, # batches buffered on device
)
for step, batch in zip(range(2), ds):
# `batch` is already a jax.Array on-device.
print(step, batch["image"].sharding)
0 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
1 SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)
Under the hood this is just ThreadPrefetch -> map(jax.device_put) -> ThreadPrefetch.
4. Multi-device: sharding a batch across accelerators#
For data-parallel training across all local devices, pass a Sharding to device_put instead of a single device. Each batch is split along its first axis:
devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), axis_names=("data",))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))
ds = (
grain.MapDataset.source(source)
.seed(42)
.shuffle()
.repeat()
.map(
lambda r: {
"image": r["image"].astype(np.float32) / 255.0,
"label": r["label"],
}
)
.batch(128, drop_remainder=True)
.to_iter_dataset()
)
ds = grain.experimental.device_put(
ds=ds,
device=sharding,
cpu_buffer_size=4,
device_buffer_size=2,
)
for step, batch in zip(range(3), ds):
print(step, batch["image"].sharding)
0 NamedSharding(mesh=Mesh('data': 2, axis_types=(Auto,)), spec=P('data',), memory_kind=device)
1 NamedSharding(mesh=Mesh('data': 2, axis_types=(Auto,)), spec=P('data',), memory_kind=device)
2 NamedSharding(mesh=Mesh('data': 2, axis_types=(Auto,)), spec=P('data',), memory_kind=device)
Make sure batch_size is divisible by len(devices) — otherwise the sharding split fails. Inside your train step, decorate with jax.jit and JAX will compile a single SPMD program that handles the per-device slices automatically.
5. Putting it all together#
A realistic single-host, multi-device template:
BATCH_SIZE = 256
devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), axis_names=("data",))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("data"))
def preprocess(r):
return {"image": r["image"].astype(np.float32) / 255.0, "label": r["label"]}
ds = (
grain.MapDataset.source(source)
.seed(42)
.shuffle()
.repeat()
.map(preprocess)
.batch(BATCH_SIZE, drop_remainder=True)
.to_iter_dataset()
)
ds = grain.experimental.device_put(
ds=ds,
device=sharding,
cpu_buffer_size=4,
device_buffer_size=2,
)
@jax.jit
def train_step(params, batch):
# Replace with your real loss/update.
return params + batch["image"].mean()
params = jnp.zeros(())
for step, batch in zip(range(3), ds):
params = train_step(params, batch)
print(step, " params:", params)
0 params: 0.13213545
1 params: 0.2589087
2 params: 0.39069483