Orbax#

Orbax is an umbrella namespace providing common training utilities for JAX users. It includes multiple distinct but interrelated libraries.

Checkpointing

A flexible and customizable API for managing checkpoints consisting of various user-defined objects in multi-host, multi-device settings.

Exporting

A library for exporting JAX models to Tensorflow SavedModel format.

Installation#

There is no single orbax package, but rather a separate package for each functionality provided by the Orbax namespace.

The latest release of orbax-checkpoint can be installed from PyPI using

pip install orbax-checkpoint

You may also install directly from GitHub, using the following command. This can be used to obtain the most recent version of Optax.

pip install 'git+https://github.com/google/orbax/#subdirectory=checkpoint'

NOTE: Certain edge cases of orbax-checkpoint may not work on Windows.

Also, supporting them is not planned yet.

Similarly, orbax-export can be installed from PyPI using

pip install orbax-export

Install from GitHub using the following.

pip install 'git+https://github.com/google/orbax/#subdirectory=export'

Checkpointing#

Getting Started
orbax_checkpoint_101.html
API Refactor
api_refactor.html
Checkpointing PyTrees of Arrays
checkpointing_pytrees.html
Checkpoint Format Guide
checkpoint_format.html
Optimized Checkpointing
optimized_checkpointing.html
Custom Handlers
custom_handlers.html
Transformations
transformations.html
Preemption Tolerance
preemption_checkpointing.html
Async Checkpointing
async_checkpointing.html

Exporting#

Getting Started
orbax_export_101.html

Support#

Please report any issues or request support using our issue tracker.