Grain - Feeding JAX Models#
Grain is a library for reading data for training and evaluating JAX models. It’s open source, fast and deterministic.
Users can bring arbitrary Python transformations.
Grain is designed to be modular. Users can readily override Grain components if need be with their own implementation.
Multiple runs of the same pipeline will produce the same output.
Grain is designed such that checkpoints have minimal size. After pre-emption, Grain can resume from where it left off and produce the same output as if it was never preempted.
We took care while designing Grain to ensure that it’s performant (refer to the Behind the Scenes section of the documentation.) We also tested it against multiple data modalities (e.g.Text/Audio/Images/Videos).
Grain minimizes its set of dependencies when possible. For example, it should not depend on TensorFlow.