Skip to content

Add JAX training tutorial#1302

Open
mtsokol wants to merge 2 commits into
google:mainfrom
mtsokol:jax-training-docs
Open

Add JAX training tutorial#1302
mtsokol wants to merge 2 commits into
google:mainfrom
mtsokol:jax-training-docs

Conversation

@mtsokol
Copy link
Copy Markdown
Contributor

@mtsokol mtsokol commented May 11, 2026

This PR adds a guide on plugging Grain dataset in JAX training.


📚 Documentation preview 📚: https://google-grain--1302.org.readthedocs.build/

Comment thread docs/tutorials/jax_training_tutorial.md Outdated
[{"tokens": np.arange(np.random.randint(2, 6))} for _ in range(16)]
)
ragged = ragged.batch(4, batch_fn=pad_collate, drop_remainder=True)
print(ragged[0]["tokens"].shape)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a couple more to show that they are ragged?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment thread docs/tutorials/jax_training_tutorial.md Outdated
ds = ds.map(jax.device_put) # transfer still on iter thread

first = next(iter(ds))
print(first["image"].shape, first["image"].sharding)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print another one or two to show the sharding would change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment thread docs/tutorials/jax_training_tutorial.md Outdated
)

batch = next(iter(ds))
print(batch["image"].sharding)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

show a couple more batches..

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment thread docs/tutorials/jax_training_tutorial.md Outdated

for step, batch in zip(range(2), ds):
batch = jax.device_put(batch)
print(step, batch["image"].sharding)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Print shape as well?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment thread docs/tutorials/jax_training_tutorial.md Outdated
```{code-cell} ipython3
:id: jx-template-code

BATCH = 256
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: BATCH_SIZE

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@mtsokol mtsokol requested a review from yk5 June 4, 2026 10:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants