Skip to content

Add CuTe DSL JAX demo#3103

Merged
Junkai-Wu merged 1 commit intoNVIDIA:mainfrom
katjasrz:cute-dsl-jax-demo
Apr 2, 2026
Merged

Add CuTe DSL JAX demo#3103
Junkai-Wu merged 1 commit intoNVIDIA:mainfrom
katjasrz:cute-dsl-jax-demo

Conversation

@katjasrz
Copy link
Copy Markdown
Contributor

Summary

This PR adds a minimal CuTe DSL + JAX demo under examples/python/CuTeDSL/jax.

The demo shows how to define and invoke custom CuTe DSL kernels from JAX, providing a self-contained example that can serve as a starting point for experimentation and integration.

What This Adds

  • cute_dsl_jax.ipynb
    A walkthrough-style notebook demonstrating:

    • Defining CuTe DSL kernels
    • Calling them from JAX
    • Basic usage patterns and expected behavior
  • cute_dsl_jax_kernels.py
    Supporting kernel definitions used by the notebook.

Purpose

The goal is to provide:

  • A simple reference example for users exploring CuTe DSL + JAX
  • A starting template for extending CuTe DSL kernels within a JAX workflow
  • A lightweight demo that complements the existing Python examples

Notes

  • The example is self-contained and lives under examples/python/CuTeDSL/jax.
  • No changes to core functionality.
  • Intended primarily as an illustrative example.

@fengxie fengxie requested a review from Junkai-Wu March 23, 2026 02:04
Copy link
Copy Markdown
Collaborator

@brandon-yujie-sun brandon-yujie-sun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for adding the example and notebook!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make more sense to move this nb file to examples/python/CuTeDSL/notebooks, which would make the testing easier to have all nb files in consolidated place?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As agreed on another thread, let's keep the current location

Comment thread examples/python/CuTeDSL/jax/cute_dsl_jax.ipynb Outdated
Comment thread examples/python/CuTeDSL/jax/cute_dsl_jax.ipynb Outdated
@katjasrz katjasrz force-pushed the cute-dsl-jax-demo branch 2 times, most recently from 684d64a to f79740f Compare April 1, 2026 12:33
@katjasrz katjasrz force-pushed the cute-dsl-jax-demo branch from b36d022 to affc1e2 Compare April 1, 2026 12:39
@Junkai-Wu Junkai-Wu merged commit 418d38a into NVIDIA:main Apr 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants