Skip to content

CoMMALab/dgrax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Diffusion Based Grasping in Jax (D-Grax)

This repository contains a JAX/Flax NNX implementation of SE(3)-DiffusionFields and Variational Shape Inference for Grasp Diffusion on SE(3)

Install as a Library

The grasp-synthesis model is exposed as a pip-installable package so it can be dropped into a planning pipeline (e.g. alongside pyroffi) without the demo or training tooling.

pip install -e .            # core: generate + score (jax, flax, numpy, pyyaml)

The model can be loaded through DGrax by downloading the weights and passing them into the module.

import numpy as np
from grasp_diffusion import DGrax

gd = DGrax.from_pretrained("runs-dir/jax_weights/ae_ckpt_best_jax.npz")

# pointcloud: (N, 3) object surface points in metres, in your own world frame
pointcloud = np.load("test-pc/object_points.npy")

# 1) Generate grasps — returns (n_grasps, 4, 4) SE(3) poses in the input frame
grasps = gd.generate(pointcloud, n_grasps=20)

# (optionally) also get the energy of each sampled grasp
grasps, energy = gd.generate(pointcloud, n_grasps=20, return_energy=True)

# 2) Score externally proposed grasps (lower energy = better)
energies = gd.score(pointcloud, grasps)            # (M,)
best = grasps[np.argsort(energies)]                # rank for a planner

The model works in an object-centred, data_scale-scaled frame; DGrax mean-centres the point cloud internally and returns/accepts grasps in your input frame (metres), so it composes directly with IK/motion planners. Pass center=False if your cloud is already object-centred. The underlying meta-module (for score-gradient/guided use, as in diffusion_ik.py) is available via gd.model.

Choosing a backend

DGrax supports two architecture: "VnnEbmSe3" or "VSIGD"

# VSIGD (VAE) backend -- config names its own weights:
gd = DGrax.from_config("configs/vsigd_full.yml")

# VSIGD additionally needs the object's mesh-normalization params (from your
# mesh2sdf preprocessing). Same generate()/score() surface otherwise:
dom = {"mesh_load_scale": s, "norm_scale": ns, "norm_center": nc}  # nc: (3,)
grasps = gd.generate(pointcloud, n_grasps=20, dom_params=dom)

For pipelines that already hold normalised, model-frame samples (e.g. dataset fixtures), generate_from_sample(sample, n) / score_from_sample(sample, grasps) operate directly on a sample dict (pc_surf plus, for VSIGD, the domain-param keys) and return model-frame results without coordinate conversion.

Repository Layout

.
|-- configs/
|   |-- train.yml              # training/model/data config
|   `-- sample.yml             # Langevin sampler override config
|-- data-dir/
|   |-- grasps/                # ACRONYM .h5 grasp files by object class
|   |-- meshes/                # object meshes by class
|   |-- mesh_sampled_pc/       # cached mesh point clouds produced by data_prep.py
|   `-- sdf/                   # SDF json/pickle sources and cached .npy filesshopt -s dotglob nullglob
mv * ..
cd ..
rmdir grasp-diffusion-jax
|   |-- curated.yml            # curated train/val source object inventory
|   |-- splitted.yml           # train/val split used by configs/train.yml
|   `-- curated-test.yml       # test split inventory
|-- lib_data/                  # dataset, gripper, visualization, mesh check helpers
|-- lib_nn/                    # JAX model, VNN encoder, losses, sampler, utilities
|-- runs-dir/
|   |-- checkpoints/           # training checkpoints
|   |-- jax_weights/           # flat converted JAX .npz weights for demos
|   `-- renders/               # rendered train/test grasp previews
|-- scripts/                   # utility scripts for conversion, inference, integrity checks
|-- test.py                    # evaluate a checkpoint on train/val/test split
|-- train.py                   # train the diffusion model
|-- visualize_grasps.py        # sample and visualize diffusion grasps in viser
|-- visualize_problem.py       # inspect robot + object + point-cloud setup
|-- visualize_test_data.py     # inspect bundled test-pc/test-grasps arrays
|-- diffusion_ik.py            # JAX diffusion-augmented IK solver
`-- example_diffusion_ik.py    # interactive Panda IK demo

References

The original implementation of SE(3)-DiffusionFields can be found here. The original implementation of Variational Shape Inference for Grasp Diffusion on SE(3) can be found here.

[1] J. Urain, N. Funk, J. Peters, and G. Chalvatzaki, SE(3)-DiffusionFields: Learning Smooth Cost Functions for Joint Grasp and Motion Optimization Through Diffusion, arXiv:2209.03855, 2023. Available: https://arxiv.org/abs/2209.03855

[2] S. T. Bukhari, K. Agrawal, Z. Kingston, and A. Bera, Variational Shape Inference for Grasp Diffusion on SE(3), IEEE Robotics and Automation Letters (RA-L), 2025. Available: https://arxiv.org/pdf/2508.17482

About

Jax implementation of grasp-diffusion

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages