Grain - Feeding JAX Models

Grain - Feeding JAX Models#

Grain is a library for reading data for training and evaluating JAX models. It’s open source, fast and deterministic.

Powerful

Users can bring arbitrary Python transformations.

Flexible

Grain is designed to be modular. Users can readily override Grain components if need be with their own implementation.

Deterministic

Multiple runs of the same pipeline will produce the same output.

Resilient to preemptions

Grain is designed such that checkpoints have minimal size. After pre-emption, Grain can resume from where it left off and produce the same output as if it was never preempted.

Performant

We took care while designing Grain to ensure that it’s performant (refer to the Behind the Scenes section of the documentation.) We also tested it against multiple data modalities (e.g.Text/Audio/Images/Videos).

With minimal dependencies

Grain minimizes its set of dependencies when possible. For example, it should not depend on TensorFlow.