This repository contains a JAX/Flax NNX implementation of SE(3)-DiffusionFields and Variational Shape Inference for Grasp Diffusion on SE(3)
The grasp-synthesis model is exposed as a pip-installable package so it can be
dropped into a planning pipeline (e.g. alongside pyroffi) without the demo or
training tooling.
pip install -e . # core: generate + score (jax, flax, numpy, pyyaml)The model can be loaded through DGrax by downloading the weights and passing them into the module.
import numpy as np
from grasp_diffusion import DGrax
gd = DGrax.from_pretrained("runs-dir/jax_weights/ae_ckpt_best_jax.npz")
# pointcloud: (N, 3) object surface points in metres, in your own world frame
pointcloud = np.load("test-pc/object_points.npy")
# 1) Generate grasps — returns (n_grasps, 4, 4) SE(3) poses in the input frame
grasps = gd.generate(pointcloud, n_grasps=20)
# (optionally) also get the energy of each sampled grasp
grasps, energy = gd.generate(pointcloud, n_grasps=20, return_energy=True)
# 2) Score externally proposed grasps (lower energy = better)
energies = gd.score(pointcloud, grasps) # (M,)
best = grasps[np.argsort(energies)] # rank for a plannerThe model works in an object-centred, data_scale-scaled frame; DGrax
mean-centres the point cloud internally and returns/accepts grasps in your input
frame (metres), so it composes directly with IK/motion planners. Pass
center=False if your cloud is already object-centred. The underlying meta-module
(for score-gradient/guided use, as in diffusion_ik.py) is available via gd.model.
DGrax supports two architecture: "VnnEbmSe3" or "VSIGD"
# VSIGD (VAE) backend -- config names its own weights:
gd = DGrax.from_config("configs/vsigd_full.yml")
# VSIGD additionally needs the object's mesh-normalization params (from your
# mesh2sdf preprocessing). Same generate()/score() surface otherwise:
dom = {"mesh_load_scale": s, "norm_scale": ns, "norm_center": nc} # nc: (3,)
grasps = gd.generate(pointcloud, n_grasps=20, dom_params=dom)For pipelines that already hold normalised, model-frame samples (e.g. dataset
fixtures), generate_from_sample(sample, n) / score_from_sample(sample, grasps)
operate directly on a sample dict (pc_surf plus, for VSIGD, the domain-param
keys) and return model-frame results without coordinate conversion.
.
|-- configs/
| |-- train.yml # training/model/data config
| `-- sample.yml # Langevin sampler override config
|-- data-dir/
| |-- grasps/ # ACRONYM .h5 grasp files by object class
| |-- meshes/ # object meshes by class
| |-- mesh_sampled_pc/ # cached mesh point clouds produced by data_prep.py
| `-- sdf/ # SDF json/pickle sources and cached .npy filesshopt -s dotglob nullglob
mv * ..
cd ..
rmdir grasp-diffusion-jax
| |-- curated.yml # curated train/val source object inventory
| |-- splitted.yml # train/val split used by configs/train.yml
| `-- curated-test.yml # test split inventory
|-- lib_data/ # dataset, gripper, visualization, mesh check helpers
|-- lib_nn/ # JAX model, VNN encoder, losses, sampler, utilities
|-- runs-dir/
| |-- checkpoints/ # training checkpoints
| |-- jax_weights/ # flat converted JAX .npz weights for demos
| `-- renders/ # rendered train/test grasp previews
|-- scripts/ # utility scripts for conversion, inference, integrity checks
|-- test.py # evaluate a checkpoint on train/val/test split
|-- train.py # train the diffusion model
|-- visualize_grasps.py # sample and visualize diffusion grasps in viser
|-- visualize_problem.py # inspect robot + object + point-cloud setup
|-- visualize_test_data.py # inspect bundled test-pc/test-grasps arrays
|-- diffusion_ik.py # JAX diffusion-augmented IK solver
`-- example_diffusion_ik.py # interactive Panda IK demo
The original implementation of SE(3)-DiffusionFields can be found here. The original implementation of Variational Shape Inference for Grasp Diffusion on SE(3) can be found here.
[1] J. Urain, N. Funk, J. Peters, and G. Chalvatzaki, SE(3)-DiffusionFields: Learning Smooth Cost Functions for Joint Grasp and Motion Optimization Through Diffusion, arXiv:2209.03855, 2023. Available: https://arxiv.org/abs/2209.03855
[2] S. T. Bukhari, K. Agrawal, Z. Kingston, and A. Bera, Variational Shape Inference for Grasp Diffusion on SE(3), IEEE Robotics and Automation Letters (RA-L), 2025. Available: https://arxiv.org/pdf/2508.17482