diff --git a/.gitignore b/.gitignore index bdc7c14f..b2d46b8f 100644 --- a/.gitignore +++ b/.gitignore @@ -18,5 +18,8 @@ build HJCD-IK resources/learned_ik/* pRRTC -src/pyroffi/cuda_kernels/*.so -src/pyroffi/cuda_kernels/*.o \ No newline at end of file +src/pyroffi/cuda_kernels/**/*.so +src/pyroffi/cuda_kernels/**/*.o +src/pyroffi/cuda_kernels/**/*.ptx +NVIDIA*.sh +NVIDIA-OptiX-SDK-9.1.0-linux64-x86_64 \ No newline at end of file diff --git a/src/pyroffi/cuda_kernels/build_all.sh b/build_kernels/build_all.sh similarity index 88% rename from src/pyroffi/cuda_kernels/build_all.sh rename to build_kernels/build_all.sh index bd6dc2ff..33e6e33c 100755 --- a/src/pyroffi/cuda_kernels/build_all.sh +++ b/build_kernels/build_all.sh @@ -2,11 +2,11 @@ # Build all CUDA kernels. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_all.sh -# bash src/pyroffi/cuda_kernels/build_all.sh --debug +# bash build_kernels/build_all.sh +# bash build_kernels/build_all.sh --debug # # Override GPU arch for all kernels: -# GPU_ARCH=-arch=sm_80 bash src/pyroffi/cuda_kernels/build_all.sh +# GPU_ARCH=-arch=sm_80 bash build_kernels/build_all.sh set -euo pipefail @@ -65,5 +65,7 @@ bash "${SCRIPT_DIR}/build_sco_trajopt_cuda.sh" "${BUILD_ARGS[@]}" bash "${SCRIPT_DIR}/build_ls_trajopt_cuda.sh" "${BUILD_ARGS[@]}" bash "${SCRIPT_DIR}/build_chomp_trajopt_cuda.sh" "${BUILD_ARGS[@]}" bash "${SCRIPT_DIR}/build_stomp_trajopt_cuda.sh" "${BUILD_ARGS[@]}" +bash "${SCRIPT_DIR}/build_robogpu_collision.sh" "${BUILD_ARGS[@]}" +bash "${SCRIPT_DIR}/build_cricket_jit.sh" echo "All CUDA kernels built successfully." diff --git a/src/pyroffi/cuda_kernels/build_brownian_motion_ik_cuda.sh b/build_kernels/build_brownian_motion_ik_cuda.sh similarity index 82% rename from src/pyroffi/cuda_kernels/build_brownian_motion_ik_cuda.sh rename to build_kernels/build_brownian_motion_ik_cuda.sh index 70e41fd2..fdf0c224 100644 --- a/src/pyroffi/cuda_kernels/build_brownian_motion_ik_cuda.sh +++ b/build_kernels/build_brownian_motion_ik_cuda.sh @@ -2,8 +2,8 @@ # Build _brownian_motion_ik_cuda_lib.so from _brownian_motion_ik_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_brownian_motion_ik_cuda.sh -# bash src/pyroffi/cuda_kernels/build_brownian_motion_ik_cuda.sh --debug +# bash build_kernels/build_brownian_motion_ik_cuda.sh +# bash build_kernels/build_brownian_motion_ik_cuda.sh --debug set -euo pipefail @@ -45,8 +45,9 @@ if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_brownian_motion_ik_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_brownian_motion_ik_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/region_ik/_brownian_motion_ik_cuda_kernel.cu" +OUT="${KERNELS_DIR}/region_ik/_brownian_motion_ik_cuda_lib.so" JAXLIB_INC="$(python -c "import os, jaxlib; print(os.path.join(os.path.dirname(jaxlib.__file__), 'include'))")" @@ -71,8 +72,9 @@ nvcc \ ${GPU_ARCH} \ --shared \ --compiler-options "-fPIC" \ - -I"${SCRIPT_DIR}" \ + -I"${KERNELS_DIR}" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/src/pyroffi/cuda_kernels/build_chomp_trajopt_cuda.sh b/build_kernels/build_chomp_trajopt_cuda.sh similarity index 84% rename from src/pyroffi/cuda_kernels/build_chomp_trajopt_cuda.sh rename to build_kernels/build_chomp_trajopt_cuda.sh index 99539a50..5474f443 100755 --- a/src/pyroffi/cuda_kernels/build_chomp_trajopt_cuda.sh +++ b/build_kernels/build_chomp_trajopt_cuda.sh @@ -2,8 +2,8 @@ # Build _chomp_trajopt_cuda_lib.so from _chomp_trajopt_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_chomp_trajopt_cuda.sh -# bash src/pyroffi/cuda_kernels/build_chomp_trajopt_cuda.sh --debug +# bash build_kernels/build_chomp_trajopt_cuda.sh +# bash build_kernels/build_chomp_trajopt_cuda.sh --debug set -euo pipefail @@ -45,8 +45,9 @@ if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_chomp_trajopt_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_chomp_trajopt_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/trajopt/_chomp_trajopt_cuda_kernel.cu" +OUT="${KERNELS_DIR}/trajopt/_chomp_trajopt_cuda_lib.so" JAXLIB_INC="$(python -c \ "import os, jaxlib; print(os.path.join(os.path.dirname(jaxlib.__file__), 'include'))")" @@ -73,6 +74,7 @@ nvcc \ --shared \ --compiler-options "-fPIC" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/src/pyroffi/cuda_kernels/build_collision_binary_cuda.sh b/build_kernels/build_collision_binary_cuda.sh similarity index 86% rename from src/pyroffi/cuda_kernels/build_collision_binary_cuda.sh rename to build_kernels/build_collision_binary_cuda.sh index 5dcae7ac..3dc60db7 100755 --- a/src/pyroffi/cuda_kernels/build_collision_binary_cuda.sh +++ b/build_kernels/build_collision_binary_cuda.sh @@ -2,8 +2,8 @@ # Build _collision_binary_cuda_lib.so from _collision_binary_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_collision_binary_cuda.sh -# bash src/pyroffi/cuda_kernels/build_collision_binary_cuda.sh --debug +# bash build_kernels/build_collision_binary_cuda.sh +# bash build_kernels/build_collision_binary_cuda.sh --debug # # Requirements: # - nvcc (CUDA toolkit) @@ -51,8 +51,9 @@ if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_collision_binary_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_collision_binary_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/collision/_collision_binary_cuda_kernel.cu" +OUT="${KERNELS_DIR}/collision/_collision_binary_cuda_lib.so" # Locate the jaxlib include directory that ships xla/ffi/api/ffi.h. JAXLIB_INC="$(python -c \ @@ -81,6 +82,7 @@ nvcc \ --shared \ --compiler-options "-fPIC" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/src/pyroffi/cuda_kernels/build_collision_cuda.sh b/build_kernels/build_collision_cuda.sh similarity index 86% rename from src/pyroffi/cuda_kernels/build_collision_cuda.sh rename to build_kernels/build_collision_cuda.sh index c336a176..0ff00e1a 100755 --- a/src/pyroffi/cuda_kernels/build_collision_cuda.sh +++ b/build_kernels/build_collision_cuda.sh @@ -2,8 +2,8 @@ # Build _collision_cuda_lib.so from _collision_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_collision_cuda.sh -# bash src/pyroffi/cuda_kernels/build_collision_cuda.sh --debug +# bash build_kernels/build_collision_cuda.sh +# bash build_kernels/build_collision_cuda.sh --debug # # Requirements: # - nvcc (CUDA toolkit) @@ -51,8 +51,9 @@ if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_collision_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_collision_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/collision/_collision_cuda_kernel.cu" +OUT="${KERNELS_DIR}/collision/_collision_cuda_lib.so" # Locate the jaxlib include directory that ships xla/ffi/api/ffi.h. JAXLIB_INC="$(python -c \ @@ -81,6 +82,7 @@ nvcc \ --shared \ --compiler-options "-fPIC" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/build_kernels/build_cricket_jit.sh b/build_kernels/build_cricket_jit.sh new file mode 100644 index 00000000..f1c6cc22 --- /dev/null +++ b/build_kernels/build_cricket_jit.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +# Build + install cricket with the Python extension AND the runtime JIT enabled, +# so pyroffi's VAMPCPUCollisionChecker can JIT-compile robot-specialised VAMP +# collision checkers at runtime. +# +# This mirrors the exact, verified setup (conda-forge deps + scikit-build-core). +# Run it inside the target conda env, e.g.: +# +# conda activate pyroffi +# bash build_kernels/build_cricket_jit.sh +# +# Dependencies are taken from cricket/environment.yaml. The JIT additionally +# needs a `clang` binary on PATH at *runtime* (the JIT driver shells out to it to +# discover system headers) — clangdev provides it inside the env. +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +CRICKET_DIR="${REPO_ROOT}/external/cricket" + +if [[ -z "${CONDA_PREFIX:-}" ]]; then + echo "ERROR: activate the target conda env first (conda activate pyroffi)." >&2 + exit 1 +fi + +# 1. Install build/runtime dependencies (no-op if already present). +conda install -c conda-forge --solver=libmamba -y \ + pinocchio cppad eigen cgal-cpp nlohmann_json fmt \ + llvmdev clangdev lld cxx-compiler ninja pkg-config patch \ + nanobind scikit-build-core + +if ! command -v clang >/dev/null 2>&1; then + echo "ERROR: clang still not on PATH after install." >&2 + exit 1 +fi + +# 2. Build + install the cricket Python extension (JIT + Python both ON). +export CMAKE_PREFIX_PATH="${CONDA_PREFIX}:${CMAKE_PREFIX_PATH:-}" +export CMAKE_ARGS="-DCMAKE_PREFIX_PATH=${CONDA_PREFIX} -DCRICKET_BUILD_JIT=ON -DCRICKET_BUILD_PYTHON=ON" +pip install -e "${CRICKET_DIR}" --no-build-isolation + +echo +echo "Done. Verify with:" +echo " python -c 'from cricket import _core_ext as e; print(e.jit.JitSession)'" +echo " python -m pytest tests/test_vamp_cpu_collision.py -s" diff --git a/src/pyroffi/cuda_kernels/build_fk_cuda.sh b/build_kernels/build_fk_cuda.sh similarity index 88% rename from src/pyroffi/cuda_kernels/build_fk_cuda.sh rename to build_kernels/build_fk_cuda.sh index 25860cce..e3fba6bd 100755 --- a/src/pyroffi/cuda_kernels/build_fk_cuda.sh +++ b/build_kernels/build_fk_cuda.sh @@ -2,8 +2,8 @@ # Build _fk_cuda.so from _fk_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_fk_cuda.sh -# bash src/pyroffi/cuda_kernels/build_fk_cuda.sh --debug +# bash build_kernels/build_fk_cuda.sh +# bash build_kernels/build_fk_cuda.sh --debug # # Requirements: # - nvcc (CUDA toolkit) @@ -50,8 +50,9 @@ if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_fk_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_fk_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/fk/_fk_cuda_kernel.cu" +OUT="${KERNELS_DIR}/fk/_fk_cuda_lib.so" # Locate the jaxlib include directory that ships xla/ffi/api/ffi.h. JAXLIB_INC="$(python -c \ @@ -82,6 +83,7 @@ nvcc \ --shared \ --compiler-options "-fPIC" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/src/pyroffi/cuda_kernels/build_hit_and_run_ik_cuda.sh b/build_kernels/build_hit_and_run_ik_cuda.sh similarity index 83% rename from src/pyroffi/cuda_kernels/build_hit_and_run_ik_cuda.sh rename to build_kernels/build_hit_and_run_ik_cuda.sh index 3cac7654..1ce4b524 100755 --- a/src/pyroffi/cuda_kernels/build_hit_and_run_ik_cuda.sh +++ b/build_kernels/build_hit_and_run_ik_cuda.sh @@ -2,8 +2,8 @@ # Build _hit_and_run_ik_cuda_lib.so from _hit_and_run_ik_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_hit_and_run_ik_cuda.sh -# bash src/pyroffi/cuda_kernels/build_hit_and_run_ik_cuda.sh --debug +# bash build_kernels/build_hit_and_run_ik_cuda.sh +# bash build_kernels/build_hit_and_run_ik_cuda.sh --debug set -euo pipefail @@ -45,8 +45,9 @@ if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_hit_and_run_ik_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_hit_and_run_ik_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/region_ik/_hit_and_run_ik_cuda_kernel.cu" +OUT="${KERNELS_DIR}/region_ik/_hit_and_run_ik_cuda_lib.so" JAXLIB_INC="$(python -c "import os, jaxlib; print(os.path.join(os.path.dirname(jaxlib.__file__), 'include'))")" @@ -71,8 +72,9 @@ nvcc \ ${GPU_ARCH} \ --shared \ --compiler-options "-fPIC" \ - -I"${SCRIPT_DIR}" \ + -I"${KERNELS_DIR}" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/src/pyroffi/cuda_kernels/build_hjcd_ik_cuda.sh b/build_kernels/build_hjcd_ik_cuda.sh similarity index 87% rename from src/pyroffi/cuda_kernels/build_hjcd_ik_cuda.sh rename to build_kernels/build_hjcd_ik_cuda.sh index 25d9aaf6..f0af86ce 100755 --- a/src/pyroffi/cuda_kernels/build_hjcd_ik_cuda.sh +++ b/build_kernels/build_hjcd_ik_cuda.sh @@ -2,8 +2,8 @@ # Build _hjcd_ik_cuda_lib.so from _hjcd_ik_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_hjcd_ik_cuda.sh -# bash src/pyroffi/cuda_kernels/build_hjcd_ik_cuda.sh --debug +# bash build_kernels/build_hjcd_ik_cuda.sh +# bash build_kernels/build_hjcd_ik_cuda.sh --debug # # Requirements: # - nvcc (CUDA toolkit) @@ -50,8 +50,9 @@ if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_hjcd_ik_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_hjcd_ik_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/ik/_hjcd_ik_cuda_kernel.cu" +OUT="${KERNELS_DIR}/ik/_hjcd_ik_cuda_lib.so" # Locate the jaxlib include directory that ships xla/ffi/api/ffi.h. JAXLIB_INC="$(python -c \ @@ -81,8 +82,9 @@ nvcc \ ${GPU_ARCH} \ --shared \ --compiler-options "-fPIC" \ - -I"${SCRIPT_DIR}" \ + -I"${KERNELS_DIR}" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/src/pyroffi/cuda_kernels/build_ls_ik_cuda.sh b/build_kernels/build_ls_ik_cuda.sh similarity index 86% rename from src/pyroffi/cuda_kernels/build_ls_ik_cuda.sh rename to build_kernels/build_ls_ik_cuda.sh index f8f7077e..d8dec59d 100755 --- a/src/pyroffi/cuda_kernels/build_ls_ik_cuda.sh +++ b/build_kernels/build_ls_ik_cuda.sh @@ -2,8 +2,8 @@ # Build _ls_ik_cuda_lib.so from _ls_ik_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_ls_ik_cuda.sh -# bash src/pyroffi/cuda_kernels/build_ls_ik_cuda.sh --debug +# bash build_kernels/build_ls_ik_cuda.sh +# bash build_kernels/build_ls_ik_cuda.sh --debug # # Requirements: # - nvcc (CUDA toolkit) @@ -50,8 +50,9 @@ if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_ls_ik_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_ls_ik_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/ik/_ls_ik_cuda_kernel.cu" +OUT="${KERNELS_DIR}/ik/_ls_ik_cuda_lib.so" JAXLIB_INC="$(python -c \ "import os, jaxlib; print(os.path.join(os.path.dirname(jaxlib.__file__), 'include'))")" @@ -77,8 +78,9 @@ nvcc \ ${GPU_ARCH} \ --shared \ --compiler-options "-fPIC" \ - -I"${SCRIPT_DIR}" \ + -I"${KERNELS_DIR}" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/src/pyroffi/cuda_kernels/build_ls_trajopt_cuda.sh b/build_kernels/build_ls_trajopt_cuda.sh similarity index 83% rename from src/pyroffi/cuda_kernels/build_ls_trajopt_cuda.sh rename to build_kernels/build_ls_trajopt_cuda.sh index 03aaae69..daed440b 100755 --- a/src/pyroffi/cuda_kernels/build_ls_trajopt_cuda.sh +++ b/build_kernels/build_ls_trajopt_cuda.sh @@ -2,8 +2,8 @@ # Build _ls_trajopt_cuda_lib.so from _ls_trajopt_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_ls_trajopt_cuda.sh -# bash src/pyroffi/cuda_kernels/build_ls_trajopt_cuda.sh --debug +# bash build_kernels/build_ls_trajopt_cuda.sh +# bash build_kernels/build_ls_trajopt_cuda.sh --debug set -euo pipefail @@ -45,8 +45,9 @@ if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_ls_trajopt_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_ls_trajopt_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/trajopt/_ls_trajopt_cuda_kernel.cu" +OUT="${KERNELS_DIR}/trajopt/_ls_trajopt_cuda_lib.so" JAXLIB_INC="$(python -c \ "import os, jaxlib; print(os.path.join(os.path.dirname(jaxlib.__file__), 'include'))")" @@ -72,8 +73,9 @@ nvcc \ ${GPU_ARCH} \ --shared \ --compiler-options "-fPIC" \ - -I"${SCRIPT_DIR}" \ + -I"${KERNELS_DIR}" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/src/pyroffi/cuda_kernels/build_mppi_ik_cuda.sh b/build_kernels/build_mppi_ik_cuda.sh similarity index 85% rename from src/pyroffi/cuda_kernels/build_mppi_ik_cuda.sh rename to build_kernels/build_mppi_ik_cuda.sh index d410fa14..d013210e 100755 --- a/src/pyroffi/cuda_kernels/build_mppi_ik_cuda.sh +++ b/build_kernels/build_mppi_ik_cuda.sh @@ -2,8 +2,8 @@ # Build _mppi_ik_cuda_lib.so from _mppi_ik_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_mppi_ik_cuda.sh -# bash src/pyroffi/cuda_kernels/build_mppi_ik_cuda.sh --debug +# bash build_kernels/build_mppi_ik_cuda.sh +# bash build_kernels/build_mppi_ik_cuda.sh --debug # # Requirements: # - nvcc (CUDA toolkit) @@ -50,8 +50,9 @@ if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_mppi_ik_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_mppi_ik_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/ik/_mppi_ik_cuda_kernel.cu" +OUT="${KERNELS_DIR}/ik/_mppi_ik_cuda_lib.so" JAXLIB_INC="$(python -c \ "import os, jaxlib; print(os.path.join(os.path.dirname(jaxlib.__file__), 'include'))")" @@ -77,8 +78,9 @@ nvcc \ ${GPU_ARCH} \ --shared \ --compiler-options "-fPIC" \ - -I"${SCRIPT_DIR}" \ + -I"${KERNELS_DIR}" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/build_kernels/build_robogpu_collision.sh b/build_kernels/build_robogpu_collision.sh new file mode 100755 index 00000000..c2c42af3 --- /dev/null +++ b/build_kernels/build_robogpu_collision.sh @@ -0,0 +1,180 @@ +#!/usr/bin/env bash +# Build the RoboGPU OptiX sphere-octree collision checker. +# +# Produces two files in src/pyroffi/cuda_kernels/: +# _robogpu_optix_programs.ptx — OptiX device programs (ray gen / intersection +# / any-hit / miss), loaded at runtime. +# _robogpu_collision_lib.so — Host library with FK CUDA kernel + OptiX +# pipeline management + XLA FFI handler. +# +# Usage (from repo root): +# bash build_kernels/build_robogpu_collision.sh +# bash build_kernels/build_robogpu_collision.sh --debug +# bash build_kernels/build_robogpu_collision.sh --max-joints 128 +# +# Requirements: +# - nvcc (CUDA toolkit >= 11.2 for cudaMallocAsync) +# - NVIDIA OptiX SDK 7.x (set OPTIX_SDK or install to a standard path) +# - jaxlib >= 0.4.14 (provides xla/ffi/api/ffi.h headers) +# +# Optional environment variables: +# OPTIX_SDK Path to the OptiX SDK root (contains include/optix.h). +# If unset, common paths are searched automatically. +# GPU_ARCH nvcc architecture flag, e.g. -arch=sm_86 (default: -arch=native). +# Must be sm_50 or newer (OptiX 7 requirement). + +set -euo pipefail + +DEBUG=0 +MAX_JOINTS_OVERRIDE="" + +while [[ $# -gt 0 ]]; do + case "$1" in + --debug) + DEBUG=1; shift ;; + --max-joints) + [[ $# -lt 2 ]] && { echo "ERROR: --max-joints requires a value"; exit 1; } + MAX_JOINTS_OVERRIDE="$2"; shift 2 ;; + --max-joints=*) + MAX_JOINTS_OVERRIDE="${1#*=}"; shift ;; + *) + echo "ERROR: Unknown argument: $1"; exit 1 ;; + esac +done + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" + +DEVICE_SRC="${KERNELS_DIR}/collision/_robogpu_optix_programs.cu" +HOST_SRC="${KERNELS_DIR}/collision/_robogpu_collision_host.cu" +PTX_OUT="${KERNELS_DIR}/collision/_robogpu_optix_programs.ptx" +SO_OUT="${KERNELS_DIR}/collision/_robogpu_collision_lib.so" + +# ── Locate jaxlib XLA FFI headers ────────────────────────────────────────── + +JAXLIB_INC="$(python -c \ + "import os, jaxlib; print(os.path.join(os.path.dirname(jaxlib.__file__), 'include'))")" + +if [ ! -f "${JAXLIB_INC}/xla/ffi/api/ffi.h" ]; then + echo "ERROR: xla/ffi/api/ffi.h not found under ${JAXLIB_INC}" + echo "Make sure jaxlib >= 0.4.14 is installed." + exit 1 +fi + +# ── Locate OptiX SDK ──────────────────────────────────────────────────────── + +find_optix_sdk() { + # 1. Explicit env var + if [[ -n "${OPTIX_SDK:-}" && -f "${OPTIX_SDK}/include/optix.h" ]]; then + echo "${OPTIX_SDK}"; return 0 + fi + # 2. Common install paths + for p in \ + /usr/local/optix \ + /opt/NVIDIA-OptiX-SDK* \ + "${HOME}/NVIDIA-OptiX-SDK"* \ + "${SCRIPT_DIR}/../NVIDIA-OptiX-SDK"* \ + /usr/local/cuda \ + ; do + # glob expansion in bash 'for' already handles wildcards + if [[ -f "${p}/include/optix.h" ]]; then + echo "${p}"; return 0 + fi + done + # 3. Scan PATH for optixNamespace.h peer + local pp + for pp in $(tr ':' '\n' <<< "${PATH:-}"); do + local candidate + for candidate in "${pp}/../include" "${pp}/../../include"; do + if [[ -f "${candidate}/optix.h" ]]; then + realpath "${candidate}/.." 2>/dev/null && return 0 + fi + done + done + return 1 +} + +OPTIX_ROOT="" +if OPTIX_ROOT="$(find_optix_sdk)"; then + echo "OptiX SDK found: ${OPTIX_ROOT}" +else + echo "ERROR: NVIDIA OptiX SDK 7.x not found." + echo "Install it and set OPTIX_SDK=/path/to/optix or pass it to PATH." + echo "Download: https://developer.nvidia.com/designworks/optix/download" + exit 1 +fi + +OPTIX_INC="${OPTIX_ROOT}/include" + +# ── GPU architecture flag ─────────────────────────────────────────────────── + +GPU_ARCH="${GPU_ARCH:--arch=native}" + +# ── Build flags ───────────────────────────────────────────────────────────── + +if [ "${DEBUG}" -eq 1 ]; then + NVCC_OPT="-O0 -G -lineinfo" + PTX_OPT="-O0 -G" + echo "Building in DEBUG mode (-G / lineinfo)..." +else + NVCC_OPT="-O3" + PTX_OPT="-O3" +fi + +EXTRA_DEFS="" +if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then + EXTRA_DEFS="-DRGB_MAX_JOINTS=${MAX_JOINTS_OVERRIDE} -DRGB_MAX_LINKS=${MAX_JOINTS_OVERRIDE}" + echo "Custom bounds: RGB_MAX_JOINTS/LINKS=${MAX_JOINTS_OVERRIDE}" +fi + +# ── Step 1: Compile OptiX device programs to PTX ──────────────────────────── +# The PTX is loaded at runtime by the host library via optixModuleCreate. + +echo "" +echo "Step 1: Compiling OptiX device programs → PTX" +echo " Source : ${DEVICE_SRC}" +echo " Output : ${PTX_OUT}" + +nvcc \ + ${PTX_OPT} \ + -std=c++17 \ + ${GPU_ARCH} \ + --ptx \ + -I"${OPTIX_INC}" \ + -o "${PTX_OUT}" \ + "${DEVICE_SRC}" + +echo " OK: ${PTX_OUT}" + +# ── Step 2: Compile host code → shared library ────────────────────────────── +# The host code contains: FK + sphere-transform CUDA kernel, OptiX pipeline +# management (optix_stubs.h), BVH build/cache, and the XLA FFI handler. + +echo "" +echo "Step 2: Compiling host library → .so" +echo " Source : ${HOST_SRC}" +echo " Output : ${SO_OUT}" + +nvcc \ + ${NVCC_OPT} \ + -std=c++17 \ + ${GPU_ARCH} \ + --shared \ + --compiler-options "-fPIC" \ + -I"${JAXLIB_INC}" \ + -I"${OPTIX_INC}" \ + -I"${KERNELS_DIR}" \ + ${EXTRA_DEFS} \ + -ldl \ + -o "${SO_OUT}" \ + "${HOST_SRC}" + +echo " OK: ${SO_OUT}" + +echo "" +echo "Build complete." +echo " PTX : ${PTX_OUT}" +echo " SO : ${SO_OUT}" +echo "" +echo "Both files must reside in the same directory at runtime" +echo "(the host library locates the PTX file using dladdr)." diff --git a/src/pyroffi/cuda_kernels/build_sco_trajopt_cuda.sh b/build_kernels/build_sco_trajopt_cuda.sh similarity index 87% rename from src/pyroffi/cuda_kernels/build_sco_trajopt_cuda.sh rename to build_kernels/build_sco_trajopt_cuda.sh index 34012020..6a2b4a07 100755 --- a/src/pyroffi/cuda_kernels/build_sco_trajopt_cuda.sh +++ b/build_kernels/build_sco_trajopt_cuda.sh @@ -2,8 +2,8 @@ # Build _sco_trajopt_cuda_lib.so from _sco_trajopt_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_sco_trajopt_cuda.sh -# bash src/pyroffi/cuda_kernels/build_sco_trajopt_cuda.sh --debug +# bash build_kernels/build_sco_trajopt_cuda.sh +# bash build_kernels/build_sco_trajopt_cuda.sh --debug # # Requirements: # - nvcc (CUDA toolkit) @@ -50,8 +50,9 @@ if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_sco_trajopt_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_sco_trajopt_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/trajopt/_sco_trajopt_cuda_kernel.cu" +OUT="${KERNELS_DIR}/trajopt/_sco_trajopt_cuda_lib.so" # Locate the jaxlib include directory that ships xla/ffi/api/ffi.h. JAXLIB_INC="$(python -c \ @@ -82,6 +83,7 @@ nvcc \ --shared \ --compiler-options "-fPIC" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/src/pyroffi/cuda_kernels/build_sqp_ik_cuda.sh b/build_kernels/build_sqp_ik_cuda.sh similarity index 85% rename from src/pyroffi/cuda_kernels/build_sqp_ik_cuda.sh rename to build_kernels/build_sqp_ik_cuda.sh index 488f4596..b4a92125 100755 --- a/src/pyroffi/cuda_kernels/build_sqp_ik_cuda.sh +++ b/build_kernels/build_sqp_ik_cuda.sh @@ -2,8 +2,8 @@ # Build _sqp_ik_cuda_lib.so from _sqp_ik_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_sqp_ik_cuda.sh -# bash src/pyroffi/cuda_kernels/build_sqp_ik_cuda.sh --debug +# bash build_kernels/build_sqp_ik_cuda.sh +# bash build_kernels/build_sqp_ik_cuda.sh --debug # # Requirements: # - nvcc (CUDA toolkit) @@ -50,8 +50,9 @@ if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_sqp_ik_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_sqp_ik_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/ik/_sqp_ik_cuda_kernel.cu" +OUT="${KERNELS_DIR}/ik/_sqp_ik_cuda_lib.so" JAXLIB_INC="$(python -c \ "import os, jaxlib; print(os.path.join(os.path.dirname(jaxlib.__file__), 'include'))")" @@ -77,8 +78,9 @@ nvcc \ ${GPU_ARCH} \ --shared \ --compiler-options "-fPIC" \ - -I"${SCRIPT_DIR}" \ + -I"${KERNELS_DIR}" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/src/pyroffi/cuda_kernels/build_stomp_trajopt_cuda.sh b/build_kernels/build_stomp_trajopt_cuda.sh similarity index 86% rename from src/pyroffi/cuda_kernels/build_stomp_trajopt_cuda.sh rename to build_kernels/build_stomp_trajopt_cuda.sh index 66e47adf..7e117ca3 100755 --- a/src/pyroffi/cuda_kernels/build_stomp_trajopt_cuda.sh +++ b/build_kernels/build_stomp_trajopt_cuda.sh @@ -2,8 +2,8 @@ # Build _stomp_trajopt_cuda_lib.so from _stomp_trajopt_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_stomp_trajopt_cuda.sh -# bash src/pyroffi/cuda_kernels/build_stomp_trajopt_cuda.sh --debug +# bash build_kernels/build_stomp_trajopt_cuda.sh +# bash build_kernels/build_stomp_trajopt_cuda.sh --debug # # Requirements: # - nvcc (CUDA toolkit) @@ -50,8 +50,9 @@ if [[ -n "${MAX_JOINTS_OVERRIDE}" ]]; then fi SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_stomp_trajopt_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_stomp_trajopt_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/trajopt/_stomp_trajopt_cuda_kernel.cu" +OUT="${KERNELS_DIR}/trajopt/_stomp_trajopt_cuda_lib.so" # Locate the jaxlib include directory that ships xla/ffi/api/ffi.h. JAXLIB_INC="$(python -c \ @@ -82,6 +83,7 @@ nvcc \ --shared \ --compiler-options "-fPIC" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/src/pyroffi/cuda_kernels/build_svgd_region_ik_cuda.sh b/build_kernels/build_svgd_region_ik_cuda.sh similarity index 71% rename from src/pyroffi/cuda_kernels/build_svgd_region_ik_cuda.sh rename to build_kernels/build_svgd_region_ik_cuda.sh index 2b39a67a..2aa0c2fd 100755 --- a/src/pyroffi/cuda_kernels/build_svgd_region_ik_cuda.sh +++ b/build_kernels/build_svgd_region_ik_cuda.sh @@ -2,13 +2,14 @@ # Build _svgd_region_ik_cuda_lib.so from _svgd_region_ik_cuda_kernel.cu. # # Usage (from repo root): -# bash src/pyroffi/cuda_kernels/build_svgd_region_ik_cuda.sh +# bash build_kernels/build_svgd_region_ik_cuda.sh set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -SRC="${SCRIPT_DIR}/_svgd_region_ik_cuda_kernel.cu" -OUT="${SCRIPT_DIR}/_svgd_region_ik_cuda_lib.so" +KERNELS_DIR="$(cd "${SCRIPT_DIR}/../src/pyroffi/cuda_kernels" && pwd)" +SRC="${KERNELS_DIR}/region_ik/_svgd_region_ik_cuda_kernel.cu" +OUT="${KERNELS_DIR}/region_ik/_svgd_region_ik_cuda_lib.so" JAXLIB_INC="$(python -c " import os, jaxlib @@ -29,8 +30,9 @@ nvcc \ ${GPU_ARCH} \ --shared \ --compiler-options "-fPIC" \ - -I"${SCRIPT_DIR}" \ + -I"${KERNELS_DIR}" \ -I"${JAXLIB_INC}" \ + -I"${KERNELS_DIR}" \ -o "${OUT}" \ "${SRC}" diff --git a/docs/collision_cuda_vs_jax.md b/docs/collision_cuda_vs_jax.md index 41bf3c0d..4d163ef8 100644 --- a/docs/collision_cuda_vs_jax.md +++ b/docs/collision_cuda_vs_jax.md @@ -393,8 +393,8 @@ Previously the JAX FK was called *inside* the vmap — once per batch element. N The CUDA backend requires two compiled shared libraries: ```bash -bash src/pyroffi/cuda_kernels/build_fk_cuda.sh # _fk_cuda_lib.so -bash src/pyroffi/cuda_kernels/build_collision_cuda.sh # _collision_cuda_lib.so +bash build_kernels/build_fk_cuda.sh # _fk_cuda_lib.so +bash build_kernels/build_collision_cuda.sh # _collision_cuda_lib.so ``` --- diff --git a/examples/12_00_box_region_ik_cuda_brownian.py b/examples/12_00_box_region_ik_cuda_brownian.py index da6207f1..22f51698 100644 --- a/examples/12_00_box_region_ik_cuda_brownian.py +++ b/examples/12_00_box_region_ik_cuda_brownian.py @@ -1,7 +1,7 @@ """Sample Panda IK configurations with end-effector coverage inside a box region. Prerequisite: - bash src/pyroffi/cuda_kernels/build_brownian_motion_ik_cuda.sh + bash build_kernels/build_brownian_motion_ik_cuda.sh """ from __future__ import annotations diff --git a/examples/12_01_box_region_ik_cuda_hit_and_run.py b/examples/12_01_box_region_ik_cuda_hit_and_run.py index 9cff4e09..249bcea1 100644 --- a/examples/12_01_box_region_ik_cuda_hit_and_run.py +++ b/examples/12_01_box_region_ik_cuda_hit_and_run.py @@ -1,7 +1,7 @@ """Sample Panda IK configurations with end-effector coverage inside a box region using hit-and-run sampling. Prerequisite: - bash src/pyroffi/cuda_kernels/build_hit_and_run_ik_cuda.sh + bash build_kernels/build_hit_and_run_ik_cuda.sh """ from __future__ import annotations diff --git a/examples/12_02_box_region_ik_cuda_svgd.py b/examples/12_02_box_region_ik_cuda_svgd.py index c61a8e10..068c4e01 100644 --- a/examples/12_02_box_region_ik_cuda_svgd.py +++ b/examples/12_02_box_region_ik_cuda_svgd.py @@ -1,7 +1,7 @@ """Sample Panda IK configurations with end-effector coverage inside a box region using SVGD. Prerequisite: - bash src/pyroffi/cuda_kernels/build_svgd_region_ik_cuda.sh + bash build_kernels/build_svgd_region_ik_cuda.sh """ from __future__ import annotations diff --git a/examples/13_00_vamp_cpu_collision.py b/examples/13_00_vamp_cpu_collision.py new file mode 100644 index 00000000..dd8341c8 --- /dev/null +++ b/examples/13_00_vamp_cpu_collision.py @@ -0,0 +1,77 @@ +"""Example: CPU collision checking with the JIT-compiled VAMP backend. + +``VAMPCPUCollisionChecker`` specialises VAMP's SIMD ``fkcc`` collision routine to +a concrete robot at runtime: cricket parses the URDF, emits a +``vamp::robots::`` struct, and JIT-compiles a binary collision checker for +it (cached on disk for reuse). Forward kinematics is baked into the binary, so +you only pass joint configurations — no pre-built pyroffi collision model. + +This script: + 1. builds the checker for the Panda (first call compiles + caches; later calls + reuse the cached binary), + 2. checks a batch of random configurations against a Sphere + Box world, + 3. reports how many are collision-free and the throughput. + +Run (inside the `pyroffi` conda env, with cricket built — see +build_kernels/build_cricket_jit.sh): + + python examples/13_00_vamp_cpu_collision.py +""" + +import os +import sys +import time +from pathlib import Path + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) + +import jax.numpy as jnp +import numpy as np + +from pyroffi.collision import Box, Sphere, VAMPCPUCollisionChecker + +REPO_ROOT = Path(__file__).resolve().parents[1] +SPHERIZED_URDF = REPO_ROOT / "resources" / "panda" / "panda_spherized.urdf" +SRDF = REPO_ROOT / "resources" / "panda" / "panda.srdf" + + +def main() -> None: + print("Building VAMP CPU collision checker (first run JIT-compiles, then caches)...") + t0 = time.perf_counter() + checker = VAMPCPUCollisionChecker(SPHERIZED_URDF, srdf_path=SRDF) + print(f" ready in {time.perf_counter() - t0:.2f}s " + f"(dim={checker.dimension}, n_spheres={checker.n_spheres})") + + n = checker.dimension + rng = np.random.RandomState(0) + cfg = jnp.asarray(rng.uniform(-1.5, 1.5, size=(4096, n)), dtype=jnp.float32) + + # A mixed world: a few spheres and a box in the arm's workspace. + spheres = Sphere.from_center_and_radius( + center=jnp.array([[0.35, 0.0, 0.6], [0.0, 0.4, 0.7], [-0.3, 0.0, 0.5]]), + radius=jnp.array([0.13, 0.12, 0.12]), + ) + box = Box.from_center_and_half_lengths( + center=jnp.array([[0.4, 0.0, 0.4]]), + half_lengths=jnp.array([[0.1, 0.4, 0.1]]), + ) + + for name, world in (("sphere world", spheres), ("box world", box)): + # warm up, then time + free = np.asarray(checker.check_collision_free(None, cfg, world)) + t0 = time.perf_counter() + for _ in range(5): + free = np.asarray(checker.check_collision_free(None, cfg, world)) + dt = (time.perf_counter() - t0) / 5 + print(f"\n[{name}] {int(free.sum())}/{len(free)} configs collision-free") + print(f" {dt * 1e3:.2f} ms/call ({dt * 1e6 / len(free):.2f} us/config) " + f"for {len(free)} configs") + + # Single-config check (returns a scalar bool). + home = jnp.zeros((n,), dtype=jnp.float32) + print(f"\nhome configuration collision-free in empty box world: " + f"{bool(checker.check_collision_free(None, home, box))}") + + +if __name__ == "__main__": + main() diff --git a/examples/13_01_vamp_edge_validation.py b/examples/13_01_vamp_edge_validation.py new file mode 100644 index 00000000..790f8996 --- /dev/null +++ b/examples/13_01_vamp_edge_validation.py @@ -0,0 +1,91 @@ +"""Example: batch edge validation (+ point clouds) with the VAMP CPU backend. + +Edge validation is the headline use of the VAMP backend for sampling-based +planning: given many candidate motions (edges), decide which are entirely +collision-free. VAMP discretises each edge internally at the robot's planning +resolution and checks it with its SIMD ``fkcc`` routine, parallelised across the +batch with OpenMP. + +``check_edges_collision_free`` takes the two endpoints of each edge in the +second-to-last axis (shape ``[*batch, 2, n_act]``) — VAMP fills in the +interior — and returns one verdict per edge. + +This script also shows VAMP's CAPT (Collision-Affording Point Tree) path: a +point-cloud obstacle passed via ``point_cloud=`` with a per-point radius. + +Run (inside the `pyroffi` conda env, with cricket built): + + python examples/13_01_vamp_edge_validation.py +""" + +import os +import sys +import time +from pathlib import Path + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) + +import jax.numpy as jnp +import numpy as np + +from pyroffi.collision import Sphere, VAMPCPUCollisionChecker + +REPO_ROOT = Path(__file__).resolve().parents[1] +SPHERIZED_URDF = REPO_ROOT / "resources" / "panda" / "panda_spherized.urdf" +SRDF = REPO_ROOT / "resources" / "panda" / "panda.srdf" + + +def main() -> None: + checker = VAMPCPUCollisionChecker(SPHERIZED_URDF, srdf_path=SRDF) + n = checker.dimension + rng = np.random.RandomState(5) + + world = Sphere.from_center_and_radius( + center=jnp.array([[0.3, 0.0, 0.6], [0.0, 0.35, 0.7]]), + radius=jnp.array([0.15, 0.14]), + ) + + # ── Batch edge validation ─────────────────────────────────────────────── + E = 4096 + a = jnp.asarray(rng.uniform(-1.2, 1.2, size=(E, n)), dtype=jnp.float32) + b = jnp.asarray(rng.uniform(-1.2, 1.2, size=(E, n)), dtype=jnp.float32) + edges = jnp.stack([a, b], axis=1) # [E, 2, n] + + valid = np.asarray(checker.check_edges_collision_free(None, edges, world)) # warm up + t0 = time.perf_counter() + for _ in range(5): + valid = np.asarray(checker.check_edges_collision_free(None, edges, world)) + dt = (time.perf_counter() - t0) / 5 + print(f"[edges] {int(valid.sum())}/{E} edges collision-free") + print(f" {dt * 1e3:.2f} ms/call ({dt * 1e6 / E:.2f} us/edge) for {E} edges") + + # Sanity: VAMP samples (0, 1], so a valid edge must have its *goal* endpoint + # collision-free (the start is assumed pre-validated by the planner). + b_free = np.asarray(checker.check_collision_free(None, b, world)) + assert np.all(~valid | b_free) + print(" consistency OK: no edge is valid with a colliding goal endpoint") + + # ── Point-cloud (CAPT) obstacle ───────────────────────────────────────── + far = Sphere.from_center_and_radius( + center=jnp.array([[100.0, 100.0, 100.0]]), radius=jnp.array([1e-3]) + ) + cfg = jnp.asarray(rng.uniform(-1.2, 1.2, size=(2048, n)), dtype=jnp.float32) + gx, gz = np.meshgrid(np.linspace(0.15, 0.5, 25), np.linspace(0.2, 1.0, 25)) + cloud = np.stack([gx.ravel(), np.zeros(gx.size), gz.ravel()], axis=1).astype(np.float32) + + base = int(np.asarray(checker.check_collision_free(None, cfg, far)).sum()) + with_pc = int( + np.asarray( + checker.check_collision_free( + None, cfg, far, point_cloud=jnp.asarray(cloud), capt=(0.0, 1.0, 0.04) + ) + ).sum() + ) + print(f"\n[CAPT point cloud] {cloud.shape[0]} points, r_point=0.04") + print(f" free without cloud: {base}/{cfg.shape[0]}") + print(f" free with cloud: {with_pc}/{cfg.shape[0]} " + f"(point-cloud wall removed {base - with_pc} configs)") + + +if __name__ == "__main__": + main() diff --git a/examples/13_02_robogpu_vs_capt_pointcloud.py b/examples/13_02_robogpu_vs_capt_pointcloud.py new file mode 100644 index 00000000..777c0e36 --- /dev/null +++ b/examples/13_02_robogpu_vs_capt_pointcloud.py @@ -0,0 +1,294 @@ +"""Interactive: drag the Panda through a point-cloud field, RoboGPU vs CAPT. + +Grab the transform gizmo and sweep the arm (via IK) through a wall of points. +The robot is drawn as its collision-sphere model; whenever a link's spheres +overlap the point cloud the whole link turns **red**, and the offending points +turn red too. Two collision backends are evaluated every frame and reported in +the GUI panel so you can visually confirm they agree with what you see: + + * **RoboGPU** — OptiX ray-tracing sphere-octree checker (GPU). Uses the same + pyroffi spherized model that drives the red colouring, so its verdict should + match the "any link red" reference exactly. + * **CAPT** — VAMP's Collision-Affording Point Tree checker (CPU). Uses + VAMP's *own* internal spherization, so its verdict may differ slightly near + grazing contacts — that divergence is exactly what this tool lets you see. + +The brute-force per-sphere test (NumPy) is the ground-truth reference for the +red colouring. + +Multi-GPU (pmap) + RoboGPU's CUDA/OptiX kernels are device-safe under ``jax.pmap``: each visible + GPU keeps its own pipeline/BVH/graph caches, so pmap shards a batch of + configurations across all devices automatically. To show this off, every + frame we also fan a batch of perturbed candidate configurations out across all + visible GPUs via ``jax.pmap`` and report the aggregate free-count + timing in + the GUI. Candidate ``(device 0, slot 0)`` is the exact dragged config, so its + pmap verdict must match the single-device RoboGPU verdict (sanity check). + +Run (inside the `pyroffi` conda env): + # build the kernels first: + # bash build_kernels/build_robogpu_collision.sh + # bash build_kernels/build_cricket_jit.sh (optional, for CAPT) + python examples/13_02_robogpu_vs_capt_pointcloud.py +""" + +import os +import sys +import time +from pathlib import Path + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) + +import numpy as np +import jax +import jax.numpy as jnp +import jaxlie +import trimesh +import viser +import yourdfpy + +import pyroffi as pk +from pyroffi.collision import RobotCollisionSpherized, RoboGPUCollisionChecker, Sphere +from pyroffi.collision._cuda_collision import _spherized_local_geometry + +REPO_ROOT = Path(__file__).resolve().parents[1] +SPHERIZED_URDF = REPO_ROOT / "resources" / "panda" / "panda_spherized.urdf" +SRDF = REPO_ROOT / "resources" / "panda" / "panda.srdf" + +TARGET_LINK = "panda_hand" +R_ENV = 0.02 # environment point sphere radius +N_POINTS = 1200 # point-cloud size +BLOB_CENTER = np.array([0.45, 0.0, 0.55], np.float32) # small sphere location +BLOB_RADIUS = 0.10 # small sphere radius + +# Multi-GPU pmap demo: candidate configs evaluated per GPU each frame. +N_PER_DEVICE = 256 # perturbed candidates checked on every visible GPU +PERTURB_STD = 0.15 # std-dev (rad) of the candidate perturbation + + +# --------------------------------------------------------------------------- +# Geometry helpers +# --------------------------------------------------------------------------- + +def make_point_field(rng) -> np.ndarray: + """A small dense sphere of points the arm sweeps through.""" + # Uniformly sample points inside a small ball so the robot isn't engulfed. + dirs = rng.normal(size=(N_POINTS, 3)) + dirs /= np.linalg.norm(dirs, axis=1, keepdims=True) + radii = BLOB_RADIUS * np.cbrt(rng.uniform(0.0, 1.0, N_POINTS))[:, None] + return (BLOB_CENTER + dirs * radii).astype(np.float32) + + +def main() -> None: + robot = pk.Robot.from_urdf(yourdfpy.URDF.load(str(SPHERIZED_URDF))) + coll = RobotCollisionSpherized.from_urdf(yourdfpy.URDF.load(str(SPHERIZED_URDF))) + + NL = coll.num_links + f_local = np.asarray(_spherized_local_geometry(coll)) # [K, 4] k = s*NL + n + K = f_local.shape[0] + sphere_link = np.arange(K) % NL # link index per sphere + sphere_valid = f_local[:, 3] > 0.0 + radii_all = f_local[:, 3].copy() + r_robot_max = float(radii_all[sphere_valid].max()) + + # Map each collision-model link to its forward-kinematics link index. + fk_idx = np.array( + [robot.links.names.index(name) for name in coll.link_names], dtype=np.int32 + ) + sphere_fk = fk_idx[sphere_link] # FK link idx per sphere + + f_local_j = jnp.asarray(f_local[:, :3]) + sphere_fk_j = jnp.asarray(sphere_fk) + + @jax.jit + def world_spheres(cfg): + """Return [K, 3] world-frame sphere centres for a single config.""" + link_poses = robot.forward_kinematics(cfg) # [NL, 7] wxyz_xyz + T = jaxlie.SE3(link_poses[sphere_fk_j]) # [K] SE3 + return T.apply(f_local_j) # [K, 3] + + # ── Point cloud + checkers ─────────────────────────────────────────────── + rng = np.random.default_rng(0) + points = make_point_field(rng) + points_j = jnp.asarray(points) + far = Sphere.from_center_and_radius( + center=jnp.array([[100.0, 100.0, 100.0]]), radius=jnp.array([0.01])) + + print("Building RoboGPU checker ...") + robogpu = RoboGPUCollisionChecker(coll) + # Disable self-collision so RoboGPU's verdict reflects ONLY point-cloud + # contact — exactly what the red-link reference colouring shows. (Self- + # collision still works in production; we switch it off here so the + # "RoboGPU == reference" indicator is a clean point-cloud comparison.) + robogpu._f_pair_i = jnp.zeros((0,), dtype=jnp.int32) + robogpu._f_pair_j = jnp.zeros((0,), dtype=jnp.int32) + robogpu._cached_robot_id = None + robogpu._jit_fn = None + robogpu.set_world(far, point_cloud=points_j, r_env=R_ENV) + + # ── Multi-GPU pmap setup ───────────────────────────────────────────────── + # A single warmup call builds the checker's jitted FFI function (which closes + # over the static robot/world geometry). We then wrap it in jax.pmap so a + # config batch shaped [n_devices, N_PER_DEVICE, n_act] is sharded across every + # visible GPU — one shard per device, each running on its own CUDA stream and + # per-device kernel caches. + n_devices = jax.device_count() + n_act = robot.joints.lower_limits.shape[0] + mid_cfg = (robot.joints.lower_limits + robot.joints.upper_limits) / 2.0 + robogpu.check_collision_free(robot, mid_cfg[None, :]).block_until_ready() + pmap_check = jax.pmap(robogpu._jit_fn) # [D, P, n_act] -> [D, P] int32 (1=free) + pmap_key = jax.random.PRNGKey(0) + print(f"pmap RoboGPU over {n_devices} GPU(s): " + f"{n_devices * N_PER_DEVICE} candidates/frame") + + capt = None + try: + from pyroffi.collision import VAMPCPUCollisionChecker + print("Building CAPT (VAMP) checker ... (first run JIT-compiles)") + capt = VAMPCPUCollisionChecker(SPHERIZED_URDF, srdf_path=SRDF) + capt.set_world( + far, point_cloud=points_j, + capt_r_min=0.0, capt_r_max=r_robot_max, capt_r_point=R_ENV, + ) + print(" CAPT ready.") + except Exception as exc: + print(f" CAPT unavailable ({exc}); continuing with RoboGPU only.") + + # ── Viser scene ────────────────────────────────────────────────────────── + server = viser.ViserServer() + server.scene.add_grid("/ground", width=2.0, height=2.0) + + # One icosphere mesh, instanced once per collision sphere. + unit = trimesh.creation.icosphere(subdivisions=2, radius=1.0) + verts = np.asarray(unit.vertices, dtype=np.float32) + faces = np.asarray(unit.faces, dtype=np.uint32) + + n_show = int(sphere_valid.sum()) + show_idx = np.where(sphere_valid)[0] + sphere_handle = server.scene.add_batched_meshes_simple( + "/robot_spheres", + vertices=verts, + faces=faces, + batched_positions=np.zeros((n_show, 3), np.float32), + batched_wxyzs=np.tile(np.array([1, 0, 0, 0], np.float32), (n_show, 1)), + batched_scales=radii_all[show_idx].astype(np.float32), + batched_colors=np.tile(np.array([90, 200, 255], np.uint8), (n_show, 1)), + ) + + pc_colors = np.tile(np.array([160, 160, 160], np.uint8), (len(points), 1)) + pc_handle = server.scene.add_point_cloud( + "/point_cloud", points=points, colors=pc_colors, point_size=R_ENV, + point_shape="circle", + ) + + # IK drag target. + ik_target = server.scene.add_transform_controls( + "/ik_target", scale=0.2, position=(0.45, 0.0, 0.55), wxyz=(0, 1, 0, 0) + ) + + # GUI readouts. + g_ref = server.gui.add_text("Reference (red links)", "—", disabled=True) + g_rg = server.gui.add_text("RoboGPU verdict", "—", disabled=True) + g_capt = server.gui.add_text("CAPT verdict", "—", disabled=True) + g_agree = server.gui.add_text("RoboGPU == reference", "—", disabled=True) + g_t_rg = server.gui.add_number("RoboGPU (us)", 0.0, disabled=True) + g_t_capt = server.gui.add_number("CAPT (us)", 0.0, disabled=True) + g_pmap = server.gui.add_text( + f"pmap free / total ({n_devices} GPUs)", "—", disabled=True) + g_t_pmap = server.gui.add_number("pmap batch (us)", 0.0, disabled=True) + g_pmap_agree = server.gui.add_text("pmap[0,0] == single", "—", disabled=True) + + RED = np.array([220, 40, 40], np.uint8) + BLUE = np.array([90, 200, 255], np.uint8) + GREY = np.array([160, 160, 160], np.uint8) + PT_RED = np.array([240, 60, 60], np.uint8) + + target_idx = robot.links.names.index(TARGET_LINK) + ik_solve = jax.jit( + lambda pose, key, prev: robot.inverse_kinematics( + target_link_name=TARGET_LINK, target_pose=pose, + rng_key=key, previous_cfg=prev, + ) + ) + rng_key = jax.random.PRNGKey(0) + solution = (robot.joints.lower_limits + robot.joints.upper_limits) / 2 + + while True: + target_pose = jaxlie.SE3.from_rotation_and_translation( + rotation=jaxlie.SO3(wxyz=jnp.array(ik_target.wxyz)), + translation=jnp.array(ik_target.position), + ) + rng_key, subkey = jax.random.split(rng_key) + solution = ik_solve(target_pose, subkey, solution) + solution.block_until_ready() + cfg = solution + + # World-frame collision spheres + brute-force reference vs the cloud. + centers = np.asarray(world_spheres(cfg)) # [K, 3] + d2 = ((centers[show_idx][:, None, :] - points[None, :, :]) ** 2).sum(-1) + rsum = (radii_all[show_idx][:, None] + R_ENV) ** 2 # [n_show, 1] + sphere_hit = np.any(d2 < rsum, axis=1) # [n_show] bool + + # Per-link: a link is red if any of its spheres collide. + link_hit = np.zeros(NL, dtype=bool) + hit_links = sphere_link[show_idx][sphere_hit] + link_hit[hit_links] = True + # Colour each shown sphere by whether its LINK is in collision. + link_of_show = sphere_link[show_idx] + col = np.where(link_hit[link_of_show][:, None], RED, BLUE).astype(np.uint8) + + # Colliding points → red. + pt_hit = np.any(d2 < rsum, axis=0) # [Mp] bool + pcol = np.where(pt_hit[:, None], PT_RED, GREY).astype(np.uint8) + + sphere_handle.batched_positions = centers[show_idx].astype(np.float32) + sphere_handle.batched_colors = col + pc_handle.colors = pcol + + ref_collision = bool(link_hit.any()) + g_ref.value = ("COLLISION" if ref_collision else "free") + \ + f" ({int(link_hit.sum())} links)" + + # RoboGPU verdict (1 = free). + t0 = time.perf_counter() + rg_free = bool(np.asarray( + robogpu.check_collision_free(robot, cfg[None, :])).reshape(())) + g_t_rg.value = (time.perf_counter() - t0) * 1e6 + g_rg.value = "free" if rg_free else "COLLISION" + g_agree.value = "yes" if (rg_free != ref_collision) else "NO — mismatch!" + + # ── Multi-GPU pmap batch: perturbed candidates fanned across all GPUs ── + # Build [n_devices, N_PER_DEVICE, n_act]. Slot (0, 0) is the exact dragged + # config (zero perturbation) so its verdict must match the single-device + # RoboGPU result above — a live correctness check of the pmap path. + pmap_key, sub = jax.random.split(pmap_key) + noise = PERTURB_STD * jax.random.normal( + sub, (n_devices, N_PER_DEVICE, n_act), dtype=jnp.float32) + noise = noise.at[0, 0].set(0.0) + cands = cfg[None, None, :] + noise + t0 = time.perf_counter() + verdicts = pmap_check(cands) # [n_devices, N_PER_DEVICE] int32 + verdicts.block_until_ready() + g_t_pmap.value = (time.perf_counter() - t0) * 1e6 + verdicts_np = np.asarray(verdicts) + n_total = verdicts_np.size + n_free = int(verdicts_np.sum()) + g_pmap.value = f"{n_free} / {n_total}" + pmap_ref_free = bool(verdicts_np[0, 0]) + g_pmap_agree.value = "yes" if (pmap_ref_free == rg_free) else "NO — mismatch!" + + # CAPT verdict. + if capt is not None: + t0 = time.perf_counter() + capt_free = bool(np.asarray( + capt.check_collision_free(None, cfg[None, :])).reshape(())) + g_t_capt.value = (time.perf_counter() - t0) * 1e6 + g_capt.value = "free" if capt_free else "COLLISION" + else: + g_capt.value = "n/a" + + time.sleep(0.02) + + +if __name__ == "__main__": + main() diff --git a/external/cricket b/external/cricket new file mode 160000 index 00000000..c9a76a75 --- /dev/null +++ b/external/cricket @@ -0,0 +1 @@ +Subproject commit c9a76a75c22c7d49d06c12c9badabe970eded11a diff --git a/external/vamp b/external/vamp new file mode 160000 index 00000000..5aa84f4e --- /dev/null +++ b/external/vamp @@ -0,0 +1 @@ +Subproject commit 5aa84f4e22046db816cf61c3976cb5b8d7739be2 diff --git a/requirements.txt b/requirements.txt index 0b336dc2..d670b9a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,14 +6,19 @@ attrs==25.4.0 beautifulsoup4==4.14.3 blinker==1.9.0 brax==0.13.0 +casadi==3.7.2 +certifi==2026.4.22 +charset-normalizer==3.4.7 chex==0.1.91 click==8.3.1 colorlog==6.10.1 contourpy==1.3.3 +-e git+https://github.com/commalab/cricket@c9a76a75c22c7d49d06c12c9badabe970eded11a#egg=cricket cycler==0.12.1 docstring_parser==0.17.0 embreex==2.17.7.post7 etils==1.13.0 +exceptiongroup @ file:///home/conda/feedstock_root/build_artifacts/exceptiongroup_1763918002538/work Flask==3.1.2 flask-cors==6.0.1 flax==0.12.0 @@ -28,8 +33,10 @@ httpcore==1.0.9 httpx==0.28.1 humanize==4.14.0 hydra-core==1.3.2 +idna==3.13 ImageIO==2.37.2 importlib_resources==6.5.2 +iniconfig==2.3.0 itsdangerous==2.2.0 jax==0.8.0 jax-cuda13-pjrt==0.8.0 @@ -48,8 +55,10 @@ loguru==0.7.3 lxml==6.0.2 manifold3d==3.2.1 mapbox_earcut==1.0.3 +markdown-it-py==4.0.0 MarkupSafe==3.0.3 matplotlib==3.10.7 +mdurl==0.1.2 ml_collections==1.1.0 ml_dtypes==0.5.3 mpmath==1.3.0 @@ -57,6 +66,8 @@ msgpack==1.1.2 msgspec==0.19.0 mujoco==3.3.7 mujoco-mjx==3.3.7 +nanobind @ file:///home/conda/feedstock_root/build_artifacts/nanobind-split_1781788265459/work +nest-asyncio==1.6.0 networkx==3.5 nodeenv==1.9.1 numpy==2.3.4 @@ -79,24 +90,35 @@ omegaconf==2.3.0 opt_einsum==3.4.0 optax==0.2.6 orbax-checkpoint==0.11.28 +packaging==26.0 pandas==3.0.0 +pathspec @ file:///home/conda/feedstock_root/build_artifacts/pathspec_1777271521463/work pillow==12.0.0 +pluggy==1.6.0 protobuf==6.33.1 psutil==7.1.3 pycollada==0.9.2 +Pygments==2.20.0 pyliblzfse==0.4.1 PyOpenGL==3.1.10 pyparsing==3.2.5 +-e git+https://github.com/commalab/pyroffi@4f0607281f1469cd8e1f170665c027ee52385f49#egg=pyroffi +pytest==9.1.1 +python-dateutil==2.9.0.post0 PyYAML==6.0.3 referencing==0.37.0 +requests==2.33.1 +rich==14.3.4 robot_descriptions==1.21.0 rpds-py==0.28.0 rtree==1.4.1 +scikit_build_core @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_scikit-build-core_1781724892/work scipy==1.16.3 -setuptools==80.9.0 shapely==2.1.2 +shellingham==1.5.4 shtab==1.7.2 simplejson==3.20.2 +six==1.17.0 smmap==5.0.2 sniffio==1.3.1 soupsieve==2.8.3 @@ -105,20 +127,23 @@ sympy==1.14.0 tensorboardX==2.6.4 tensorstore==0.1.79 termcolor==3.2.0 +tomli @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_tomli_1774492402/work toolz==1.1.0 +tqdm==4.67.3 treescope==0.1.10 trimesh==4.9.0 typeguard==4.4.4 typer==0.20.0 typer-slim==0.20.0 +typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1756220668/work tyro==0.9.35 urllib3==2.5.0 +-e git+https://github.com/commalab/vamp@5aa84f4e22046db816cf61c3976cb5b8d7739be2#egg=vamp_planner vhacdx==0.0.9 viser==1.0.15 wadler_lindig==0.1.7 websockets==15.0.1 Werkzeug==3.1.3 -wheel==0.45.1 xxhash==3.6.0 -yourdfpy==0.0.58 -zipp==3.23.0 \ No newline at end of file +yourdfpy==0.0.60 +zipp @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_zipp_1779159948/work diff --git a/src/pyroffi/_robot.py b/src/pyroffi/_robot.py index 57fa3c8e..86ff0060 100644 --- a/src/pyroffi/_robot.py +++ b/src/pyroffi/_robot.py @@ -81,7 +81,7 @@ def forward_kinematics( unroll_fk: If True, unroll the JAX fori_loop over joints (ignored when use_cuda=True). use_cuda: If True, dispatch to an external CUDA kernel via the JAX FFI instead of the default JAX implementation. Requires ``_fk_cuda.so`` to be compiled first - (see ``src/pyroffi/cuda_kernels/build_fk_cuda.sh``). + (see ``build_kernels/build_fk_cuda.sh``). Returns: The SE(3) transforms of the links, ordered by `self.link.names`, @@ -271,7 +271,7 @@ def _fk_cuda_differentiable( input. Both ``jax.jvp`` and ``jax.grad`` work; differentiated calls evaluate the JAX FK (the FFI itself is not differentiable). """ - from .cuda_kernels._fk_cuda import fk_cuda + from .cuda_kernels.fk._fk_cuda import fk_cuda return fk_cuda( cfg=cfg, diff --git a/src/pyroffi/collision/__init__.py b/src/pyroffi/collision/__init__.py index d835c987..74a6d779 100644 --- a/src/pyroffi/collision/__init__.py +++ b/src/pyroffi/collision/__init__.py @@ -21,3 +21,7 @@ from ._cuda_collision import CUDABinaryCollisionChecker as CUDABinaryCollisionChecker from ._cuda_collision import make_cuda_checker as make_cuda_checker from ._cuda_collision import make_cuda_binary_checker as make_cuda_binary_checker +from ._vamp_collision import VAMPCPUCollisionChecker as VAMPCPUCollisionChecker +from ._vamp_collision import make_vamp_cpu_checker as make_vamp_cpu_checker +from ._robogpu_collision import RoboGPUCollisionChecker as RoboGPUCollisionChecker +from ._robogpu_collision import make_robogpu_checker as make_robogpu_checker diff --git a/src/pyroffi/collision/_cuda_collision.py b/src/pyroffi/collision/_cuda_collision.py index 563c2f72..a0523056 100644 --- a/src/pyroffi/collision/_cuda_collision.py +++ b/src/pyroffi/collision/_cuda_collision.py @@ -30,7 +30,7 @@ sphere-robot radii [B, K] Requires the compiled shared library _collision_cuda_lib.so: - bash src/pyroffi/cuda_kernels/build_collision_cuda.sh + bash build_kernels/build_collision_cuda.sh """ from __future__ import annotations @@ -47,7 +47,7 @@ from ._geometry import Box, Capsule, CollGeom, HalfSpace, Sphere from ._robot_collision import RobotCollision, RobotCollisionSpherized -from ..cuda_kernels._collision_cuda_ffi import ( +from ..cuda_kernels.collision._collision_cuda_ffi import ( _load_and_register, collision_world_sphere, collision_world_sphere_reduced, @@ -55,7 +55,7 @@ collision_self_sphere, collision_self_capsule, ) -from ..cuda_kernels._collision_binary_cuda_ffi import ( +from ..cuda_kernels.collision._collision_binary_cuda_ffi import ( _load_and_register as _load_and_register_binary, collision_binary, ) @@ -217,7 +217,7 @@ class CUDADifferentiableSDFCollisionChecker: Notes: - Requires the compiled _collision_cuda_lib.so. - Build with: bash src/pyroffi/cuda_kernels/build_collision_cuda.sh + Build with: bash build_kernels/build_collision_cuda.sh - World geometry must be a flat collection (no leading batch dims). - FK is performed by the CUDA FK kernel (via ``robot.forward_kinematics(use_cuda=True)``); geometry transforms and distance computation both run on CUDA/device. diff --git a/src/pyroffi/collision/_robogpu_collision.py b/src/pyroffi/collision/_robogpu_collision.py new file mode 100644 index 00000000..d56a0490 --- /dev/null +++ b/src/pyroffi/collision/_robogpu_collision.py @@ -0,0 +1,314 @@ +"""GPU collision checker using OptiX ray-tracing cores for point-cloud queries. + +RoboGPUCollisionChecker implements the RoboGPU architecture (arXiv:2603.01517) +adapted for sphere-based robot models: + + - The environment point cloud is built into an OptiX BVH (each point as a + sphere of radius ``r_env``). The BVH is built once and cached across calls + that use the same point cloud; subsequent checks pay only the traversal cost. + + - Robot collision geometry comes from a :class:`RobotCollisionSpherized` + model (the same input as :class:`CUDABinaryCollisionChecker`). + + - Per-call execution on a single CUDA stream (fully asynchronous): + 1. CUDA kernel: FK → world-frame robot spheres → regular world geometry + check (spheres / capsules / boxes / halfspaces) + self-collision. + 2. OptiX kernel: for configs still free after Stage 1, each robot sphere + queries the point-cloud BVH. OptiX any-hit terminates BVH traversal + on the first hit (per-sphere early exit), and the raygen loop breaks + immediately (per-config early exit). + + - Public API deliberately mirrors both :class:`CUDABinaryCollisionChecker` + (same constructor argument, same ``check_collision_free`` / + ``check_edges_collision_free`` signatures) and :class:`VAMPCPUCollisionChecker` + (same ``set_world`` keyword args for the CAPT-style point-cloud parameters). + +Prerequisites: + bash build_kernels/build_robogpu_collision.sh + (Requires NVIDIA OptiX SDK 7.x and CUDA 11.2+.) + +Note: HalfSpace obstacles are supported (via the CUDA stage only), since OptiX +is used exclusively for the point cloud. +""" + +from __future__ import annotations + +from functools import lru_cache +from typing import Optional + +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array +from jaxtyping import Float + +from ._geometry import CollGeom +from ._robot_collision import RobotCollisionSpherized +from ._cuda_collision import _extract_world_arrays, _spherized_local_geometry + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _r_robot_max(f_local: Array) -> float: + """Maximum active (non-padding) robot sphere radius.""" + radii = np.asarray(f_local[:, 3]) + active = radii[radii > 0.0] + return float(active.max()) if active.size > 0 else 0.01 + + +# --------------------------------------------------------------------------- +# RoboGPUCollisionChecker +# --------------------------------------------------------------------------- + +class RoboGPUCollisionChecker: + """OptiX-accelerated sphere-octree collision checker for point-cloud worlds. + + Args: + inner: Spherized robot collision model (same as CUDABinaryCollisionChecker). + edge_granularity: Number of interpolation points per edge for + ``check_edges_collision_free`` (pre-discretised like the CUDA binary + checker; increase for denser edges). + + Example:: + + checker = RoboGPUCollisionChecker(RobotCollisionSpherized.from_urdf(urdf)) + checker.set_world(world_geom, point_cloud=pc, r_env=0.02) + free = checker.check_collision_free(robot, cfg) # [B] int32 + """ + + def __init__( + self, + inner: RobotCollisionSpherized, + *, + edge_granularity: int = 16, + ) -> None: + from ..cuda_kernels.collision._robogpu_collision_ffi import _load_and_register + _load_and_register() + + if not isinstance(inner, RobotCollisionSpherized): + raise TypeError( + "RoboGPUCollisionChecker requires a RobotCollisionSpherized model; " + f"got {type(inner).__name__}." + ) + + self._inner = inner + self._edge_granularity = int(edge_granularity) + + # Robot sphere geometry (link-local, static across configs). + self._f_local = jnp.asarray(_spherized_local_geometry(inner)) # [K, 4] + self._f_pair_i = jnp.asarray(inner.active_idx_i, dtype=jnp.int32) + self._f_pair_j = jnp.asarray(inner.active_idx_j, dtype=jnp.int32) + self._r_robot_max = _r_robot_max(self._f_local) + + # World geometry cache (updated by set_world or lazily on first call). + self._ws: Optional[Array] = None + self._wc: Optional[Array] = None + self._wb: Optional[Array] = None + self._wh: Optional[Array] = None + self._wp: Array = jnp.zeros((0, 3), dtype=jnp.float32) # point cloud + self._r_env: float = 0.01 + self._cached_world_id: Optional[int] = None + + # Per-robot JIT cache (keyed by robot object identity). + self._cached_robot_id: Optional[int] = None + self._jit_fn = None + + # ── Properties ────────────────────────────────────────────────────────── + + @property + def num_links(self) -> int: + return self._inner.num_links + + @property + def link_names(self) -> tuple[str, ...]: + return self._inner.link_names + + # ── World handling ─────────────────────────────────────────────────────── + + def set_world( + self, + world_geom: CollGeom, + point_cloud: Optional[Array] = None, + *, + r_env: float = 0.01, + # CAPT-compatible aliases (ignored — r_env covers all points uniformly) + capt_r_min: float = 0.0, + capt_r_max: float = 1.0, + capt_r_point: float = 0.0, + ) -> None: + """Cache world obstacles and optional point cloud. + + Args: + world_geom: Regular world geometry (spheres, capsules, boxes, + halfspaces) — checked in CUDA Stage 1. + point_cloud: ``[Mp, 3]`` float32 array of environment points. + Pass ``None`` or an empty array to use Stage 1 only. + r_env: Radius of each environment point sphere. Also used as + ``capt_r_point`` equivalent. All points share the same radius. + """ + ws_np, wc_np, wb_np, wh_np = _extract_world_arrays(world_geom) + self._ws = jnp.array(ws_np) + self._wc = jnp.array(wc_np) + self._wb = jnp.array(wb_np) + self._wh = jnp.array(wh_np) + self._cached_world_id = id(world_geom) + + if point_cloud is not None: + self._wp = jnp.asarray(point_cloud, dtype=jnp.float32).reshape(-1, 3) + else: + self._wp = jnp.zeros((0, 3), dtype=jnp.float32) + self._r_env = float(r_env if r_env > 0.0 else capt_r_point) + + # Invalidate robot JIT cache so new world args are picked up. + self._cached_robot_id = None + self._jit_fn = None + + def _ensure_world(self, world_geom: CollGeom) -> None: + if id(world_geom) != self._cached_world_id: + self.set_world(world_geom) + + # ── JIT cache ──────────────────────────────────────────────────────────── + + def _ensure_jit( + self, + robot, + world_geom: Optional[CollGeom], + point_cloud: Optional[Array], + r_env: float, + ) -> None: + """Build and cache a jax.jit'd call for this (robot, world) combination.""" + cache_id = (id(robot), id(world_geom), id(point_cloud), r_env) + if cache_id == self._cached_robot_id: + return + + from ..cuda_kernels.collision._robogpu_collision_ffi import robogpu_collision + + _robot = robot + _f_local = self._f_local + _f_pair_i = self._f_pair_i + _f_pair_j = self._f_pair_j + _r_robot = self._r_robot_max + + # Resolve world arrays once at JIT build time. + if world_geom is not None: + ws_np, wc_np, wb_np, wh_np = _extract_world_arrays(world_geom) + _ws = jnp.array(ws_np) + _wc = jnp.array(wc_np) + _wb = jnp.array(wb_np) + _wh = jnp.array(wh_np) + else: + _ws = self._ws if self._ws is not None else jnp.zeros((0, 4), jnp.float32) + _wc = self._wc if self._wc is not None else jnp.zeros((0, 7), jnp.float32) + _wb = self._wb if self._wb is not None else jnp.zeros((0, 15), jnp.float32) + _wh = self._wh if self._wh is not None else jnp.zeros((0, 6), jnp.float32) + + if point_cloud is not None: + _pc = jnp.asarray(point_cloud, dtype=jnp.float32).reshape(-1, 3) + _re = float(r_env) + else: + _pc = self._wp + _re = self._r_env + + def _impl(cfg_flat): + j = _robot.joints + return robogpu_collision( + cfg_flat, + twists=j.twists, + parent_tf=j.parent_transforms, + parent_idx=j.parent_indices, + act_idx=j.actuated_indices, + mimic_mul=j.mimic_multiplier, + mimic_off=j.mimic_offset, + mimic_act_idx=j.mimic_act_indices, + topo_inv=j._topo_sort_inv, + link_parent_joint=_robot.links.parent_joint_indices, + f_local=_f_local, + f_pair_i=_f_pair_i, + f_pair_j=_f_pair_j, + world_spheres=_ws, + world_capsules=_wc, + world_boxes=_wb, + world_halfspaces=_wh, + point_cloud=_pc, + r_env=_re, + r_robot_max=_r_robot, + ) + + self._jit_fn = jax.jit(_impl) + self._cached_robot_id = cache_id + + # ── Public API ─────────────────────────────────────────────────────────── + + def check_collision_free( + self, + robot, + cfg: Float[Array, "*batch actuated_count"], + world_geom: Optional[CollGeom] = None, + point_cloud: Optional[Array] = None, + r_env: float = 0.0, + ) -> Array: + """Return ``int32[*batch]``: 1 = collision-free, 0 = in-collision. + + Args: + robot: Robot model (provides FK joint arrays). + cfg: Configuration tensor, shape ``[*batch, n_act]``. + world_geom: Override the cached world geometry for this call. + point_cloud: Override the cached point cloud for this call + (``[Mp, 3]`` float32). Pass ``None`` to use the cached cloud. + r_env: Override the env sphere radius for this call. + """ + cfg = jnp.asarray(cfg, dtype=jnp.float32) + batch_axes = cfg.shape[:-1] + n_act = cfg.shape[-1] + B = int(np.prod(batch_axes)) if batch_axes else 1 + cfg_flat = cfg.reshape(B, n_act) + + self._ensure_jit(robot, world_geom, point_cloud, + r_env if r_env > 0.0 else self._r_env) + out = self._jit_fn(cfg_flat) + return out.reshape(batch_axes) if batch_axes else out.reshape(()) + + def check_edges_collision_free( + self, + robot, + edge_cfgs: Float[Array, "*batch granularity actuated_count"], + world_geom: Optional[CollGeom] = None, + point_cloud: Optional[Array] = None, + r_env: float = 0.0, + ) -> Array: + """Batch edge validation: ``int32[*batch]`` — 1 if all points free. + + ``edge_cfgs`` has shape ``[*batch, G, n_act]`` where G is the number of + pre-discretised waypoints along each edge (``edge_granularity``). The + result is 1 only when ALL G points are collision-free. + + Unlike :class:`VAMPCPUCollisionChecker`, this checker does NOT + discretise internally; the caller is responsible for interpolation:: + + G = 16 + ts = jnp.linspace(0, 1, G) + edges = a[:, None, :] * (1 - ts)[None, :, None] + b[:, None, :] * ts[None, :, None] + free = checker.check_edges_collision_free(robot, edges, world) + """ + cfg = jnp.asarray(edge_cfgs, dtype=jnp.float32) + *edge_axes, G, n_act = cfg.shape + E = int(np.prod(edge_axes)) if edge_axes else 1 + + # Flatten [E, G, n_act] → [E*G, n_act], check all at once, then AND. + cfg_flat = cfg.reshape(E * G, n_act) + self._ensure_jit(robot, world_geom, point_cloud, + r_env if r_env > 0.0 else self._r_env) + out_flat = self._jit_fn(cfg_flat) # [E*G] int32 + # A config is free iff ALL G waypoints are free: min over the G axis. + out_edges = out_flat.reshape(E, G).min(axis=1) # [E] int32 + return out_edges.reshape(tuple(edge_axes)) if edge_axes else out_edges.reshape(()) + + +def make_robogpu_checker( + inner: RobotCollisionSpherized, + **kwargs, +) -> RoboGPUCollisionChecker: + """Convenience factory for :class:`RoboGPUCollisionChecker`.""" + return RoboGPUCollisionChecker(inner, **kwargs) diff --git a/src/pyroffi/collision/_vamp_collision.py b/src/pyroffi/collision/_vamp_collision.py new file mode 100644 index 00000000..0f924290 --- /dev/null +++ b/src/pyroffi/collision/_vamp_collision.py @@ -0,0 +1,533 @@ +"""CPU collision checking via VAMP, JIT-compiled per robot through cricket. + +This is the CPU counterpart to :class:`CUDABinaryCollisionChecker`. Where the +CUDA checker fuses a hand-written SIMT kernel, this checker reuses VAMP's +heavily-optimised SIMD ``fkcc`` collision routine +(https://github.com/KavrakiLab/vamp) and specialises it to a concrete robot at +*runtime*: + + 1. When the checker is constructed (i.e. as soon as the robot is defined), + cricket parses the URDF and emits a ``vamp::robots::`` C++ struct + (traced forward kinematics + spherized collision). + 2. That struct is stitched into a tiny translation unit + (``_robot_edge_validation_tu.cc.in``) that instantiates the JAX FFI + handlers in ``_edge_validation_ffi.hh`` for the robot. + 3. cricket's JIT (LLVM ORC) compiles the TU once, caching the compiled object + on disk keyed by a content hash so subsequent constructions for the same + robot reuse the cached binary. + 4. The resulting XLA FFI custom-call handlers are registered with JAX and + invoked from :meth:`check_collision_free` / :meth:`check_edges_collision_free`. + +The public method surface intentionally mirrors +:class:`CUDABinaryCollisionChecker` so call sites are interchangeable; the +constructor differs because cricket needs the URDF/SRDF (not a pre-spherized +pyroffi model). + +Note: VAMP's edge validation supports point-cloud obstacles through its CAPT +(Collision-Affording Point Tree). Pass ``point_cloud=`` (an ``[Mp, 3]`` array) +to :meth:`set_world` / the check methods to enable it. HalfSpace obstacles are +not supported by the VAMP backend (use a large flat Box instead). +""" + +from __future__ import annotations + +import hashlib +from functools import lru_cache +from pathlib import Path +from typing import Optional + +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array +from jaxtyping import Float + +from ._geometry import CollGeom +from ._cuda_collision import _extract_world_arrays + +_KERNELS_DIR = Path(__file__).resolve().parent.parent / "vamp_kernels" +_FFI_HEADER = _KERNELS_DIR / "_edge_validation_ffi.hh" +_TU_TEMPLATE = _KERNELS_DIR / "_robot_edge_validation_tu.cc.in" + +# Repo-relative default include roots for the JIT. These can be overridden via +# the ``include_dirs`` constructor argument (e.g. when vamp's CPM dependencies +# live somewhere else). vamp's transitive header-only deps (pdqsort, +# SIMDxorshift, nigh) are fetched into vamp's build ``.cpm-cache`` — point at +# them with ``include_dirs`` if header resolution fails. +_REPO_ROOT = Path(__file__).resolve().parents[3] +_DEFAULT_VAMP_INCLUDE = _REPO_ROOT / "external" / "vamp" / "src" / "impl" + + +@lru_cache(maxsize=1) +def _xla_ffi_include() -> str: + """Locate the XLA FFI headers shipped with the installed jaxlib.""" + import jaxlib + + root = Path(jaxlib.__file__).resolve().parent / "include" + if not (root / "xla" / "ffi" / "api" / "ffi.h").exists(): + raise RuntimeError( + f"Could not find xla/ffi/api/ffi.h under {root}. " + "Pass include_dirs=[...] pointing at the XLA FFI headers." + ) + return str(root) + + +@lru_cache(maxsize=1) +def _cricket_jit(): + """Import cricket's JIT submodule, with a helpful error if unavailable.""" + try: + import cricket # noqa: F401 + from cricket import _core_ext + + jit = getattr(_core_ext, "jit", None) + if jit is None: + raise RuntimeError( + "cricket was built without JIT support " + "(reconfigure with -DCRICKET_BUILD_JIT=ON)." + ) + return cricket, jit + except ImportError as exc: # pragma: no cover - environment dependent + raise RuntimeError( + "cricket is not importable. Build it from external/cricket with the " + "Python extension and JIT enabled (CRICKET_BUILD_PYTHON=ON, " + "CRICKET_BUILD_JIT=ON)." + ) from exc + + +def _robot_name_from_urdf(urdf_path: Path) -> str: + """Derive a valid C++ struct name from the URDF ```` tag. + + Falls back to the file stem. The result is sanitised to a valid identifier + and capitalised (e.g. ``panda`` -> ``Panda``). + """ + import re + + name = None + try: + text = urdf_path.read_text() + m = re.search(r"]*\bname\s*=\s*\"([^\"]+)\"", text) + if m: + name = m.group(1) + except OSError: + pass + if not name: + name = urdf_path.stem + name = re.sub(r"[^0-9a-zA-Z_]", "_", name) + if name and name[0].isdigit(): + name = "_" + name + return name[:1].upper() + name[1:] if name else "Robot" + + +def _find_pdqsort_dir() -> Optional[str]: + """Locate a directory containing ``pdqsort.h`` (vamp's CAPT dependency). + + vamp fetches pdqsort via CPM into its build tree, so search a few likely + roots: an explicit ``$VAMP_PDQSORT_DIR``, then any vamp ``_deps`` build + cache under the user's work tree. + """ + env = __import__("os").environ.get("VAMP_PDQSORT_DIR") + if env and (Path(env) / "pdqsort.h").exists(): + return env + for base in (_REPO_ROOT.parent, _REPO_ROOT, Path.home() / "Work"): + hits = list(base.glob("**/_deps/pdqsort-src/pdqsort.h")) + if hits: + return str(hits[0].parent) + return None + + +def _default_include_dirs() -> list[str]: + import os + + dirs = [str(_DEFAULT_VAMP_INCLUDE), _xla_ffi_include()] + conda = os.environ.get("CONDA_PREFIX") + for cand in ( + os.path.join(conda, "include", "eigen3") if conda else None, + os.path.join(conda, "include") if conda else None, + "/usr/include/eigen3", + "/usr/local/include/eigen3", + ): + if cand and Path(cand).exists(): + dirs.append(cand) + pdq = _find_pdqsort_dir() + if pdq: + dirs.append(pdq) + return dirs + + +@lru_cache(maxsize=1) +def _preload_runtime_libs() -> None: + """Load libstdc++ and the OpenMP runtime with RTLD_GLOBAL. + + The cricket JIT resolves external symbols via ``dlsym(RTLD_DEFAULT)``, so the + C++ standard library and OpenMP runtime that the JIT-compiled collision code + calls into must be present in the global symbol scope. We load them from the + active conda prefix when available, else by soname. + """ + import ctypes + import os + + candidates = ["libstdc++.so.6", "libomp.so"] + conda = os.environ.get("CONDA_PREFIX") + libdir = Path(conda) / "lib" if conda else None + for name in candidates: + loaded = False + if libdir is not None and (libdir / name).exists(): + try: + ctypes.CDLL(str(libdir / name), mode=ctypes.RTLD_GLOBAL) + loaded = True + except OSError: + pass + if not loaded: + try: + ctypes.CDLL(name, mode=ctypes.RTLD_GLOBAL) + except OSError: + pass # surfaced later as a clear "Symbols not found" JIT error + + +# Registered (target_name) per (robot hash, kind) so we only register once. +_REGISTERED: dict[str, str] = {} + + +class VAMPCPUCollisionChecker: + """JIT-compiled VAMP CPU collision checker with batch edge validation. + + Args: + urdf_path: Path to the robot URDF (cricket parses it to emit FK + CC). + srdf_path: Optional SRDF for self-collision pairs. + end_effector: Optional end-effector link name (forwarded to cricket). + cache_dir: On-disk JIT object cache directory (defaults to cricket's). + include_dirs: Extra ``-I`` include roots for the JIT compile. Appended + to the auto-discovered vamp / Eigen / XLA-FFI roots. + extra_flags: Extra clang flags (e.g. ``["-march=native", "-fopenmp"]``). + """ + + def __init__( + self, + urdf_path: str | Path, + srdf_path: Optional[str | Path] = None, + end_effector: Optional[str] = None, + *, + robot_name: Optional[str] = None, + resolution: int = 32, + cache_dir: Optional[str | Path] = None, + include_dirs: Optional[list[str]] = None, + extra_flags: Optional[list[str]] = None, + ) -> None: + cricket, jit = _cricket_jit() + + urdf_path = Path(urdf_path).resolve() + srdf_path = Path(srdf_path).resolve() if srdf_path is not None else None + + # cricket's template needs a C++ struct `name` and a planning + # `resolution`; neither is in RobotInfo.json(), so supply them via data. + if robot_name is None: + robot_name = _robot_name_from_urdf(urdf_path) + + # 1. Codegen: URDF -> vamp::robots:: source. + gen = cricket.generate_robot_source( + cricket.GenOptions( + urdf=urdf_path, + srdf=srdf_path, + end_effector=end_effector, + data={"name": robot_name, "resolution": int(resolution)}, + ) + ) + robot_type_name = gen.robot_name # struct name, e.g. "Panda" + robot_token = robot_type_name.lower() # symbol suffix, e.g. "panda" + self._dimension = int(gen.dimension) + self._n_spheres = int(gen.n_spheres) + + # 2. Materialise the generated header so the TU can #include it, keyed by + # a content hash for cache stability. The hash folds in the generated + # robot source AND the FFI handler header + TU template contents, so an + # edit to either busts the on-disk object cache (which otherwise only + # sees the TU string, not the headers it #includes by path). + digest = hashlib.sha1() + digest.update(gen.source.encode()) + digest.update(_FFI_HEADER.read_bytes()) + digest.update(_TU_TEMPLATE.read_bytes()) + src_hash = digest.hexdigest()[:16] + work_dir = Path(cache_dir) if cache_dir is not None else Path(jit.default_cache_dir()) + work_dir.mkdir(parents=True, exist_ok=True) + header_path = work_dir / f"vamp_robot_{robot_token}_{src_hash}.hh" + if not header_path.exists(): + header_path.write_text(gen.source) + + # 3. Build the per-robot translation unit from the template. + tu_source = ( + _TU_TEMPLATE.read_text() + .replace("@ROBOT_HEADER@", str(header_path)) + .replace("@FFI_HEADER@", str(_FFI_HEADER)) + .replace("@ROBOT_TYPE@", f"vamp::robots::{robot_type_name}") + .replace("@ROBOT_NAME@", robot_token) + ) + + # 4. JIT compile (object-cached on disk) and register the FFI handlers. + opts = jit.CompileOptions() + opts.std_flag = "-std=c++17" + opts.opt_flag = "-O3" + dirs = _default_include_dirs() + if include_dirs: + dirs.extend(include_dirs) + opts.include_dirs = dirs + opts.extra_flags = extra_flags or ["-march=native", "-fopenmp"] + opts.module_id = f"vamp_{robot_token}_{src_hash}" + + # Cache the (source, opts) hash so re-registration is skipped. + self._key = jit.hash_source(tu_source, opts) + self._configs_target = f"vamp_configs_{robot_token}_{src_hash}" + self._edges_target = f"vamp_edges_{robot_token}_{src_hash}" + + if self._key not in _REGISTERED: + _preload_runtime_libs() + session = jit.JitSession(work_dir) + # Reuse a previously JIT-compiled binary for this robot if present, + # skipping the (expensive) clang front-end; else compile + cache it. + if not session.try_load_cached(opts.module_id): + session.add_source(tu_source, opts) + jax.ffi.register_ffi_target( + self._configs_target, + session.handler_capsule("pyroffi_get_validate_configs"), + platform="cpu", + ) + jax.ffi.register_ffi_target( + self._edges_target, + session.handler_capsule("pyroffi_get_validate_edges"), + platform="cpu", + ) + # Keep the session alive for the process lifetime: the JIT-owned code + # must outlive every FFI call into it. + _REGISTERED[self._key] = self._configs_target + self._session = session + else: + self._session = None + + # World geometry cache (mirrors the CUDA checkers). + self._ws = np.zeros((0, 4), dtype=np.float32) + self._wc = np.zeros((0, 7), dtype=np.float32) + self._wb = np.zeros((0, 15), dtype=np.float32) + self._wp = np.zeros((0, 3), dtype=np.float32) + self._capt = (0.0, 0.0, 0.0) + + # Per-call caches so repeated checks against the same world are cheap: + # * _world_cache: CPU-resident obstacle buffers, keyed by object id + + # capt, so we skip the (GPU->host) CollGeom re-extraction each call; + # * _jit_cache: jax.jit'd FFI calls keyed by (kind, capt) so the kernel + # runs jitted (~tens of us) instead of via slow eager dispatch. + self._world_cache_key = None + self._world_cache = None + self._jit_cache: dict = {} + + # ── Properties ────────────────────────────────────────────────────────── + + @property + def dimension(self) -> int: + return self._dimension + + @property + def n_spheres(self) -> int: + return self._n_spheres + + # ── World handling ────────────────────────────────────────────────────── + + def set_world( + self, + world_geom: CollGeom, + point_cloud: Optional[Array] = None, + *, + capt_r_min: float = 0.0, + capt_r_max: float = 1.0, + capt_r_point: float = 0.0, + ) -> None: + """Cache the world obstacles (and optional CAPT point cloud).""" + self._ws, self._wc, self._wb, wh = _extract_world_arrays(world_geom) + if wh.shape[0] != 0: + raise NotImplementedError( + "The VAMP backend has no half-space primitive; represent a " + "ground plane as a large flat Box instead." + ) + if point_cloud is not None: + self._wp = np.asarray(point_cloud, dtype=np.float32).reshape(-1, 3) + self._capt = (float(capt_r_min), float(capt_r_max), float(capt_r_point)) + else: + self._wp = np.zeros((0, 3), dtype=np.float32) + self._capt = (0.0, 0.0, 0.0) + self._world_cache_key = None # invalidate the per-call cache + + @staticmethod + @lru_cache(maxsize=1) + def _cpu_device(): + # The handlers are registered for platform="cpu"; on a CUDA-default JAX + # install we must place operands on the host so ffi_call dispatches to + # the CPU target rather than CUDA. + return jax.devices("cpu")[0] + + def _world_args( + self, + world_geom: Optional[CollGeom], + point_cloud: Optional[Array], + capt: Optional[tuple[float, float, float]], + ): + """Resolve CPU-resident obstacle buffers for a call (cached by identity). + + Per-call ``world_geom`` / ``point_cloud`` override the cached + :meth:`set_world` state; when both are omitted the cached state is used. + Providing a ``world_geom`` without a ``point_cloud`` keeps the cached + cloud (it does not silently wipe it). + + Re-extracting a ``CollGeom`` means converting its (GPU-resident) JAX + arrays to host numpy — a sync worth ~1 ms. We therefore memoise the + result keyed by the object identities + capt, so repeatedly checking + against the same world (the planner / benchmark pattern) is free after + the first call. + """ + key = (id(world_geom), id(point_cloud), capt) + if key == self._world_cache_key and self._world_cache is not None: + return self._world_cache + + cpu = self._cpu_device() + if world_geom is not None: + ws, wc, wb, wh = _extract_world_arrays(world_geom) + if wh.shape[0] != 0: + raise NotImplementedError( + "The VAMP backend has no half-space primitive; represent a " + "ground plane as a large flat Box instead." + ) + else: + ws, wc, wb = self._ws, self._wc, self._wb + + if point_cloud is not None: + wp = np.asarray(point_cloud, dtype=np.float32).reshape(-1, 3) + capt_v = capt if capt is not None else (0.0, 1.0, 0.0) + else: + wp, capt_v = self._wp, self._capt + + arrays = ( + jax.device_put(np.asarray(ws, dtype=np.float32), cpu), + jax.device_put(np.asarray(wc, dtype=np.float32), cpu), + jax.device_put(np.asarray(wb, dtype=np.float32), cpu), + jax.device_put(np.asarray(wp, dtype=np.float32), cpu), + tuple(float(x) for x in capt_v), + ) + self._world_cache_key = key + self._world_cache = arrays + return arrays + + # ── Public API (mirrors CUDABinaryCollisionChecker) ───────────────────── + + def _jit_fn(self, kind: str, capt: tuple[float, float, float]): + """A cached ``jax.jit`` wrapper of the FFI call for (kind, capt). + + Baking ``capt`` (a static FFI attribute) into the traced function and + caching by (kind, capt) lets the kernel run jitted — ~tens of us — rather + than through slow per-call eager dispatch. jit retraces per operand + shape automatically.""" + key = (kind, capt) + fn = self._jit_cache.get(key) + if fn is not None: + return fn + + rmin, rmax, rpt = (np.float32(capt[0]), np.float32(capt[1]), np.float32(capt[2])) + if kind == "configs": + tgt = self._configs_target + + def impl(a, ws, wc, wb, wp): + out = jax.ffi.ffi_call( + tgt, jax.ShapeDtypeStruct((a.shape[0],), jnp.bool_) + )(a, ws, wc, wb, wp, + capt_r_min=rmin, capt_r_max=rmax, capt_r_point=rpt) + return out + else: + tgt = self._edges_target + + def impl(ab, ws, wc, wb, wp): + # ab: [E, 2, n] — slice the endpoints inside the jit so no eager + # dispatch happens on the host side. + a = ab[:, 0, :] + b = ab[:, 1, :] + out = jax.ffi.ffi_call( + tgt, jax.ShapeDtypeStruct((ab.shape[0],), jnp.bool_) + )(a, b, ws, wc, wb, wp, + capt_r_min=rmin, capt_r_max=rmax, capt_r_point=rpt) + return out + + fn = jax.jit(impl) + self._jit_cache[key] = fn + return fn + + def check_collision_free( + self, + robot, # accepted for API parity; FK is baked into the JIT binary + cfg: Float[Array, "*batch actuated_count"], + world_geom: Optional[CollGeom] = None, + point_cloud: Optional[Array] = None, + capt: Optional[tuple[float, float, float]] = None, + ) -> Array: + """Return a boolean per configuration: ``True`` if collision-free. + + ``capt`` is the optional ``(r_min, r_max, r_point)`` triple for a + ``point_cloud`` (point radius defaults to 0 — set ``r_point`` to give the + cloud thickness). + """ + ws, wc, wb, wp, capt = self._world_args(world_geom, point_cloud, capt) + cfg = jnp.asarray(cfg, dtype=jnp.float32) + batch_axes = cfg.shape[:-1] + n_act = cfg.shape[-1] + B = int(np.prod(batch_axes)) if batch_axes else 1 + cfg_flat = jax.device_put(cfg.reshape(B, n_act), self._cpu_device()) + + out = self._jit_fn("configs", capt)(cfg_flat, ws, wc, wb, wp) + return out.reshape(batch_axes) if batch_axes else out.reshape(()) + + def check_edges_collision_free( + self, + robot, + edge_cfgs: Float[Array, "*batch endpoint actuated_count"], + world_geom: Optional[CollGeom] = None, + point_cloud: Optional[Array] = None, + capt: Optional[tuple[float, float, float]] = None, + ) -> Array: + """Batch edge validation: ``True`` if the whole edge is collision-free. + + Unlike the CUDA checker — which expects pre-discretised points along the + ``granularity`` axis and AND-reduces them — VAMP discretises each edge + internally at the robot's planning resolution. ``edge_cfgs`` therefore + holds the two endpoints in its second-to-last axis (shape ``[*batch, 2, + n_act]``); the return shape is ``edge_cfgs.shape[:-2]``. + + VAMP samples the open interval ``(0, 1]``: the goal endpoint and the + interior are validated, but the start is assumed pre-validated (the usual + planner contract, matching VAMP's own ``validate_motion``). A ``True`` + verdict therefore implies the *goal* endpoint and interior are + collision-free; validate the start separately with + :meth:`check_collision_free` if you need it. + """ + ws, wc, wb, wp, capt = self._world_args(world_geom, point_cloud, capt) + edge_cfgs = jnp.asarray(edge_cfgs, dtype=jnp.float32) + *edge_axes, endpoints, n_act = edge_cfgs.shape + if endpoints != 2: + raise ValueError( + "VAMP edge validation expects exactly 2 endpoints per edge " + f"(got {endpoints}); it discretises internally at the robot " + "resolution." + ) + E = int(np.prod(edge_axes)) if edge_axes else 1 + # Move the whole edge array to the host in a single transfer; the jitted + # function slices the two endpoints internally. (Slicing the GPU array + # and transferring each strided half separately is what made small edge + # batches slow.) + flat = jax.device_put(edge_cfgs.reshape(E, 2, n_act), self._cpu_device()) + out = self._jit_fn("edges", capt)(flat, ws, wc, wb, wp) + return out.reshape(tuple(edge_axes)) if edge_axes else out.reshape(()) + + +def make_vamp_cpu_checker( + urdf_path: str | Path, + srdf_path: Optional[str | Path] = None, + end_effector: Optional[str] = None, + **kwargs, +) -> VAMPCPUCollisionChecker: + """Build a JIT-compiled VAMP CPU collision checker for ``urdf_path``.""" + return VAMPCPUCollisionChecker( + urdf_path, srdf_path=srdf_path, end_effector=end_effector, **kwargs + ) diff --git a/src/pyroffi/cuda_kernels/_brownian_motion_ik_cuda_lib.so b/src/pyroffi/cuda_kernels/_brownian_motion_ik_cuda_lib.so deleted file mode 100755 index 905a2a47..00000000 Binary files a/src/pyroffi/cuda_kernels/_brownian_motion_ik_cuda_lib.so and /dev/null differ diff --git a/src/pyroffi/cuda_kernels/_chomp_trajopt_cuda_lib.so b/src/pyroffi/cuda_kernels/_chomp_trajopt_cuda_lib.so deleted file mode 100755 index 16f17891..00000000 Binary files a/src/pyroffi/cuda_kernels/_chomp_trajopt_cuda_lib.so and /dev/null differ diff --git a/src/pyroffi/cuda_kernels/_collision_cuda_lib.so b/src/pyroffi/cuda_kernels/_collision_cuda_lib.so deleted file mode 100755 index 59b8059e..00000000 Binary files a/src/pyroffi/cuda_kernels/_collision_cuda_lib.so and /dev/null differ diff --git a/src/pyroffi/cuda_kernels/_fk_cuda_lib.so b/src/pyroffi/cuda_kernels/_fk_cuda_lib.so deleted file mode 100755 index cb17f4a5..00000000 Binary files a/src/pyroffi/cuda_kernels/_fk_cuda_lib.so and /dev/null differ diff --git a/src/pyroffi/cuda_kernels/_hit_and_run_ik_cuda_lib.so b/src/pyroffi/cuda_kernels/_hit_and_run_ik_cuda_lib.so deleted file mode 100755 index 9089214f..00000000 Binary files a/src/pyroffi/cuda_kernels/_hit_and_run_ik_cuda_lib.so and /dev/null differ diff --git a/src/pyroffi/cuda_kernels/_hjcd_ik_cuda_lib.so b/src/pyroffi/cuda_kernels/_hjcd_ik_cuda_lib.so deleted file mode 100755 index 653f96a3..00000000 Binary files a/src/pyroffi/cuda_kernels/_hjcd_ik_cuda_lib.so and /dev/null differ diff --git a/src/pyroffi/cuda_kernels/_ls_ik_cuda_lib.so b/src/pyroffi/cuda_kernels/_ls_ik_cuda_lib.so deleted file mode 100755 index 73b2520b..00000000 Binary files a/src/pyroffi/cuda_kernels/_ls_ik_cuda_lib.so and /dev/null differ diff --git a/src/pyroffi/cuda_kernels/_ls_trajopt_cuda_lib.so b/src/pyroffi/cuda_kernels/_ls_trajopt_cuda_lib.so deleted file mode 100755 index 0c5b6e07..00000000 Binary files a/src/pyroffi/cuda_kernels/_ls_trajopt_cuda_lib.so and /dev/null differ diff --git a/src/pyroffi/cuda_kernels/_mppi_ik_cuda_lib.so b/src/pyroffi/cuda_kernels/_mppi_ik_cuda_lib.so deleted file mode 100755 index bf6e55d8..00000000 Binary files a/src/pyroffi/cuda_kernels/_mppi_ik_cuda_lib.so and /dev/null differ diff --git a/src/pyroffi/cuda_kernels/_sco_trajopt_cuda_lib.so b/src/pyroffi/cuda_kernels/_sco_trajopt_cuda_lib.so deleted file mode 100755 index 59b8a898..00000000 Binary files a/src/pyroffi/cuda_kernels/_sco_trajopt_cuda_lib.so and /dev/null differ diff --git a/src/pyroffi/cuda_kernels/_sqp_ik_cuda_lib.so b/src/pyroffi/cuda_kernels/_sqp_ik_cuda_lib.so deleted file mode 100755 index feef69b8..00000000 Binary files a/src/pyroffi/cuda_kernels/_sqp_ik_cuda_lib.so and /dev/null differ diff --git a/src/pyroffi/cuda_kernels/_stomp_trajopt_cuda_lib.so b/src/pyroffi/cuda_kernels/_stomp_trajopt_cuda_lib.so deleted file mode 100755 index 14471f37..00000000 Binary files a/src/pyroffi/cuda_kernels/_stomp_trajopt_cuda_lib.so and /dev/null differ diff --git a/src/pyroffi/cuda_kernels/_svgd_region_ik_cuda_lib.so b/src/pyroffi/cuda_kernels/_svgd_region_ik_cuda_lib.so deleted file mode 100755 index f829da11..00000000 Binary files a/src/pyroffi/cuda_kernels/_svgd_region_ik_cuda_lib.so and /dev/null differ diff --git a/src/pyroffi/cuda_kernels/collision/__init__.py b/src/pyroffi/cuda_kernels/collision/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/pyroffi/cuda_kernels/_collision_binary_cuda_ffi.py b/src/pyroffi/cuda_kernels/collision/_collision_binary_cuda_ffi.py similarity index 96% rename from src/pyroffi/cuda_kernels/_collision_binary_cuda_ffi.py rename to src/pyroffi/cuda_kernels/collision/_collision_binary_cuda_ffi.py index 7e244f94..9acc7df2 100644 --- a/src/pyroffi/cuda_kernels/_collision_binary_cuda_ffi.py +++ b/src/pyroffi/cuda_kernels/collision/_collision_binary_cuda_ffi.py @@ -3,7 +3,7 @@ The companion shared library ``_collision_binary_cuda_lib.so`` must be compiled from ``_collision_binary_cuda_kernel.cu`` first: - bash src/pyroffi/cuda_kernels/build_collision_binary_cuda.sh + bash build_kernels/build_collision_binary_cuda.sh Requires JAX >= 0.4.14 (for jax.ffi). @@ -35,7 +35,7 @@ def _load_and_register() -> None: raise RuntimeError( f"CUDA binary-collision library not found at {lib_path}.\n" "Compile it first with:\n" - " bash src/pyroffi/cuda_kernels/build_collision_binary_cuda.sh\n" + " bash build_kernels/build_collision_binary_cuda.sh\n" "(This produces _collision_binary_cuda_lib.so alongside the kernel source.)" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_collision_binary_cuda_kernel.cu b/src/pyroffi/cuda_kernels/collision/_collision_binary_cuda_kernel.cu similarity index 99% rename from src/pyroffi/cuda_kernels/_collision_binary_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/collision/_collision_binary_cuda_kernel.cu index bf467212..0c820711 100644 --- a/src/pyroffi/cuda_kernels/_collision_binary_cuda_kernel.cu +++ b/src/pyroffi/cuda_kernels/collision/_collision_binary_cuda_kernel.cu @@ -36,7 +36,7 @@ * World geometry arrays match _collision_cuda_kernel.cu / _cuda_collision.py: * spheres [Ms, 4], capsules [Mc, 7], boxes [Mb, 15], halfspaces [Mh, 6]. * - * Build: bash src/pyroffi/cuda_kernels/build_collision_binary_cuda.sh + * Build: bash build_kernels/build_collision_binary_cuda.sh */ #include "xla/ffi/api/ffi.h" diff --git a/src/pyroffi/cuda_kernels/_collision_cuda_ffi.py b/src/pyroffi/cuda_kernels/collision/_collision_cuda_ffi.py similarity index 98% rename from src/pyroffi/cuda_kernels/_collision_cuda_ffi.py rename to src/pyroffi/cuda_kernels/collision/_collision_cuda_ffi.py index ac416d70..357e33da 100644 --- a/src/pyroffi/cuda_kernels/_collision_cuda_ffi.py +++ b/src/pyroffi/cuda_kernels/collision/_collision_cuda_ffi.py @@ -3,7 +3,7 @@ The companion shared library ``_collision_cuda_lib.so`` must be compiled from ``_collision_cuda_kernel.cu`` before this module can be imported: - bash src/pyroffi/cuda_kernels/build_collision_cuda.sh + bash build_kernels/build_collision_cuda.sh Requires JAX >= 0.4.14 (for jax.ffi). @@ -43,7 +43,7 @@ def _load_and_register() -> None: raise RuntimeError( f"CUDA collision library not found at {lib_path}.\n" "Compile it first with:\n" - " bash src/pyroffi/cuda_kernels/build_collision_cuda.sh\n" + " bash build_kernels/build_collision_cuda.sh\n" "(This produces _collision_cuda_lib.so alongside the kernel source.)" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_collision_cuda_kernel.cu b/src/pyroffi/cuda_kernels/collision/_collision_cuda_kernel.cu similarity index 95% rename from src/pyroffi/cuda_kernels/_collision_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/collision/_collision_cuda_kernel.cu index c39db56c..1361cc5a 100644 --- a/src/pyroffi/cuda_kernels/_collision_cuda_kernel.cu +++ b/src/pyroffi/cuda_kernels/collision/_collision_cuda_kernel.cu @@ -35,7 +35,7 @@ * positive → separated * negative → penetration * - * Build: bash src/pyroffi/cuda_kernels/build_collision_cuda.sh + * Build: bash build_kernels/build_collision_cuda.sh */ #include "xla/ffi/api/ffi.h" @@ -43,6 +43,17 @@ namespace ffi = xla::ffi; +// ── Multi-GPU safety ────────────────────────────────────────────────────────── +// +// Under jax.pmap, one host thread drives each visible GPU and the same FFI +// handler runs concurrently on every device. Any host-side cache (e.g. a CUDA +// graph exec containing device-specific pointers) must therefore be kept +// *per-device* — a single shared cache would let one device replay another +// device's graph (illegal access) and race on the cache struct. We index every +// cache by the current CUDA device ordinal, so each pmap worker only ever +// touches its own slot. +static constexpr int PYROFFI_MAX_GPUS = 16; + // ── Grid / tile constants ───────────────────────────────────────────────────── /// Threads per block along the robot (K or N) dimension. @@ -538,7 +549,13 @@ static ffi::Error CollisionWorldSphereImpl( B = K = Ms = Mc = Mb = Mh = -1; } }; - static GraphCache cache; + static GraphCache cache_pool[PYROFFI_MAX_GPUS]; + int _dev = 0; + cudaGetDevice(&_dev); + if (_dev < 0 || _dev >= PYROFFI_MAX_GPUS) + return ffi::Error(ffi::ErrorCode::kInternal, + "CUDA device ordinal exceeds PYROFFI_MAX_GPUS (rebuild with a larger limit)."); + GraphCache& cache = cache_pool[_dev]; if (B > 0 && K > 0 && M > 0) { const float* sc = sphere_centers.typed_data(); @@ -707,7 +724,13 @@ static ffi::Error CollisionWorldSphereReducedImpl( B = K = N = Ms = Mc = Mb = Mh = -1; } }; - static GraphCache cache; + static GraphCache cache_pool[PYROFFI_MAX_GPUS]; + int _dev = 0; + cudaGetDevice(&_dev); + if (_dev < 0 || _dev >= PYROFFI_MAX_GPUS) + return ffi::Error(ffi::ErrorCode::kInternal, + "CUDA device ordinal exceeds PYROFFI_MAX_GPUS (rebuild with a larger limit)."); + GraphCache& cache = cache_pool[_dev]; if (B > 0 && N > 0 && M > 0) { const float* sc = sphere_centers.typed_data(); @@ -873,7 +896,13 @@ static ffi::Error CollisionWorldCapsuleImpl( B = N = Ms = Mc = Mb = Mh = -1; } }; - static GraphCache cache; + static GraphCache cache_pool[PYROFFI_MAX_GPUS]; + int _dev = 0; + cudaGetDevice(&_dev); + if (_dev < 0 || _dev >= PYROFFI_MAX_GPUS) + return ffi::Error(ffi::ErrorCode::kInternal, + "CUDA device ordinal exceeds PYROFFI_MAX_GPUS (rebuild with a larger limit)."); + GraphCache& cache = cache_pool[_dev]; if (B > 0 && N > 0 && M > 0) { const float* cp = caps.typed_data(); @@ -1032,7 +1061,13 @@ static ffi::Error CollisionSelfSphereImpl( B = S = N = P = -1; } }; - static GraphCache cache; + static GraphCache cache_pool[PYROFFI_MAX_GPUS]; + int _dev = 0; + cudaGetDevice(&_dev); + if (_dev < 0 || _dev >= PYROFFI_MAX_GPUS) + return ffi::Error(ffi::ErrorCode::kInternal, + "CUDA device ordinal exceeds PYROFFI_MAX_GPUS (rebuild with a larger limit)."); + GraphCache& cache = cache_pool[_dev]; if (total > 0) { const int blocks = (total + 255) / 256; @@ -1138,7 +1173,13 @@ static ffi::Error CollisionSelfCapsuleImpl( B = N = P = -1; } }; - static GraphCache cache; + static GraphCache cache_pool[PYROFFI_MAX_GPUS]; + int _dev = 0; + cudaGetDevice(&_dev); + if (_dev < 0 || _dev >= PYROFFI_MAX_GPUS) + return ffi::Error(ffi::ErrorCode::kInternal, + "CUDA device ordinal exceeds PYROFFI_MAX_GPUS (rebuild with a larger limit)."); + GraphCache& cache = cache_pool[_dev]; if (total > 0) { const int blocks = (total + 255) / 256; diff --git a/src/pyroffi/cuda_kernels/collision/_robogpu_collision_ffi.py b/src/pyroffi/cuda_kernels/collision/_robogpu_collision_ffi.py new file mode 100644 index 00000000..63a2cf9d --- /dev/null +++ b/src/pyroffi/cuda_kernels/collision/_robogpu_collision_ffi.py @@ -0,0 +1,122 @@ +"""JAX FFI wrapper for the RoboGPU sphere-octree collision-check kernel. + +The companion shared library ``_robogpu_collision_lib.so`` must be compiled +before importing this module: + + bash build_kernels/build_robogpu_collision.sh + +The build script also compiles ``_robogpu_optix_programs.ptx`` (OptiX device +programs); both files must sit alongside the .so at runtime. + +The XLA FFI target ``"robogpu_collision"`` accepts the same FK model arrays as +``CUDABinaryCollisionChecker`` plus a point cloud [Mp, 3] and two scalar +attributes (``r_env``, ``r_robot_max``). It returns ``int32[B]`` with 1 = +collision-free and 0 = in-collision. +""" + +from __future__ import annotations + +import ctypes +from functools import lru_cache +from pathlib import Path + +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array + +_LIB_NAME = "_robogpu_collision_lib.so" +_FFI_TARGET = "robogpu_collision" + + +@lru_cache(maxsize=1) +def _load_and_register() -> None: + lib_path = Path(__file__).parent / _LIB_NAME + if not lib_path.exists(): + raise RuntimeError( + f"RoboGPU collision library not found at {lib_path}.\n" + "Compile it first with:\n" + " bash build_kernels/build_robogpu_collision.sh\n" + "(Requires NVIDIA OptiX SDK 7.x and nvcc.)" + ) + ptx_path = lib_path.parent / "_robogpu_optix_programs.ptx" + if not ptx_path.exists(): + raise RuntimeError( + f"RoboGPU OptiX PTX not found at {ptx_path}.\n" + "Run build_kernels/build_robogpu_collision.sh to produce it." + ) + + lib = ctypes.CDLL(str(lib_path)) + _New = ctypes.pythonapi.PyCapsule_New + _New.restype = ctypes.py_object + _New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] + capsule = _New( + ctypes.cast(lib.RoboGPUCollisionFfi, ctypes.c_void_p), + b"xla._CUSTOM_CALL_TARGET", + None, + ) + jax.ffi.register_ffi_target(_FFI_TARGET, capsule, platform="CUDA") + + +def robogpu_collision( + cfg: Array, # [B, n_act] float32 + twists: Array, # [J, 6] float32 + parent_tf: Array, # [J, 7] float32 + parent_idx: Array, # [J] int32 + act_idx: Array, # [J] int32 + mimic_mul: Array, # [J] float32 + mimic_off: Array, # [J] float32 + mimic_act_idx: Array, # [J] int32 + topo_inv: Array, # [J] int32 + link_parent_joint: Array, # [NL] int32 + f_local: Array, # [K, 4] float32 (k = s*NL + n) + f_pair_i: Array, # [Pf] int32 + f_pair_j: Array, # [Pf] int32 + world_spheres: Array, # [Ms, 4] float32 + world_capsules: Array, # [Mc, 7] float32 + world_boxes: Array, # [Mb, 15] float32 + world_halfspaces: Array, # [Mh, 6] float32 + point_cloud: Array, # [Mp, 3] float32 + r_env: float, # env sphere radius per point + r_robot_max: float, # max robot sphere radius (BVH AABB expansion) +) -> Array: # [B] int32 1=free, 0=collision + """Fused FK + binary collision check with OptiX point-cloud BVH traversal. + + Regular world geometry (spheres, capsules, boxes, halfspaces) and self- + collision are checked in a CUDA kernel (Stage 1). The environment point + cloud is indexed in an OptiX BVH; robot sphere centres query it via ray + tracing with any-hit early exit (Stage 2). + + Pass ``point_cloud`` with shape ``[0, 3]`` to skip the OptiX stage and + run only the regular world geometry + self-collision check. + """ + _load_and_register() + B = cfg.shape[0] + r_env_np = np.float32(r_env) + r_robot_np = np.float32(r_robot_max) + return jax.ffi.ffi_call( + _FFI_TARGET, + jax.ShapeDtypeStruct((B,), jnp.int32), + )( + cfg.astype(jnp.float32), + twists.astype(jnp.float32), + parent_tf.astype(jnp.float32), + parent_idx.astype(jnp.int32), + act_idx.astype(jnp.int32), + mimic_mul.astype(jnp.float32), + mimic_off.astype(jnp.float32), + mimic_act_idx.astype(jnp.int32), + topo_inv.astype(jnp.int32), + link_parent_joint.astype(jnp.int32), + f_local.astype(jnp.float32), + f_pair_i.astype(jnp.int32), + f_pair_j.astype(jnp.int32), + world_spheres.astype(jnp.float32), + world_capsules.astype(jnp.float32), + world_boxes.astype(jnp.float32), + world_halfspaces.astype(jnp.float32), + point_cloud.astype(jnp.float32), + # Scalar attributes (consumed by .Attr() in the FFI handler). + r_env=r_env_np, + r_robot_max=r_robot_np, + ) diff --git a/src/pyroffi/cuda_kernels/collision/_robogpu_collision_host.cu b/src/pyroffi/cuda_kernels/collision/_robogpu_collision_host.cu new file mode 100644 index 00000000..513b82d9 --- /dev/null +++ b/src/pyroffi/cuda_kernels/collision/_robogpu_collision_host.cu @@ -0,0 +1,800 @@ +/** + * RoboGPU: GPU-accelerated sphere-octree collision checking via NVIDIA OptiX. + * + * Implements the RoboGPU architecture (arXiv:2603.01517) adapted for + * sphere-based robot representations used in pyroffi. OBB-AABB SAT is + * replaced by sphere-sphere occupancy queries against an OptiX BVH built from + * the environment point cloud, retaining the massive parallelism and early- + * exit benefits of the original design. + * + * Stage 1 (CUDA kernel — robogpu_prepare_kernel): + * FK → link world transforms → transform robot collision spheres to world + * frame → check against regular world geometry (spheres/capsules/boxes/ + * halfspaces) → self-collision check. Outputs: world-frame robot spheres + * [B*K, 4] and per-config free flags [B]. + * + * Stage 2 (OptiX — same CUDA stream, no host sync needed): + * For each config still marked free, fires K "query rays" (one per robot + * sphere) into the env sphere BVH. The custom intersection program does + * the sphere-sphere proximity test; any-hit terminates BVH traversal on + * the first hit (early exit at both the per-env-sphere and per-robot-sphere + * levels). + * + * Build: bash build_kernels/build_robogpu_collision.sh + * + * Requires NVIDIA OptiX SDK 7.x (headers; runtime loaded via optixInit()). + */ + +// optix_function_table_definition.h must appear in exactly one TU — it provides +// the storage for the OptiX function table that optixInit() populates. +// optix_stubs.h provides the thin wrapper functions (optixInit, optixLaunch, …). +#include +#include +#include + +#include "xla/ffi/api/ffi.h" +#include "_collision_cuda_helpers.cuh" // includes _fk_cuda_helpers.cuh; all dist prims + fk_single + +#include +#include +#include // dladdr (Linux) + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ffi = xla::ffi; + +// --------------------------------------------------------------------------- +// Multi-GPU safety +// +// Under jax.pmap each visible GPU is driven by its own host thread running this +// handler concurrently. The OptiX pipeline, BVH cache, and scratch buffer are +// all bound to a specific CUDA device/context, so they must be kept per-device +// and indexed by the current device ordinal — sharing them across devices would +// launch one device's pipeline/BVH against another device's stream and memory. +// --------------------------------------------------------------------------- + +static constexpr int PYROFFI_MAX_GPUS = 16; + +// Current CUDA device ordinal, or -1 if it exceeds PYROFFI_MAX_GPUS / errors. +static int robogpu_current_device() { + int dev = 0; + if (cudaGetDevice(&dev) != cudaSuccess) return -1; + if (dev < 0 || dev >= PYROFFI_MAX_GPUS) return -1; + return dev; +} + +// --------------------------------------------------------------------------- +// Error-checking macros (return ffi::Error on failure) +// --------------------------------------------------------------------------- + +#define CUDA_CHECK(call) \ + do { \ + cudaError_t _e = (call); \ + if (_e != cudaSuccess) { \ + return ffi::Error(ffi::ErrorCode::kInternal, \ + cudaGetErrorString(_e)); \ + } \ + } while (0) + +#define CUDA_CHECK_VOID(call) \ + do { \ + cudaError_t _e = (call); \ + if (_e != cudaSuccess) \ + fprintf(stderr, "CUDA %s:%d %s\n", \ + __FILE__, __LINE__, cudaGetErrorString(_e)); \ + } while (0) + +#define OPTIX_CHECK(call) \ + do { \ + OptixResult _r = (call); \ + if (_r != OPTIX_SUCCESS) { \ + return ffi::Error(ffi::ErrorCode::kInternal, \ + "OptiX call failed (code=" + \ + std::to_string((int)_r) + ")"); \ + } \ + } while (0) + +#define OPTIX_CHECK_VOID(call) \ + do { \ + OptixResult _r = (call); \ + if (_r != OPTIX_SUCCESS) \ + fprintf(stderr, "OptiX %s:%d code=%d\n", \ + __FILE__, __LINE__, (int)_r); \ + } while (0) + +// --------------------------------------------------------------------------- +// Shared data structures (must match _robogpu_optix_programs.cu exactly) +// --------------------------------------------------------------------------- + +struct RoboGPULaunchParams { + OptixTraversableHandle handle; + const float4* robot_spheres; // [B * K, 4] world-frame + int32_t* out_free; // [B] 1=free, 0=collision (in/out) + int B; + int K; +}; + +struct HitGroupData { + const float4* env_spheres; // [Mp, 4] (cx, cy, cz, r_env) +}; + +// SBT record wrappers — must be OPTIX_SBT_RECORD_ALIGNMENT-aligned. +struct alignas(OPTIX_SBT_RECORD_ALIGNMENT) RaygenRecord { + char header[OPTIX_SBT_RECORD_HEADER_SIZE]; +}; +struct alignas(OPTIX_SBT_RECORD_ALIGNMENT) MissRecord { + char header[OPTIX_SBT_RECORD_HEADER_SIZE]; +}; +struct alignas(OPTIX_SBT_RECORD_ALIGNMENT) HitGroupRecord { + char header[OPTIX_SBT_RECORD_HEADER_SIZE]; + HitGroupData data; +}; + +// --------------------------------------------------------------------------- +// CUDA prepare kernel constants +// --------------------------------------------------------------------------- + +#define RGB_MAX_JOINTS 64 +#define RGB_MAX_LINKS 64 +#define RGB_THREADS 64 + +// --------------------------------------------------------------------------- +// World-geometry hit test (mirrors _collision_binary_cuda_kernel.cu) +// --------------------------------------------------------------------------- + +__device__ __forceinline__ bool sphere_world_hit( + float px, float py, float pz, float r, + const float* __restrict__ ws, int Ms, + const float* __restrict__ wc, int Mc, + const float* __restrict__ wb, int Mb, + const float* __restrict__ wh, int Mh) +{ + for (int i = 0; i < Ms; i++) { + const float* o = ws + i * 4; + if (sphere_sphere_dist(px, py, pz, r, o[0], o[1], o[2], o[3]) < 0.0f) return true; + } + for (int i = 0; i < Mc; i++) { + const float* o = wc + i * 7; + if (sphere_capsule_dist(px, py, pz, r, + o[0], o[1], o[2], o[3], o[4], o[5], o[6]) < 0.0f) return true; + } + for (int i = 0; i < Mb; i++) { + const float* o = wb + i * 15; + if (sphere_box_dist(px, py, pz, r, + o[0], o[1], o[2], o[3], o[4], o[5], o[6], o[7], o[8], + o[9], o[10], o[11], o[12], o[13], o[14]) < 0.0f) return true; + } + for (int i = 0; i < Mh; i++) { + const float* o = wh + i * 6; + if (sphere_halfspace_dist(px, py, pz, r, + o[0], o[1], o[2], o[3], o[4], o[5]) < 0.0f) return true; + } + return false; +} + +// --------------------------------------------------------------------------- +// Point-cloud → sphere + AABB conversion kernel +// +// For each env point p_i: sphere = (p_i, r_env), AABB expanded by r_total +// (= r_env + r_robot_max) so that any robot sphere centre within collision +// range of p_i falls inside the AABB and triggers BVH traversal. +// --------------------------------------------------------------------------- + +__global__ void build_env_spheres_kernel( + const float* __restrict__ pc, // [Mp, 3] + float4* __restrict__ sph, // [Mp, 4] → env sphere (x,y,z,r_env) + OptixAabb* __restrict__ aabb, // [Mp] + float r_env, float r_total, // r_total = r_env + r_robot_max + int Mp) +{ + const int i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= Mp) return; + const float x = pc[i * 3 + 0]; + const float y = pc[i * 3 + 1]; + const float z = pc[i * 3 + 2]; + sph[i] = make_float4(x, y, z, r_env); + aabb[i] = { x - r_total, y - r_total, z - r_total, + x + r_total, y + r_total, z + r_total }; +} + +// --------------------------------------------------------------------------- +// Stage-1 CUDA kernel: FK + sphere transform + world geometry + self-collision +// +// One block per configuration (blockIdx.x = b). RGB_THREADS threads cooperate +// over the K robot spheres and Pf self-collision pairs. +// +// Outputs: +// robot_spheres_world[b*K .. b*K+K-1] — world-frame (x,y,z,r) per sphere +// out_free[b] — 1 if free, 0 if collision +// --------------------------------------------------------------------------- + +__global__ void robogpu_prepare_kernel( + const float* __restrict__ cfg, // [B, n_act] + const float* __restrict__ twists, // [J, 6] + const float* __restrict__ parent_tf, // [J, 7] + const int* __restrict__ parent_idx, // [J] + const int* __restrict__ act_idx, // [J] + const float* __restrict__ mimic_mul, // [J] + const float* __restrict__ mimic_off, // [J] + const int* __restrict__ mimic_act_idx, // [J] + const int* __restrict__ topo_inv, // [J] + const int* __restrict__ link_parent_joint, // [NL] + const float* __restrict__ f_local, // [K, 4] k = s*NL + n + const int* __restrict__ f_pair_i, // [Pf] + const int* __restrict__ f_pair_j, // [Pf] + const float* __restrict__ ws, int Ms, + const float* __restrict__ wc, int Mc, + const float* __restrict__ wb, int Mb, + const float* __restrict__ wh, int Mh, + float4* __restrict__ robot_spheres_world, // [B*K, 4] output + int* __restrict__ out_free, // [B] output + int B, int n_act, int J, int NL, int K, int Pf) +{ + const int b = blockIdx.x; + if (b >= B) return; + const int tid = threadIdx.x; + const int nt = blockDim.x; // == RGB_THREADS + + __shared__ float Tw[RGB_MAX_JOINTS * 7]; // joint world transforms + __shared__ float Tl[RGB_MAX_LINKS * 7]; // link world transforms + __shared__ volatile int cc; // collision flag + + const int NL_cap = min(NL, RGB_MAX_LINKS); + // Spheres per link: layout is k = s*NL + n → link n has Sf spheres + const int Sf = (NL > 0) ? (K / NL) : 0; + + // ── FK (thread 0 walks the topological order) ──────────────────────────── + if (tid == 0) { + fk_single(cfg + (long long)b * n_act, + twists, parent_tf, parent_idx, act_idx, + mimic_mul, mimic_off, mimic_act_idx, topo_inv, + Tw, J, n_act); + cc = 0; + } + __syncthreads(); + + // ── Compute per-link world transforms from joint transforms ────────────── + for (int l = tid; l < NL_cap; l += nt) { + const int pj = link_parent_joint[l]; + float* dst = Tl + l * 7; + if (pj < 0) { + dst[0]=1.f; dst[1]=dst[2]=dst[3]=dst[4]=dst[5]=dst[6]=0.f; + } else { + #pragma unroll + for (int i = 0; i < 7; i++) dst[i] = Tw[pj * 7 + i]; + } + } + __syncthreads(); + + // ── Transform all robot spheres to world frame; store in global mem ────── + // We write all spheres even if a collision is found, so that OptiX's + // raygen can use the buffer (it skips configs where out_free[b]=0 anyway). + float4* base = robot_spheres_world + (long long)b * K; + for (int k = tid; k < K; k += nt) { + const int n = k % NL; + const float* lp = f_local + k * 4; + if (lp[3] < 0.0f || n >= NL_cap) { + base[k] = make_float4(0.f, 0.f, 0.f, -1.f); // padding + continue; + } + float p[3] = { lp[0], lp[1], lp[2] }, w[3]; + apply_se3_point(Tl + n * 7, p, w); + base[k] = make_float4(w[0], w[1], w[2], lp[3]); + } + __syncthreads(); + + // ── World geometry check (early exit via shared cc flag) ───────────────── + for (int k = tid; k < K; k += nt) { + if (cc) break; + const float4 s = base[k]; + if (s.w < 0.0f) continue; + if (sphere_world_hit(s.x, s.y, s.z, s.w, + ws, Ms, wc, Mc, wb, Mb, wh, Mh)) + cc = 1; + } + __syncthreads(); + if (cc) { if (tid == 0) out_free[b] = 0; return; } + + // ── Self-collision check (fine sphere pairs) ───────────────────────────── + for (int p_idx = tid; p_idx < Pf; p_idx += nt) { + if (cc) break; + const int li = f_pair_i[p_idx]; + const int lj = f_pair_j[p_idx]; + for (int si = 0; si < Sf && !cc; ++si) { + const int ki = si * NL + li; + if (ki >= K) continue; + const float4 wi = base[ki]; + if (wi.w < 0.0f) continue; + for (int sj = 0; sj < Sf && !cc; ++sj) { + const int kj = sj * NL + lj; + if (kj >= K) continue; + const float4 wj = base[kj]; + if (wj.w < 0.0f) continue; + if (sphere_sphere_dist(wi.x, wi.y, wi.z, wi.w, + wj.x, wj.y, wj.z, wj.w) < 0.0f) + cc = 1; + } + } + } + __syncthreads(); + + if (tid == 0) out_free[b] = (cc ? 0 : 1); +} + +// --------------------------------------------------------------------------- +// OptiX pipeline (process-lifetime singleton) +// --------------------------------------------------------------------------- + +struct OptiXPipeline { + OptixDeviceContext ctx = nullptr; + OptixModule module = nullptr; + OptixProgramGroup pg_rg = nullptr; // raygen + OptixProgramGroup pg_ms = nullptr; // miss + OptixProgramGroup pg_hg = nullptr; // hit group + OptixPipeline pipeline = nullptr; + bool ready = false; +}; + +static OptiXPipeline g_pipe[PYROFFI_MAX_GPUS]; +static std::mutex g_pipe_mtx[PYROFFI_MAX_GPUS]; + +// --------------------------------------------------------------------------- +// Locate the PTX file alongside the running shared library +// --------------------------------------------------------------------------- + +static std::string ptx_file_path() { + Dl_info info{}; + if (dladdr(reinterpret_cast(&ptx_file_path), &info) + && info.dli_fname) { + std::string path(info.dli_fname); + auto slash = path.rfind('/'); + std::string dir = (slash == std::string::npos) ? "." : path.substr(0, slash); + return dir + "/_robogpu_optix_programs.ptx"; + } + return "_robogpu_optix_programs.ptx"; +} + +// --------------------------------------------------------------------------- +// Initialise OptiX pipeline (idempotent, mutex-protected) +// --------------------------------------------------------------------------- + +static ffi::Error ensure_optix_pipeline(int dev) { + OptiXPipeline& gp = g_pipe[dev]; + std::lock_guard lk(g_pipe_mtx[dev]); + if (gp.ready) return ffi::Error::Success(); + + OPTIX_CHECK(optixInit()); + + CUcontext cu_ctx = nullptr; + OptixDeviceContextOptions ctx_opts = {}; + ctx_opts.logCallbackLevel = 1; // errors only + OPTIX_CHECK(optixDeviceContextCreate(cu_ctx, &ctx_opts, &gp.ctx)); + + // Load PTX from file. + std::string ptx_path = ptx_file_path(); + std::ifstream ifs(ptx_path, std::ios::binary); + if (!ifs) + return ffi::Error(ffi::ErrorCode::kInternal, + "RoboGPU: cannot open PTX: " + ptx_path); + std::ostringstream ss; ss << ifs.rdbuf(); + std::string ptx = ss.str(); + + OptixModuleCompileOptions mco = {}; + mco.optLevel = OPTIX_COMPILE_OPTIMIZATION_DEFAULT; + mco.debugLevel = OPTIX_COMPILE_DEBUG_LEVEL_NONE; + + OptixPipelineCompileOptions pco = {}; + pco.traversableGraphFlags = OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS; + pco.numPayloadValues = 2; // p0=robot_r, p1=hit_flag + pco.numAttributeValues = 0; + pco.exceptionFlags = OPTIX_EXCEPTION_FLAG_NONE; + pco.pipelineLaunchParamsVariableName = "params"; + pco.usesPrimitiveTypeFlags = + static_cast(OPTIX_PRIMITIVE_TYPE_FLAGS_CUSTOM); + + char log[2048]; size_t lsz = sizeof(log); + OPTIX_CHECK(optixModuleCreate(gp.ctx, &mco, &pco, + ptx.c_str(), ptx.size(), + log, &lsz, &gp.module)); + + OptixProgramGroupOptions pgo = {}; + + // Raygen + { + OptixProgramGroupDesc d = {}; + d.kind = OPTIX_PROGRAM_GROUP_KIND_RAYGEN; + d.raygen.module = gp.module; + d.raygen.entryFunctionName = "__raygen__sphere_query"; + OPTIX_CHECK(optixProgramGroupCreate(gp.ctx, &d, 1, &pgo, + log, &lsz, &gp.pg_rg)); + } + // Miss + { + OptixProgramGroupDesc d = {}; + d.kind = OPTIX_PROGRAM_GROUP_KIND_MISS; + d.miss.module = gp.module; + d.miss.entryFunctionName = "__miss__sphere"; + OPTIX_CHECK(optixProgramGroupCreate(gp.ctx, &d, 1, &pgo, + log, &lsz, &gp.pg_ms)); + } + // Hit group (intersection + any-hit; no closest-hit) + { + OptixProgramGroupDesc d = {}; + d.kind = OPTIX_PROGRAM_GROUP_KIND_HITGROUP; + d.hitgroup.moduleIS = gp.module; + d.hitgroup.entryFunctionNameIS = "__intersection__sphere"; + d.hitgroup.moduleAH = gp.module; + d.hitgroup.entryFunctionNameAH = "__anyhit__sphere"; + OPTIX_CHECK(optixProgramGroupCreate(gp.ctx, &d, 1, &pgo, + log, &lsz, &gp.pg_hg)); + } + + OptixProgramGroup pgs[] = { gp.pg_rg, gp.pg_ms, gp.pg_hg }; + OptixPipelineLinkOptions plo = {}; + plo.maxTraceDepth = 1; + OPTIX_CHECK(optixPipelineCreate(gp.ctx, &pco, &plo, + pgs, 3, log, &lsz, &gp.pipeline)); + OPTIX_CHECK(optixPipelineSetStackSize(gp.pipeline, + 2048, 2048, 2048, 1)); + + gp.ready = true; + return ffi::Error::Success(); +} + +// --------------------------------------------------------------------------- +// BVH cache entry (one per unique point cloud + r_env + r_robot_max) +// --------------------------------------------------------------------------- + +struct BVHEntry { + CUdeviceptr d_gas = 0; // compacted GAS + size_t d_gas_size = 0; + OptixTraversableHandle handle = {}; + + CUdeviceptr d_env_spheres = 0; // [Mp, 4] float4 on device + CUdeviceptr d_aabbs = 0; // [Mp] OptixAabb on device + int Mp = 0; + + CUdeviceptr d_launch_params = 0; // per-call (updated each call) + + // SBT device buffers + CUdeviceptr d_sbt_rg = 0; + CUdeviceptr d_sbt_ms = 0; + CUdeviceptr d_sbt_hg = 0; + OptixShaderBindingTable sbt = {}; +}; + +static std::unordered_map g_bvh_cache[PYROFFI_MAX_GPUS]; +static std::mutex g_bvh_mtx[PYROFFI_MAX_GPUS]; + +// --------------------------------------------------------------------------- +// Build (or return cached) BVH for a given point cloud +// --------------------------------------------------------------------------- + +static BVHEntry* build_bvh( + cudaStream_t stream, + const float* d_pc, // [Mp, 3] device pointer (from JAX buffer) + int Mp, + float r_env, + float r_robot_max, + int dev) +{ + OptiXPipeline& gp = g_pipe[dev]; + BVHEntry* e = new BVHEntry{}; + e->Mp = Mp; + + CUDA_CHECK_VOID(cudaMalloc(reinterpret_cast(&e->d_env_spheres), + (size_t)Mp * sizeof(float4))); + CUDA_CHECK_VOID(cudaMalloc(reinterpret_cast(&e->d_aabbs), + (size_t)Mp * sizeof(OptixAabb))); + + const int blk = 256; + build_env_spheres_kernel<<<(Mp + blk - 1) / blk, blk, 0, stream>>>( + d_pc, + reinterpret_cast(e->d_env_spheres), + reinterpret_cast(e->d_aabbs), + r_env, r_env + r_robot_max, Mp); + + CUDA_CHECK_VOID(cudaStreamSynchronize(stream)); + + // GAS build + unsigned int geo_flags = OPTIX_GEOMETRY_FLAG_NONE; + CUdeviceptr aabb_ptr = e->d_aabbs; + + OptixBuildInput bi = {}; + bi.type = OPTIX_BUILD_INPUT_TYPE_CUSTOM_PRIMITIVES; + bi.customPrimitiveArray.aabbBuffers = &aabb_ptr; + bi.customPrimitiveArray.numPrimitives = static_cast(Mp); + bi.customPrimitiveArray.strideInBytes = sizeof(OptixAabb); + bi.customPrimitiveArray.flags = &geo_flags; + bi.customPrimitiveArray.numSbtRecords = 1; + + OptixAccelBuildOptions abo = {}; + abo.buildFlags = OPTIX_BUILD_FLAG_ALLOW_COMPACTION + | OPTIX_BUILD_FLAG_PREFER_FAST_TRACE; + abo.operation = OPTIX_BUILD_OPERATION_BUILD; + + OptixAccelBufferSizes bs = {}; + OPTIX_CHECK_VOID(optixAccelComputeMemoryUsage(gp.ctx, &abo, &bi, 1, &bs)); + + CUdeviceptr d_tmp = 0, d_out = 0, d_compact_sz = 0; + CUDA_CHECK_VOID(cudaMalloc(reinterpret_cast(&d_tmp), bs.tempSizeInBytes)); + CUDA_CHECK_VOID(cudaMalloc(reinterpret_cast(&d_out), bs.outputSizeInBytes)); + CUDA_CHECK_VOID(cudaMalloc(reinterpret_cast(&d_compact_sz), sizeof(size_t))); + + OptixAccelEmitDesc emit = {}; + emit.type = OPTIX_PROPERTY_TYPE_COMPACTED_SIZE; + emit.result = d_compact_sz; + + OptixTraversableHandle raw_handle = {}; + OPTIX_CHECK_VOID(optixAccelBuild(gp.ctx, stream, &abo, + &bi, 1, + d_tmp, bs.tempSizeInBytes, + d_out, bs.outputSizeInBytes, + &raw_handle, &emit, 1)); + CUDA_CHECK_VOID(cudaStreamSynchronize(stream)); + + size_t compact_sz = 0; + CUDA_CHECK_VOID(cudaMemcpy(&compact_sz, + reinterpret_cast(d_compact_sz), + sizeof(size_t), cudaMemcpyDeviceToHost)); + CUDA_CHECK_VOID(cudaMalloc(reinterpret_cast(&e->d_gas), compact_sz)); + e->d_gas_size = compact_sz; + OPTIX_CHECK_VOID(optixAccelCompact(gp.ctx, stream, raw_handle, + e->d_gas, compact_sz, &e->handle)); + CUDA_CHECK_VOID(cudaStreamSynchronize(stream)); + + CUDA_CHECK_VOID(cudaFree(reinterpret_cast(d_tmp))); + CUDA_CHECK_VOID(cudaFree(reinterpret_cast(d_out))); + CUDA_CHECK_VOID(cudaFree(reinterpret_cast(d_compact_sz))); + + // Build SBT + RaygenRecord rg_rec = {}; + MissRecord ms_rec = {}; + HitGroupRecord hg_rec = {}; + OPTIX_CHECK_VOID(optixSbtRecordPackHeader(gp.pg_rg, &rg_rec)); + OPTIX_CHECK_VOID(optixSbtRecordPackHeader(gp.pg_ms, &ms_rec)); + OPTIX_CHECK_VOID(optixSbtRecordPackHeader(gp.pg_hg, &hg_rec)); + hg_rec.data.env_spheres = reinterpret_cast(e->d_env_spheres); + + CUDA_CHECK_VOID(cudaMalloc(reinterpret_cast(&e->d_sbt_rg), sizeof(rg_rec))); + CUDA_CHECK_VOID(cudaMalloc(reinterpret_cast(&e->d_sbt_ms), sizeof(ms_rec))); + CUDA_CHECK_VOID(cudaMalloc(reinterpret_cast(&e->d_sbt_hg), sizeof(hg_rec))); + CUDA_CHECK_VOID(cudaMemcpy(reinterpret_cast(e->d_sbt_rg), &rg_rec, sizeof(rg_rec), cudaMemcpyHostToDevice)); + CUDA_CHECK_VOID(cudaMemcpy(reinterpret_cast(e->d_sbt_ms), &ms_rec, sizeof(ms_rec), cudaMemcpyHostToDevice)); + CUDA_CHECK_VOID(cudaMemcpy(reinterpret_cast(e->d_sbt_hg), &hg_rec, sizeof(hg_rec), cudaMemcpyHostToDevice)); + + e->sbt.raygenRecord = e->d_sbt_rg; + e->sbt.missRecordBase = e->d_sbt_ms; + e->sbt.missRecordStrideInBytes = sizeof(MissRecord); + e->sbt.missRecordCount = 1; + e->sbt.hitgroupRecordBase = e->d_sbt_hg; + e->sbt.hitgroupRecordStrideInBytes = sizeof(HitGroupRecord); + e->sbt.hitgroupRecordCount = 1; + + CUDA_CHECK_VOID(cudaMalloc(reinterpret_cast(&e->d_launch_params), + sizeof(RoboGPULaunchParams))); + return e; +} + +static BVHEntry* get_or_build_bvh( + cudaStream_t stream, + const float* d_pc, int Mp, + float r_env, float r_robot_max, + const std::string& key, + int dev) +{ + { + std::lock_guard lk(g_bvh_mtx[dev]); + auto it = g_bvh_cache[dev].find(key); + if (it != g_bvh_cache[dev].end()) return it->second; + } + BVHEntry* e = build_bvh(stream, d_pc, Mp, r_env, r_robot_max, dev); + if (e) { + std::lock_guard lk(g_bvh_mtx[dev]); + g_bvh_cache[dev].emplace(key, e); + } + return e; +} + +// --------------------------------------------------------------------------- +// Persistent scratch buffer for the world-frame robot spheres [B*K, 4]. +// +// Allocating this per call with cudaMallocAsync forces synchronization against +// JAX's own caching GPU allocator (they manage separate pools), adding a fixed +// ~0.5 ms stall to every check. Instead we keep one buffer that grows +// monotonically and is reused across calls. JAX serialises FFI calls on a +// single stream, so sequential reuse is safe; the mutex guards the rare grow. +// --------------------------------------------------------------------------- + +struct ScratchBuffer { + float4* ptr = nullptr; + size_t capacity = 0; // in float4 elements +}; +static ScratchBuffer g_scratch[PYROFFI_MAX_GPUS]; +static std::mutex g_scratch_mtx[PYROFFI_MAX_GPUS]; + +// Returns a device buffer of at least `n` float4 elements (nullptr on failure). +static float4* get_scratch(size_t n, int dev) { + ScratchBuffer& sc = g_scratch[dev]; + std::lock_guard lk(g_scratch_mtx[dev]); + if (n <= sc.capacity) return sc.ptr; + if (sc.ptr) cudaFree(sc.ptr); + // Over-allocate (1.5x) to amortise growth across increasing batch sizes. + size_t want = n + n / 2; + if (cudaMalloc(reinterpret_cast(&sc.ptr), + want * sizeof(float4)) != cudaSuccess) { + sc.ptr = nullptr; + sc.capacity = 0; + return nullptr; + } + sc.capacity = want; + return sc.ptr; +} + +// --------------------------------------------------------------------------- +// XLA FFI implementation +// --------------------------------------------------------------------------- + +static ffi::Error RoboGPUCheckImpl( + cudaStream_t stream, + ffi::Buffer cfg, // [B, n_act] + ffi::Buffer twists, // [J, 6] + ffi::Buffer parent_tf, // [J, 7] + ffi::Buffer parent_idx, // [J] + ffi::Buffer act_idx, // [J] + ffi::Buffer mimic_mul, // [J] + ffi::Buffer mimic_off, // [J] + ffi::Buffer mimic_act_idx, // [J] + ffi::Buffer topo_inv, // [J] + ffi::Buffer link_parent_joint, // [NL] + ffi::Buffer f_local, // [K, 4] + ffi::Buffer f_pair_i, // [Pf] + ffi::Buffer f_pair_j, // [Pf] + ffi::Buffer world_spheres, // [Ms, 4] + ffi::Buffer world_capsules, // [Mc, 7] + ffi::Buffer world_boxes, // [Mb, 15] + ffi::Buffer world_halfspaces, // [Mh, 6] + ffi::Buffer point_cloud, // [Mp, 3] + float r_env, + float r_robot_max, + ffi::Result> out // [B] +) { + const int B = static_cast(cfg.dimensions()[0]); + const int n_act = static_cast(cfg.dimensions()[1]); + const int J = static_cast(twists.dimensions()[0]); + const int NL = static_cast(link_parent_joint.dimensions()[0]); + const int K = static_cast(f_local.dimensions()[0]); + const int Pf = static_cast(f_pair_i.dimensions()[0]); + const int Ms = static_cast(world_spheres.dimensions()[0]); + const int Mc = static_cast(world_capsules.dimensions()[0]); + const int Mb = static_cast(world_boxes.dimensions()[0]); + const int Mh = static_cast(world_halfspaces.dimensions()[0]); + const int Mp = static_cast(point_cloud.dimensions()[0]); + + if (B <= 0) return ffi::Error::Success(); + + const int dev = robogpu_current_device(); + if (dev < 0) + return ffi::Error(ffi::ErrorCode::kInternal, + "RoboGPU: failed to query CUDA device, or device " + "ordinal exceeds PYROFFI_MAX_GPUS (rebuild with a " + "larger limit)."); + + if (J > RGB_MAX_JOINTS || NL > RGB_MAX_LINKS) + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RoboGPU: J or NL exceeds compile-time bounds " + "(RGB_MAX_JOINTS=" + std::to_string(RGB_MAX_JOINTS) + + ", RGB_MAX_LINKS=" + std::to_string(RGB_MAX_LINKS) + + "); rebuild with larger values."); + + // ── Stage 1: CUDA prepare (FK + world geom + self-collision) ───────────── + + float4* d_spheres = get_scratch(static_cast(B) * K, dev); + if (!d_spheres) + return ffi::Error(ffi::ErrorCode::kInternal, + "RoboGPU: scratch allocation failed"); + + robogpu_prepare_kernel<<>>( + cfg.typed_data(), twists.typed_data(), parent_tf.typed_data(), + parent_idx.typed_data(), act_idx.typed_data(), + mimic_mul.typed_data(), mimic_off.typed_data(), + mimic_act_idx.typed_data(), topo_inv.typed_data(), + link_parent_joint.typed_data(), f_local.typed_data(), + f_pair_i.typed_data(), f_pair_j.typed_data(), + world_spheres.typed_data(), Ms, + world_capsules.typed_data(), Mc, + world_boxes.typed_data(), Mb, + world_halfspaces.typed_data(), Mh, + d_spheres, out->typed_data(), + B, n_act, J, NL, K, Pf); + + { + cudaError_t e = cudaGetLastError(); + if (e != cudaSuccess) + return ffi::Error(ffi::ErrorCode::kInternal, cudaGetErrorString(e)); + } + + // ── Stage 2: OptiX BVH traversal for point cloud ───────────────────────── + if (Mp > 0) { + { + auto err = ensure_optix_pipeline(dev); + if (err.failure()) return err; + } + + // Key the BVH cache on the point-cloud *device pointer* (+ Mp + radii). + // The Python checker captures the point cloud as a constant in its jitted + // closure, so the buffer — and hence this pointer — is stable across + // calls that reuse the same cloud, and a different cloud yields a + // different buffer. This avoids the per-call D2H copy + stream sync that + // a content hash would require, keeping the whole check asynchronous. + char keybuf[96]; + snprintf(keybuf, sizeof(keybuf), "%p_%d_%g_%g", + reinterpret_cast(point_cloud.typed_data()), + Mp, (double)r_env, (double)r_robot_max); + std::string key(keybuf); + + BVHEntry* bvh = get_or_build_bvh( + stream, point_cloud.typed_data(), Mp, r_env, r_robot_max, key, dev); + if (!bvh) + return ffi::Error(ffi::ErrorCode::kInternal, "RoboGPU: BVH build failed"); + + // Update launch params (stream-ordered ahead of optixLaunch). + RoboGPULaunchParams hp = {}; + hp.handle = bvh->handle; + hp.robot_spheres = d_spheres; + hp.out_free = out->typed_data(); + hp.B = B; + hp.K = K; + CUDA_CHECK(cudaMemcpyAsync(reinterpret_cast(bvh->d_launch_params), + &hp, sizeof(hp), + cudaMemcpyHostToDevice, stream)); + + // OptiX launch — one raygen thread per config. + OPTIX_CHECK(optixLaunch(g_pipe[dev].pipeline, stream, + bvh->d_launch_params, sizeof(hp), + &bvh->sbt, + static_cast(B), 1, 1)); + } + + return ffi::Error::Success(); +} + +// --------------------------------------------------------------------------- +// XLA FFI handler symbol (loaded by Python via ctypes.CDLL) +// --------------------------------------------------------------------------- + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + RoboGPUCollisionFfi, RoboGPUCheckImpl, + ffi::Ffi::Bind() + .Ctx>() + .Arg>() // cfg + .Arg>() // twists + .Arg>() // parent_tf + .Arg>() // parent_idx + .Arg>() // act_idx + .Arg>() // mimic_mul + .Arg>() // mimic_off + .Arg>() // mimic_act_idx + .Arg>() // topo_inv + .Arg>() // link_parent_joint + .Arg>() // f_local + .Arg>() // f_pair_i + .Arg>() // f_pair_j + .Arg>() // world_spheres + .Arg>() // world_capsules + .Arg>() // world_boxes + .Arg>() // world_halfspaces + .Arg>() // point_cloud [Mp, 3] + .Attr("r_env") + .Attr("r_robot_max") + .Ret>()); // out [B] diff --git a/src/pyroffi/cuda_kernels/collision/_robogpu_optix_programs.cu b/src/pyroffi/cuda_kernels/collision/_robogpu_optix_programs.cu new file mode 100644 index 00000000..c8d8039b --- /dev/null +++ b/src/pyroffi/cuda_kernels/collision/_robogpu_optix_programs.cu @@ -0,0 +1,151 @@ +/** + * OptiX device programs for RoboGPU sphere-octree collision checking. + * + * Compiled to PTX (NOT linked into the host .so): + * nvcc --ptx -arch=sm_XX -I${OPTIX_SDK}/include \ + * -o _robogpu_optix_programs.ptx _robogpu_optix_programs.cu + * + * The environment point cloud is represented as a BVH of spheres (one per + * point, radius = r_env expanded by r_robot_max for AABB coverage). Robot + * collision spheres query the BVH via degenerate "rays" whose origin is the + * robot sphere centre. The custom intersection program performs the exact + * sphere-sphere proximity test; the any-hit program terminates traversal on + * the first hit (early exit — the key RoboGPU contribution over plain CUDA). + * + * Per RoboGPU §IV "P-Sphere" approach: environment points are the spheres in + * the BVH; robot sphere centres become query points (degenerate rays with + * direction=(0,0,1), tmax=1). The intersection program ignores the ray + * equation entirely and just tests sphere-sphere overlap, reporting a hit at + * t=0.5 (within [tmin=0, tmax=1]). This lets the OptiX BVH provide the + * tree-traversal acceleration (§III-B early-exit support) while the actual + * primitive test is a simple sphere-sphere distance check. + * + * Payload registers: + * p0 — robot sphere radius (float bits, set by raygen, read by intersection) + * p1 — hit flag (0 = miss, 1 = hit; written by any-hit, read by raygen) + */ + +#include +#include +#include + +// --------------------------------------------------------------------------- +// Launch parameters (set via optixLaunch params buffer; __constant__ in PTX) +// --------------------------------------------------------------------------- + +struct RoboGPULaunchParams { + OptixTraversableHandle handle; // BVH over environment spheres + const float4* robot_spheres; // [B * K, 4] world-frame (x,y,z,r) + int32_t* out_free; // [B] in/out: 1=free, 0=collision + int B; // batch size + int K; // robot spheres per config (incl. padding) +}; + +extern "C" { + __constant__ RoboGPULaunchParams params; +} + +// --------------------------------------------------------------------------- +// Per-primitive SBT hit-group data (one record covers all env sphere prims) +// --------------------------------------------------------------------------- + +struct HitGroupData { + const float4* env_spheres; // [Mp, 4] (cx, cy, cz, r_env) +}; + +// --------------------------------------------------------------------------- +// Ray generation — one OptiX thread per configuration +// +// Loops over K robot collision spheres. For each active sphere, fires a +// single optixTrace into the env BVH. Breaks immediately on the first hit +// (per-config early exit complementing the per-sphere early exit in any-hit). +// --------------------------------------------------------------------------- + +extern "C" __global__ void __raygen__sphere_query() { + const int b = static_cast(optixGetLaunchIndex().x); + if (b >= params.B) return; + + // Skip configs already marked as in-collision by the CUDA prepare stage. + if (params.out_free[b] == 0) return; + + for (int k = 0; k < params.K; ++k) { + const float4 s = params.robot_spheres[b * params.K + k]; + // Padding spheres have negative radius — skip. + if (s.w < 0.0f) continue; + + // p0 = robot sphere radius (read by intersection program) + // p1 = hit flag, 0 initially; set to 1 by any-hit on first collision + unsigned int p0 = __float_as_uint(s.w); + unsigned int p1 = 0u; + + // Degenerate point query: a near-zero-length ray. The robot sphere + // centre lies inside every expanded env-sphere AABB it could collide + // with, so a tiny tmax still visits all relevant AABBs while avoiding + // the spurious candidates a long ray would sweep up on dense clouds. + optixTrace( + params.handle, + make_float3(s.x, s.y, s.z), // ray origin = robot sphere centre + make_float3(0.0f, 0.0f, 1.0f), // dummy direction (ignored by isect) + 0.0f, // tmin + 1.0e-3f, // tmax (intersection reports t=5e-4) + 0.0f, // ray time + OptixVisibilityMask(0xFF), + OPTIX_RAY_FLAG_NONE, + 0, 1, 0, // SBT offset, stride, miss SBT idx + p0, p1 + ); + + if (p1 != 0u) { + // First environment sphere hit → collision for this config. + params.out_free[b] = 0; + return; // early exit: no need to test remaining robot spheres + } + } + // All robot spheres clear of the point cloud — out_free[b] stays 1. +} + +// --------------------------------------------------------------------------- +// Intersection program — custom sphere-sphere proximity test +// +// The ray "origin" encodes the robot sphere centre; payload word 0 carries +// the robot sphere radius. We test overlap with the BVH primitive's env +// sphere and report a hit at t=0.5 if they overlap. +// --------------------------------------------------------------------------- + +extern "C" __global__ void __intersection__sphere() { + const HitGroupData* hg = + reinterpret_cast(optixGetSbtDataPointer()); + const int prim_idx = optixGetPrimitiveIndex(); + const float4 env = hg->env_spheres[prim_idx]; // (cx, cy, cz, r_env) + + // Robot sphere centre from ray origin; robot sphere radius from payload. + const float3 o = optixGetWorldRayOrigin(); + const float r_robot = __uint_as_float(optixGetPayload_0()); + + const float dx = o.x - env.x; + const float dy = o.y - env.y; + const float dz = o.z - env.z; + const float r_sum = r_robot + env.w; + + if (dx*dx + dy*dy + dz*dz < r_sum * r_sum) { + // Spheres overlap: report hit at t=5e-4, within [tmin=0, tmax=1e-3]. + optixReportIntersection(5.0e-4f, 0u); + } +} + +// --------------------------------------------------------------------------- +// Any-hit program — terminate ray on first overlap (RoboGPU early-exit) +// --------------------------------------------------------------------------- + +extern "C" __global__ void __anyhit__sphere() { + optixSetPayload_1(1u); // signal collision back to raygen + optixTerminateRay(); // stop BVH traversal immediately +} + +// --------------------------------------------------------------------------- +// Miss program — no environment sphere hit for this robot sphere +// --------------------------------------------------------------------------- + +extern "C" __global__ void __miss__sphere() { + // p1 remains 0: robot sphere is clear of the point cloud +} diff --git a/src/pyroffi/cuda_kernels/fk/__init__.py b/src/pyroffi/cuda_kernels/fk/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/pyroffi/cuda_kernels/_fk_cuda.py b/src/pyroffi/cuda_kernels/fk/_fk_cuda.py similarity index 97% rename from src/pyroffi/cuda_kernels/_fk_cuda.py rename to src/pyroffi/cuda_kernels/fk/_fk_cuda.py index c44dd798..5e243d11 100644 --- a/src/pyroffi/cuda_kernels/_fk_cuda.py +++ b/src/pyroffi/cuda_kernels/fk/_fk_cuda.py @@ -3,7 +3,7 @@ The companion shared library ``_fk_cuda.so`` must be compiled from ``_fk_cuda_kernel.cu`` before this module can be imported: - bash src/pyroffi/cuda_kernels/build_fk_cuda.sh + bash build_kernels/build_fk_cuda.sh Requires JAX >= 0.4.14 (for jax.ffi). """ @@ -30,7 +30,7 @@ def _load_and_register() -> None: if not lib_path.exists(): raise RuntimeError( f"CUDA FK library not found at {lib_path}.\n" - "Compile it first with: bash src/pyroffi/cuda_kernels/build_fk_cuda.sh\n" + "Compile it first with: bash build_kernels/build_fk_cuda.sh\n" "(This produces _fk_cuda_lib.so alongside the kernel source.)" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_fk_cuda_kernel.cu b/src/pyroffi/cuda_kernels/fk/_fk_cuda_kernel.cu similarity index 94% rename from src/pyroffi/cuda_kernels/_fk_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/fk/_fk_cuda_kernel.cu index 7813d9bd..a3204c73 100644 --- a/src/pyroffi/cuda_kernels/_fk_cuda_kernel.cu +++ b/src/pyroffi/cuda_kernels/fk/_fk_cuda_kernel.cu @@ -12,7 +12,7 @@ * SE(3) math and the single-thread FK device function live in * _fk_cuda_helpers.cuh so that _ik_cuda_kernel.cu can reuse them. * - * Build with: bash src/pyroffi/cuda_kernels/build_fk_cuda.sh + * Build with: bash build_kernels/build_fk_cuda.sh */ #include "_fk_cuda_helpers.cuh" @@ -21,6 +21,14 @@ namespace ffi = xla::ffi; +// Under jax.pmap each visible GPU is driven by its own host thread running this +// same handler concurrently. All host-side caches (the per-model __constant__ +// upload tracker and the CUDA graph exec) must therefore be kept per-device and +// indexed by the current CUDA device ordinal — a shared cache would race across +// device threads, thrash the constant-memory re-upload, and replay one device's +// graph on another (illegal access). +static constexpr int PYROFFI_MAX_GPUS = 16; + // Maximum number of joints supported by the FK shared-memory cache. // Increase if your robot has more joints. #define FK_MAX_JOINTS 64 @@ -263,7 +271,13 @@ static ffi::Error FkCudaImpl( int max_level_width = -1; bool valid = false; }; - static ModelCache cache; + static ModelCache cache_pool[PYROFFI_MAX_GPUS]; + int _dev = 0; + cudaGetDevice(&_dev); + if (_dev < 0 || _dev >= PYROFFI_MAX_GPUS) + return ffi::Error(ffi::ErrorCode::kInternal, + "CUDA device ordinal exceeds PYROFFI_MAX_GPUS (rebuild with a larger limit)."); + ModelCache& cache = cache_pool[_dev]; const int batch = static_cast(cfg.dimensions()[0]); const int n_act = static_cast(cfg.dimensions()[1]); @@ -356,7 +370,8 @@ static ffi::Error FkCudaImpl( (void)topo_inv; - static FkGraphCache graph_cache; + static FkGraphCache graph_cache_pool[PYROFFI_MAX_GPUS]; + FkGraphCache& graph_cache = graph_cache_pool[_dev]; const int items_per_warp = fk_pick_items_per_warp(cache.max_level_width); void* cfg_ptr = const_cast(cfg.typed_data()); diff --git a/src/pyroffi/cuda_kernels/ik/__init__.py b/src/pyroffi/cuda_kernels/ik/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/pyroffi/cuda_kernels/_hjcd_ik_cuda.py b/src/pyroffi/cuda_kernels/ik/_hjcd_ik_cuda.py similarity index 98% rename from src/pyroffi/cuda_kernels/_hjcd_ik_cuda.py rename to src/pyroffi/cuda_kernels/ik/_hjcd_ik_cuda.py index 6a7312c6..7260ad74 100644 --- a/src/pyroffi/cuda_kernels/_hjcd_ik_cuda.py +++ b/src/pyroffi/cuda_kernels/ik/_hjcd_ik_cuda.py @@ -3,7 +3,7 @@ The companion shared library ``_hjcd_ik_cuda_lib.so`` must be compiled from ``_hjcd_ik_cuda_kernel.cu`` before this module can be imported: - bash src/pyroffi/cuda_kernels/build_hjcd_ik_cuda.sh + bash build_kernels/build_hjcd_ik_cuda.sh Provides two primitives called by the CUDA path in ``_hjcd_ik.py``: @@ -40,7 +40,7 @@ def _load_and_register() -> None: if not lib_path.exists(): raise RuntimeError( f"CUDA IK library not found at {lib_path}.\n" - "Compile it first with: bash src/pyroffi/cuda_kernels/build_hjcd_ik_cuda.sh\n" + "Compile it first with: bash build_kernels/build_hjcd_ik_cuda.sh\n" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_hjcd_ik_cuda_kernel.cu b/src/pyroffi/cuda_kernels/ik/_hjcd_ik_cuda_kernel.cu similarity index 99% rename from src/pyroffi/cuda_kernels/_hjcd_ik_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/ik/_hjcd_ik_cuda_kernel.cu index 191b22e2..b8ef449f 100644 --- a/src/pyroffi/cuda_kernels/_hjcd_ik_cuda_kernel.cu +++ b/src/pyroffi/cuda_kernels/ik/_hjcd_ik_cuda_kernel.cu @@ -20,7 +20,7 @@ * - All kernel launches are associated with the caller's CUDA stream so * there are no implicit device synchronisations. * - * Build with: bash src/pyroffi/cuda_kernels/build_hjcd_ik_cuda.sh + * Build with: bash build_kernels/build_hjcd_ik_cuda.sh */ #include "_ik_cuda_helpers.cuh" diff --git a/src/pyroffi/cuda_kernels/_ls_ik_cuda.py b/src/pyroffi/cuda_kernels/ik/_ls_ik_cuda.py similarity index 97% rename from src/pyroffi/cuda_kernels/_ls_ik_cuda.py rename to src/pyroffi/cuda_kernels/ik/_ls_ik_cuda.py index fb05ac72..73763594 100644 --- a/src/pyroffi/cuda_kernels/_ls_ik_cuda.py +++ b/src/pyroffi/cuda_kernels/ik/_ls_ik_cuda.py @@ -3,7 +3,7 @@ The companion shared library ``_ls_ik_cuda_lib.so`` must be compiled from ``_ls_ik_cuda_kernel.cu`` before this module can be imported: - bash src/pyroffi/cuda_kernels/build_ls_ik_cuda.sh + bash build_kernels/build_ls_ik_cuda.sh Provides one primitive: @@ -34,7 +34,7 @@ def _load_and_register() -> None: if not lib_path.exists(): raise RuntimeError( f"LS-IK CUDA library not found at {lib_path}.\n" - "Compile it first with: bash src/pyroffi/cuda_kernels/build_ls_ik_cuda.sh\n" + "Compile it first with: bash build_kernels/build_ls_ik_cuda.sh\n" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_ls_ik_cuda_kernel.cu b/src/pyroffi/cuda_kernels/ik/_ls_ik_cuda_kernel.cu similarity index 99% rename from src/pyroffi/cuda_kernels/_ls_ik_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/ik/_ls_ik_cuda_kernel.cu index 8ce7f540..3496b65e 100644 --- a/src/pyroffi/cuda_kernels/_ls_ik_cuda_kernel.cu +++ b/src/pyroffi/cuda_kernels/ik/_ls_ik_cuda_kernel.cu @@ -19,7 +19,7 @@ * - Normal-equation matrix and Cholesky solve in float64. * * Build with: - * bash src/pyroffi/cuda_kernels/build_ls_ik_cuda.sh + * bash build_kernels/build_ls_ik_cuda.sh */ #include "_ik_cuda_helpers.cuh" diff --git a/src/pyroffi/cuda_kernels/_mppi_ik_cuda.py b/src/pyroffi/cuda_kernels/ik/_mppi_ik_cuda.py similarity index 98% rename from src/pyroffi/cuda_kernels/_mppi_ik_cuda.py rename to src/pyroffi/cuda_kernels/ik/_mppi_ik_cuda.py index 3143a161..7aefb0b3 100644 --- a/src/pyroffi/cuda_kernels/_mppi_ik_cuda.py +++ b/src/pyroffi/cuda_kernels/ik/_mppi_ik_cuda.py @@ -3,7 +3,7 @@ The companion shared library ``_mppi_ik_cuda_lib.so`` must be compiled from ``_mppi_ik_cuda_kernel.cu`` before this module can be imported: - bash src/pyroffi/cuda_kernels/build_mppi_ik_cuda.sh + bash build_kernels/build_mppi_ik_cuda.sh Provides one primitive: @@ -34,7 +34,7 @@ def _load_and_register() -> None: if not lib_path.exists(): raise RuntimeError( f"MPPI-IK CUDA library not found at {lib_path}.\n" - "Compile it first with: bash src/pyroffi/cuda_kernels/build_mppi_ik_cuda.sh\n" + "Compile it first with: bash build_kernels/build_mppi_ik_cuda.sh\n" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_mppi_ik_cuda_kernel.cu b/src/pyroffi/cuda_kernels/ik/_mppi_ik_cuda_kernel.cu similarity index 99% rename from src/pyroffi/cuda_kernels/_mppi_ik_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/ik/_mppi_ik_cuda_kernel.cu index ce7e179f..c1b5ece1 100644 --- a/src/pyroffi/cuda_kernels/_mppi_ik_cuda_kernel.cu +++ b/src/pyroffi/cuda_kernels/ik/_mppi_ik_cuda_kernel.cu @@ -25,7 +25,7 @@ * 6. Removed cudaStreamSynchronize (XLA owns stream scheduling). * * Build with: - * bash src/pyroffi/cuda_kernels/build_mppi_ik_cuda.sh + * bash build_kernels/build_mppi_ik_cuda.sh */ #include "_ik_cuda_helpers.cuh" diff --git a/src/pyroffi/cuda_kernels/_sqp_ik_cuda.py b/src/pyroffi/cuda_kernels/ik/_sqp_ik_cuda.py similarity index 97% rename from src/pyroffi/cuda_kernels/_sqp_ik_cuda.py rename to src/pyroffi/cuda_kernels/ik/_sqp_ik_cuda.py index 6687428a..09ebea3f 100644 --- a/src/pyroffi/cuda_kernels/_sqp_ik_cuda.py +++ b/src/pyroffi/cuda_kernels/ik/_sqp_ik_cuda.py @@ -3,7 +3,7 @@ The companion shared library ``_sqp_ik_cuda_lib.so`` must be compiled from ``_sqp_ik_cuda_kernel.cu`` before this module can be imported: - bash src/pyroffi/cuda_kernels/build_sqp_ik_cuda.sh + bash build_kernels/build_sqp_ik_cuda.sh Provides one primitive: @@ -34,7 +34,7 @@ def _load_and_register() -> None: if not lib_path.exists(): raise RuntimeError( f"SQP-IK CUDA library not found at {lib_path}.\n" - "Compile it first with: bash src/pyroffi/cuda_kernels/build_sqp_ik_cuda.sh\n" + "Compile it first with: bash build_kernels/build_sqp_ik_cuda.sh\n" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_sqp_ik_cuda_kernel.cu b/src/pyroffi/cuda_kernels/ik/_sqp_ik_cuda_kernel.cu similarity index 99% rename from src/pyroffi/cuda_kernels/_sqp_ik_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/ik/_sqp_ik_cuda_kernel.cu index b8207cd0..c9784cbf 100644 --- a/src/pyroffi/cuda_kernels/_sqp_ik_cuda_kernel.cu +++ b/src/pyroffi/cuda_kernels/ik/_sqp_ik_cuda_kernel.cu @@ -28,7 +28,7 @@ * - Inner projected gradient loop in float32. * * Build with: - * bash src/pyroffi/cuda_kernels/build_sqp_ik_cuda.sh + * bash build_kernels/build_sqp_ik_cuda.sh */ #include "_ik_cuda_helpers.cuh" diff --git a/src/pyroffi/cuda_kernels/region_ik/__init__.py b/src/pyroffi/cuda_kernels/region_ik/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/pyroffi/cuda_kernels/_brownian_motion_ik_cuda.py b/src/pyroffi/cuda_kernels/region_ik/_brownian_motion_ik_cuda.py similarity index 96% rename from src/pyroffi/cuda_kernels/_brownian_motion_ik_cuda.py rename to src/pyroffi/cuda_kernels/region_ik/_brownian_motion_ik_cuda.py index fc9616e8..f6676349 100644 --- a/src/pyroffi/cuda_kernels/_brownian_motion_ik_cuda.py +++ b/src/pyroffi/cuda_kernels/region_ik/_brownian_motion_ik_cuda.py @@ -2,7 +2,7 @@ Compile the companion shared library first: - bash src/pyroffi/cuda_kernels/build_brownian_motion_ik_cuda.sh + bash build_kernels/build_brownian_motion_ik_cuda.sh The kernel accepts per-problem box bounds (box_mins / box_maxs with shape (n_problems, 3)), enabling multiple distinct regions to be solved in a @@ -30,7 +30,7 @@ def _load_and_register() -> None: if not lib_path.exists(): raise RuntimeError( f"Brownian-motion IK CUDA library not found at {lib_path}.\n" - "Compile it first with: bash src/pyroffi/cuda_kernels/build_brownian_motion_ik_cuda.sh\n" + "Compile it first with: bash build_kernels/build_brownian_motion_ik_cuda.sh\n" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_brownian_motion_ik_cuda_kernel.cu b/src/pyroffi/cuda_kernels/region_ik/_brownian_motion_ik_cuda_kernel.cu similarity index 99% rename from src/pyroffi/cuda_kernels/_brownian_motion_ik_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/region_ik/_brownian_motion_ik_cuda_kernel.cu index 1d55584b..4a730dbd 100644 --- a/src/pyroffi/cuda_kernels/_brownian_motion_ik_cuda_kernel.cu +++ b/src/pyroffi/cuda_kernels/region_ik/_brownian_motion_ik_cuda_kernel.cu @@ -32,7 +32,7 @@ * 6. Build with --use_fast_math for hardware SFU (sqrtf, rsqrtf, etc.). * * Build: - * bash src/pyroffi/cuda_kernels/build_brownian_motion_ik_cuda.sh + * bash build_kernels/build_brownian_motion_ik_cuda.sh */ #include "_ik_cuda_helpers.cuh" diff --git a/src/pyroffi/cuda_kernels/_hit_and_run_ik_cuda.py b/src/pyroffi/cuda_kernels/region_ik/_hit_and_run_ik_cuda.py similarity index 96% rename from src/pyroffi/cuda_kernels/_hit_and_run_ik_cuda.py rename to src/pyroffi/cuda_kernels/region_ik/_hit_and_run_ik_cuda.py index 542dbb18..6844e67d 100644 --- a/src/pyroffi/cuda_kernels/_hit_and_run_ik_cuda.py +++ b/src/pyroffi/cuda_kernels/region_ik/_hit_and_run_ik_cuda.py @@ -2,7 +2,7 @@ Compile the companion shared library first: - bash src/pyroffi/cuda_kernels/build_hit_and_run_ik_cuda.sh + bash build_kernels/build_hit_and_run_ik_cuda.sh """ from __future__ import annotations @@ -26,7 +26,7 @@ def _load_and_register() -> None: if not lib_path.exists(): raise RuntimeError( f"Hit-and-run IK CUDA library not found at {lib_path}.\n" - "Compile it first with: bash src/pyroffi/cuda_kernels/build_hit_and_run_ik_cuda.sh\n" + "Compile it first with: bash build_kernels/build_hit_and_run_ik_cuda.sh\n" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_hit_and_run_ik_cuda_kernel.cu b/src/pyroffi/cuda_kernels/region_ik/_hit_and_run_ik_cuda_kernel.cu similarity index 99% rename from src/pyroffi/cuda_kernels/_hit_and_run_ik_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/region_ik/_hit_and_run_ik_cuda_kernel.cu index 05349c14..843cbcbd 100644 --- a/src/pyroffi/cuda_kernels/_hit_and_run_ik_cuda_kernel.cu +++ b/src/pyroffi/cuda_kernels/region_ik/_hit_and_run_ik_cuda_kernel.cu @@ -2,7 +2,7 @@ * Hit-and-run NLP sampling CUDA kernel. * * Build with: - * bash src/pyroffi/cuda_kernels/build_hit_and_run_ik_cuda.sh + * bash build_kernels/build_hit_and_run_ik_cuda.sh */ #include "_ik_cuda_helpers.cuh" diff --git a/src/pyroffi/cuda_kernels/_svgd_region_ik_cuda.py b/src/pyroffi/cuda_kernels/region_ik/_svgd_region_ik_cuda.py similarity index 97% rename from src/pyroffi/cuda_kernels/_svgd_region_ik_cuda.py rename to src/pyroffi/cuda_kernels/region_ik/_svgd_region_ik_cuda.py index 51209649..38e20197 100644 --- a/src/pyroffi/cuda_kernels/_svgd_region_ik_cuda.py +++ b/src/pyroffi/cuda_kernels/region_ik/_svgd_region_ik_cuda.py @@ -3,7 +3,7 @@ The companion shared library ``_svgd_region_ik_cuda_lib.so`` must be compiled from ``_svgd_region_ik_cuda_kernel.cu`` before this module can be imported: - bash src/pyroffi/cuda_kernels/build_svgd_region_ik_cuda.sh + bash build_kernels/build_svgd_region_ik_cuda.sh Provides one primitive: @@ -56,7 +56,7 @@ def _load_and_register() -> None: if not lib_path.exists(): raise RuntimeError( f"SVGD region-IK CUDA library not found at {lib_path}.\n" - "Compile it first with: bash src/pyroffi/cuda_kernels/build_svgd_region_ik_cuda.sh\n" + "Compile it first with: bash build_kernels/build_svgd_region_ik_cuda.sh\n" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_svgd_region_ik_cuda_kernel.cu b/src/pyroffi/cuda_kernels/region_ik/_svgd_region_ik_cuda_kernel.cu similarity index 99% rename from src/pyroffi/cuda_kernels/_svgd_region_ik_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/region_ik/_svgd_region_ik_cuda_kernel.cu index 47531e4d..d6912154 100644 --- a/src/pyroffi/cuda_kernels/_svgd_region_ik_cuda_kernel.cu +++ b/src/pyroffi/cuda_kernels/region_ik/_svgd_region_ik_cuda_kernel.cu @@ -27,7 +27,7 @@ * bandwidth is degenerate (particles collapsed) the caller-supplied fallback * value is used instead. * - * Build with: bash src/pyroffi/cuda_kernels/build_svgd_region_ik_cuda.sh + * Build with: bash build_kernels/build_svgd_region_ik_cuda.sh */ #include "_ik_cuda_helpers.cuh" diff --git a/src/pyroffi/cuda_kernels/trajopt/__init__.py b/src/pyroffi/cuda_kernels/trajopt/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/pyroffi/cuda_kernels/_chomp_trajopt_cuda.py b/src/pyroffi/cuda_kernels/trajopt/_chomp_trajopt_cuda.py similarity index 98% rename from src/pyroffi/cuda_kernels/_chomp_trajopt_cuda.py rename to src/pyroffi/cuda_kernels/trajopt/_chomp_trajopt_cuda.py index 8b852ec7..bd96628a 100644 --- a/src/pyroffi/cuda_kernels/_chomp_trajopt_cuda.py +++ b/src/pyroffi/cuda_kernels/trajopt/_chomp_trajopt_cuda.py @@ -3,7 +3,7 @@ The companion shared library _chomp_trajopt_cuda_lib.so must be compiled from _chomp_trajopt_cuda_kernel.cu before this module can be imported: - bash src/pyroffi/cuda_kernels/build_chomp_trajopt_cuda.sh + bash build_kernels/build_chomp_trajopt_cuda.sh """ from __future__ import annotations @@ -34,7 +34,7 @@ def _load_and_register() -> None: raise RuntimeError( f"CHOMP TrajOpt CUDA library not found at {lib_path}.\n" "Compile it first with:\n" - " bash src/pyroffi/cuda_kernels/build_chomp_trajopt_cuda.sh\n" + " bash build_kernels/build_chomp_trajopt_cuda.sh\n" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_chomp_trajopt_cuda_kernel.cu b/src/pyroffi/cuda_kernels/trajopt/_chomp_trajopt_cuda_kernel.cu similarity index 100% rename from src/pyroffi/cuda_kernels/_chomp_trajopt_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/trajopt/_chomp_trajopt_cuda_kernel.cu diff --git a/src/pyroffi/cuda_kernels/_ls_trajopt_cuda.py b/src/pyroffi/cuda_kernels/trajopt/_ls_trajopt_cuda.py similarity index 98% rename from src/pyroffi/cuda_kernels/_ls_trajopt_cuda.py rename to src/pyroffi/cuda_kernels/trajopt/_ls_trajopt_cuda.py index 4aa165e6..fa9424c2 100644 --- a/src/pyroffi/cuda_kernels/_ls_trajopt_cuda.py +++ b/src/pyroffi/cuda_kernels/trajopt/_ls_trajopt_cuda.py @@ -3,7 +3,7 @@ The companion shared library ``_ls_trajopt_cuda_lib.so`` must be compiled from ``_ls_trajopt_cuda_kernel.cu`` before this module can be imported: - bash src/pyroffi/cuda_kernels/build_ls_trajopt_cuda.sh + bash build_kernels/build_ls_trajopt_cuda.sh """ from __future__ import annotations @@ -34,7 +34,7 @@ def _load_and_register() -> None: raise RuntimeError( f"LS TrajOpt CUDA library not found at {lib_path}.\n" "Compile it first with:\n" - " bash src/pyroffi/cuda_kernels/build_ls_trajopt_cuda.sh\n" + " bash build_kernels/build_ls_trajopt_cuda.sh\n" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_ls_trajopt_cuda_kernel.cu b/src/pyroffi/cuda_kernels/trajopt/_ls_trajopt_cuda_kernel.cu similarity index 99% rename from src/pyroffi/cuda_kernels/_ls_trajopt_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/trajopt/_ls_trajopt_cuda_kernel.cu index fc6de5a3..32770d30 100644 --- a/src/pyroffi/cuda_kernels/_ls_trajopt_cuda_kernel.cu +++ b/src/pyroffi/cuda_kernels/trajopt/_ls_trajopt_cuda_kernel.cu @@ -5,7 +5,7 @@ * runs a diagonal-Gauss-Newton / LM update inside the kernel. * * Build with: - * bash src/pyroffi/cuda_kernels/build_ls_trajopt_cuda.sh + * bash build_kernels/build_ls_trajopt_cuda.sh */ #include "_ik_cuda_helpers.cuh" diff --git a/src/pyroffi/cuda_kernels/_sco_trajopt_cuda.py b/src/pyroffi/cuda_kernels/trajopt/_sco_trajopt_cuda.py similarity index 99% rename from src/pyroffi/cuda_kernels/_sco_trajopt_cuda.py rename to src/pyroffi/cuda_kernels/trajopt/_sco_trajopt_cuda.py index 8db6654a..09003830 100644 --- a/src/pyroffi/cuda_kernels/_sco_trajopt_cuda.py +++ b/src/pyroffi/cuda_kernels/trajopt/_sco_trajopt_cuda.py @@ -3,7 +3,7 @@ The companion shared library ``_sco_trajopt_cuda_lib.so`` must be compiled from ``_sco_trajopt_cuda_kernel.cu`` before this module can be imported: - bash src/pyroffi/cuda_kernels/build_sco_trajopt_cuda.sh + bash build_kernels/build_sco_trajopt_cuda.sh Provides: @@ -49,7 +49,7 @@ def _load_and_register() -> None: raise RuntimeError( f"SCO TrajOpt CUDA library not found at {lib_path}.\n" "Compile it first with:\n" - " bash src/pyroffi/cuda_kernels/build_sco_trajopt_cuda.sh\n" + " bash build_kernels/build_sco_trajopt_cuda.sh\n" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_sco_trajopt_cuda_kernel.cu b/src/pyroffi/cuda_kernels/trajopt/_sco_trajopt_cuda_kernel.cu similarity index 98% rename from src/pyroffi/cuda_kernels/_sco_trajopt_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/trajopt/_sco_trajopt_cuda_kernel.cu index 80b39fd8..cc685ee8 100644 --- a/src/pyroffi/cuda_kernels/_sco_trajopt_cuda_kernel.cu +++ b/src/pyroffi/cuda_kernels/trajopt/_sco_trajopt_cuda_kernel.cu @@ -15,7 +15,7 @@ * workspace. * * Build with: - * bash src/pyroffi/cuda_kernels/build_sco_trajopt_cuda.sh + * bash build_kernels/build_sco_trajopt_cuda.sh */ #include "_ik_cuda_helpers.cuh" @@ -101,6 +101,12 @@ struct ScoTrajoptGraphCache { namespace ffi = xla::ffi; +// Under jax.pmap each visible GPU runs this handler on its own host thread, so +// the CUDA graph cache must be kept per-device (indexed by the current device +// ordinal). A single shared cache would race across device threads and replay +// one device's graph on another (illegal access). +static constexpr int PYROFFI_MAX_GPUS = 16; + // --------------------------------------------------------------------------- // Compile-time limits // --------------------------------------------------------------------------- @@ -1064,7 +1070,7 @@ void sco_trajopt_kernel( // XLA FFI handler — with CUDA graph caching // --------------------------------------------------------------------------- -static ScoTrajoptGraphCache s_trajopt_cache; +static ScoTrajoptGraphCache s_trajopt_cache_pool[PYROFFI_MAX_GPUS]; static ffi::Error ScoTrajoptCudaImpl( cudaStream_t stream, @@ -1101,6 +1107,13 @@ static ffi::Error ScoTrajoptCudaImpl( ffi::Result> out_costs, ffi::Result> out_workspace) { + int _dev = 0; + cudaGetDevice(&_dev); + if (_dev < 0 || _dev >= PYROFFI_MAX_GPUS) + return ffi::Error(ffi::ErrorCode::kInternal, + "CUDA device ordinal exceeds PYROFFI_MAX_GPUS (rebuild with a larger limit)."); + ScoTrajoptGraphCache& s_trajopt_cache = s_trajopt_cache_pool[_dev]; + const int B = static_cast(init_trajs.dimensions()[0]); const int T = static_cast(init_trajs.dimensions()[1]); const int n_act = static_cast(init_trajs.dimensions()[2]); diff --git a/src/pyroffi/cuda_kernels/_stomp_trajopt_cuda.py b/src/pyroffi/cuda_kernels/trajopt/_stomp_trajopt_cuda.py similarity index 98% rename from src/pyroffi/cuda_kernels/_stomp_trajopt_cuda.py rename to src/pyroffi/cuda_kernels/trajopt/_stomp_trajopt_cuda.py index b7c29255..ff13de84 100644 --- a/src/pyroffi/cuda_kernels/_stomp_trajopt_cuda.py +++ b/src/pyroffi/cuda_kernels/trajopt/_stomp_trajopt_cuda.py @@ -3,7 +3,7 @@ The companion shared library ``_stomp_trajopt_cuda_lib.so`` must be compiled from ``_stomp_trajopt_cuda_kernel.cu`` before this module can be imported: - bash src/pyroffi/cuda_kernels/build_stomp_trajopt_cuda.sh + bash build_kernels/build_stomp_trajopt_cuda.sh Provides: @@ -45,7 +45,7 @@ def _load_and_register() -> None: raise RuntimeError( f"STOMP TrajOpt CUDA library not found at {lib_path}.\n" "Compile it first with:\n" - " bash src/pyroffi/cuda_kernels/build_stomp_trajopt_cuda.sh\n" + " bash build_kernels/build_stomp_trajopt_cuda.sh\n" ) lib = ctypes.CDLL(str(lib_path)) diff --git a/src/pyroffi/cuda_kernels/_stomp_trajopt_cuda_kernel.cu b/src/pyroffi/cuda_kernels/trajopt/_stomp_trajopt_cuda_kernel.cu similarity index 99% rename from src/pyroffi/cuda_kernels/_stomp_trajopt_cuda_kernel.cu rename to src/pyroffi/cuda_kernels/trajopt/_stomp_trajopt_cuda_kernel.cu index 393dd50d..bad95764 100644 --- a/src/pyroffi/cuda_kernels/_stomp_trajopt_cuda_kernel.cu +++ b/src/pyroffi/cuda_kernels/trajopt/_stomp_trajopt_cuda_kernel.cu @@ -21,7 +21,7 @@ * - Replaces serial thread-0 softmax with parallel reduction * * Build with: - * bash src/pyroffi/cuda_kernels/build_stomp_trajopt_cuda.sh + * bash build_kernels/build_stomp_trajopt_cuda.sh */ #include "_ik_cuda_helpers.cuh" diff --git a/src/pyroffi/optimization_engines/_chomp_optimization.py b/src/pyroffi/optimization_engines/_chomp_optimization.py index 7ba14ab6..bf8a5152 100644 --- a/src/pyroffi/optimization_engines/_chomp_optimization.py +++ b/src/pyroffi/optimization_engines/_chomp_optimization.py @@ -407,7 +407,7 @@ def chomp_trajopt( final_trajs: All optimized trajectories. [B, T, DOF]. """ if use_cuda: - from ..cuda_kernels._chomp_trajopt_cuda import chomp_trajopt_cuda + from ..cuda_kernels.trajopt._chomp_trajopt_cuda import chomp_trajopt_cuda return chomp_trajopt_cuda( init_trajs, start, goal, robot, robot_coll, world_geoms, opt_cfg ) diff --git a/src/pyroffi/optimization_engines/_hjcd_ik.py b/src/pyroffi/optimization_engines/_hjcd_ik.py index 5bb71826..aabc8516 100644 --- a/src/pyroffi/optimization_engines/_hjcd_ik.py +++ b/src/pyroffi/optimization_engines/_hjcd_ik.py @@ -695,7 +695,7 @@ def _hjcd_solve_cuda_jit( constraint_args: tuple = (), constraint_weights: Float[Array, "n_constraints"] | None = None, ) -> Float[Array, "n_act"]: - from ..cuda_kernels._hjcd_ik_cuda import hjcd_ik_coarse_cuda, hjcd_ik_lm_cuda + from ..cuda_kernels.ik._hjcd_ik_cuda import hjcd_ik_coarse_cuda, hjcd_ik_lm_cuda n_act = robot.joints.num_actuated_joints lower = robot.joints.lower_limits @@ -895,7 +895,7 @@ def hjcd_solve_cuda( implicit device synchronisations. Requires ``_hjcd_ik_cuda_lib.so`` to be compiled first: - bash src/pyroffi/cuda_kernels/build_hjcd_ik_cuda.sh + bash build_kernels/build_hjcd_ik_cuda.sh Kinematic constraints Because the CUDA kernel cannot call arbitrary Python/JAX functions, @@ -1116,7 +1116,7 @@ def _hjcd_solve_cuda_batch_jit( constraint_args: tuple = (), constraint_weights: Float[Array, "n_constraints"] | None = None, ) -> Float[Array, "n_problems n_act"]: - from ..cuda_kernels._hjcd_ik_cuda import hjcd_ik_coarse_cuda, hjcd_ik_lm_cuda + from ..cuda_kernels.ik._hjcd_ik_cuda import hjcd_ik_coarse_cuda, hjcd_ik_lm_cuda n_act = robot.joints.num_actuated_joints lower = robot.joints.lower_limits diff --git a/src/pyroffi/optimization_engines/_ls_ik.py b/src/pyroffi/optimization_engines/_ls_ik.py index 2fe98424..f3162830 100644 --- a/src/pyroffi/optimization_engines/_ls_ik.py +++ b/src/pyroffi/optimization_engines/_ls_ik.py @@ -579,7 +579,7 @@ def _ls_ik_solve_cuda_jit( constraint_args: tuple = (), constraint_weights: Float[Array, "n_constraints"] | None = None, ) -> Float[Array, "n_act"]: - from ..cuda_kernels._ls_ik_cuda import ls_ik_cuda + from ..cuda_kernels.ik._ls_ik_cuda import ls_ik_cuda n_act = robot.joints.num_actuated_joints lower = robot.joints.lower_limits @@ -707,7 +707,7 @@ def ls_ik_solve_cuda( the JAX solver but with no Python overhead per step. Requires ``_ls_ik_cuda_lib.so`` compiled from ``_ls_ik_cuda_kernel.cu``: - bash src/pyroffi/cuda_kernels/build_ls_ik_cuda.sh + bash build_kernels/build_ls_ik_cuda.sh Multi-EE support Pass multiple end-effectors via ``target_link_indices`` (tuple) and @@ -923,7 +923,7 @@ def _ls_ik_solve_cuda_batch_jit( constraint_args: tuple = (), constraint_weights: Float[Array, "n_constraints"] | None = None, ) -> Float[Array, "n_problems n_act"]: - from ..cuda_kernels._ls_ik_cuda import ls_ik_cuda + from ..cuda_kernels.ik._ls_ik_cuda import ls_ik_cuda n_act = robot.joints.num_actuated_joints lower = robot.joints.lower_limits diff --git a/src/pyroffi/optimization_engines/_ls_trajopt_optimization.py b/src/pyroffi/optimization_engines/_ls_trajopt_optimization.py index e70ff092..e72e86b1 100644 --- a/src/pyroffi/optimization_engines/_ls_trajopt_optimization.py +++ b/src/pyroffi/optimization_engines/_ls_trajopt_optimization.py @@ -517,7 +517,7 @@ def ls_trajopt( key = jax.random.PRNGKey(0) if use_cuda and opt_cfg.use_legacy_cuda_kernel: - from ..cuda_kernels._ls_trajopt_cuda import ls_trajopt_cuda + from ..cuda_kernels.trajopt._ls_trajopt_cuda import ls_trajopt_cuda best_traj, costs, final_trajs = ls_trajopt_cuda( init_trajs=init_trajs, diff --git a/src/pyroffi/optimization_engines/_mppi_ik.py b/src/pyroffi/optimization_engines/_mppi_ik.py index 1d2673e3..d9d21b16 100644 --- a/src/pyroffi/optimization_engines/_mppi_ik.py +++ b/src/pyroffi/optimization_engines/_mppi_ik.py @@ -638,7 +638,7 @@ def _mppi_ik_solve_cuda_jit( constraint_args: tuple = (), constraint_weights: Float[Array, "n_constraints"] | None = None, ) -> tuple[Float[Array, "n_act"], Array]: - from ..cuda_kernels._mppi_ik_cuda import mppi_ik_cuda + from ..cuda_kernels.ik._mppi_ik_cuda import mppi_ik_cuda n_act = robot.joints.num_actuated_joints lower = robot.joints.lower_limits @@ -770,7 +770,7 @@ def mppi_ik_solve_cuda( """CUDA MPPI+L-BFGS IK: coarse particle search then gradient refinement. Requires ``_mppi_ik_cuda_lib.so`` compiled from ``_mppi_ik_cuda_kernel.cu``: - bash src/pyroffi/cuda_kernels/build_mppi_ik_cuda.sh + bash build_kernels/build_mppi_ik_cuda.sh Stage 1 (MPPI) At each of ``n_mppi_iters`` iterations, ``n_particles`` Gaussian @@ -983,7 +983,7 @@ def _mppi_ik_solve_cuda_batch_jit( constraint_args: tuple = (), constraint_weights: Float[Array, "n_constraints"] | None = None, ) -> Float[Array, "n_problems n_act"]: - from ..cuda_kernels._mppi_ik_cuda import mppi_ik_cuda + from ..cuda_kernels.ik._mppi_ik_cuda import mppi_ik_cuda n_act = robot.joints.num_actuated_joints lower = robot.joints.lower_limits diff --git a/src/pyroffi/optimization_engines/_region_ik.py b/src/pyroffi/optimization_engines/_region_ik.py index d9e432bb..82bc78f3 100644 --- a/src/pyroffi/optimization_engines/_region_ik.py +++ b/src/pyroffi/optimization_engines/_region_ik.py @@ -116,7 +116,7 @@ def _brownian_motion_batch_select_jit( threads_per_block: int = 128, ) -> tuple[Array, Array, Array, Array, Array]: """JIT-compiled CUDA batch solve + per-target restart winner selection.""" - from ..cuda_kernels._brownian_motion_ik_cuda import brownian_motion_ik_cuda + from ..cuda_kernels.region_ik._brownian_motion_ik_cuda import brownian_motion_ik_cuda cfgs, errs, ee_points, target_points = brownian_motion_ik_cuda( seeds=seeds, @@ -196,7 +196,7 @@ def _svgd_region_batch_select_jit( threads_per_block: int = 128, ) -> tuple[Array, Array, Array, Array, Array]: """JIT-compiled SVGD CUDA batch solve + per-target restart winner selection.""" - from ..cuda_kernels._svgd_region_ik_cuda import svgd_region_ik_cuda + from ..cuda_kernels.region_ik._svgd_region_ik_cuda import svgd_region_ik_cuda cfgs, errs, ee_points, target_points = svgd_region_ik_cuda( seeds=seeds, @@ -895,7 +895,7 @@ def _hit_and_run_batch_select_jit( noise_std: float, threads_per_block: int = 128, ) -> tuple[Array, Array, Array, Array, Array]: - from ..cuda_kernels._hit_and_run_ik_cuda import hit_and_run_ik_cuda + from ..cuda_kernels.region_ik._hit_and_run_ik_cuda import hit_and_run_ik_cuda cfgs, errs, ee_points, target_points = hit_and_run_ik_cuda( seeds=seeds, diff --git a/src/pyroffi/optimization_engines/_sco_optimization.py b/src/pyroffi/optimization_engines/_sco_optimization.py index 7f282e15..ee55b9a8 100644 --- a/src/pyroffi/optimization_engines/_sco_optimization.py +++ b/src/pyroffi/optimization_engines/_sco_optimization.py @@ -535,7 +535,7 @@ def sco_trajopt( final_trajs: All optimized trajectories. [B, T, DOF]. """ if use_cuda: - from ..cuda_kernels._sco_trajopt_cuda import sco_trajopt_cuda + from ..cuda_kernels.trajopt._sco_trajopt_cuda import sco_trajopt_cuda return sco_trajopt_cuda( init_trajs, start, goal, robot, robot_coll, world_geoms, opt_cfg ) diff --git a/src/pyroffi/optimization_engines/_sqp_ik.py b/src/pyroffi/optimization_engines/_sqp_ik.py index 3fd40389..4b6a22a6 100644 --- a/src/pyroffi/optimization_engines/_sqp_ik.py +++ b/src/pyroffi/optimization_engines/_sqp_ik.py @@ -447,7 +447,7 @@ def _sqp_ik_solve_cuda_jit( constraint_args: tuple = (), constraint_weights: Float[Array, "n_constraints"] | None = None, ) -> Float[Array, "n_act"]: - from ..cuda_kernels._sqp_ik_cuda import sqp_ik_cuda + from ..cuda_kernels.ik._sqp_ik_cuda import sqp_ik_cuda n_act = robot.joints.num_actuated_joints lower = robot.joints.lower_limits @@ -573,7 +573,7 @@ def sqp_ik_solve_cuda( limits as hard QP constraints. Requires ``_sqp_ik_cuda_lib.so`` compiled from ``_sqp_ik_cuda_kernel.cu``: - bash src/pyroffi/cuda_kernels/build_sqp_ik_cuda.sh + bash build_kernels/build_sqp_ik_cuda.sh Multi-EE support Pass multiple end-effectors via ``target_link_indices`` (tuple) and @@ -779,7 +779,7 @@ def _sqp_ik_solve_cuda_batch_jit( constraint_args: tuple = (), constraint_weights: Float[Array, "n_constraints"] | None = None, ) -> Float[Array, "n_problems n_act"]: - from ..cuda_kernels._sqp_ik_cuda import sqp_ik_cuda + from ..cuda_kernels.ik._sqp_ik_cuda import sqp_ik_cuda n_act = robot.joints.num_actuated_joints lower = robot.joints.lower_limits diff --git a/src/pyroffi/optimization_engines/_stomp_optimization.py b/src/pyroffi/optimization_engines/_stomp_optimization.py index f74e07dc..081f1aab 100644 --- a/src/pyroffi/optimization_engines/_stomp_optimization.py +++ b/src/pyroffi/optimization_engines/_stomp_optimization.py @@ -816,7 +816,7 @@ def stomp_trajopt( upper = robot.joints.upper_limits.astype(jnp.float32) if use_cuda: - from ..cuda_kernels._stomp_trajopt_cuda import stomp_trajopt_cuda + from ..cuda_kernels.trajopt._stomp_trajopt_cuda import stomp_trajopt_cuda best_traj, costs, final_trajs = stomp_trajopt_cuda( init_trajs, start, goal, robot, robot_coll, world_geoms, opt_cfg, key=key ) diff --git a/src/pyroffi/vamp_kernels/_edge_validation_ffi.hh b/src/pyroffi/vamp_kernels/_edge_validation_ffi.hh new file mode 100644 index 00000000..ec535ed8 --- /dev/null +++ b/src/pyroffi/vamp_kernels/_edge_validation_ffi.hh @@ -0,0 +1,349 @@ +// Generic JAX FFI handlers for VAMP's CPU collision checker. +// +// This header is robot-agnostic: include it from a small translation unit that +// first defines a `vamp::robots::` struct (e.g. one emitted by cricket's +// `generate_robot_source`) and the macros below, and it will emit two XLA FFI +// custom-call handlers specialised for that robot: +// +// * validate_configs_ — per-configuration validity (fused FK + CC). +// a : F32 [B, dim] configurations (raw joint values) +// -> r : PRED [B] true == collision-free +// +// * validate_edges_ — batch edge validation (the gtmp branch's +// `validate_motion_batch`, but exposed through the JAX FFI). +// a : F32 [E, dim] edge start configs +// b : F32 [E, dim] edge goal configs +// -> r : PRED [E] true == whole edge collision-free +// +// Both handlers take the world geometry as flat float buffers so the buffer +// layout matches pyroffi's existing CUDA binary checker +// (`_extract_world_arrays`), plus a CAPT point-cloud buffer so point-cloud +// obstacles are checked too: +// +// spheres : F32 [Ms, 4] (cx, cy, cz, r) +// capsules : F32 [Mc, 7] (ax, ay, az, bx, by, bz, r) endpoints +// cuboids : F32 [Mb, 15] (cx,cy,cz, axis1(3), axis2(3), axis3(3), half(3)) +// points : F32 [Mp, 3] (x, y, z) CAPT cloud +// attrs : r_min, r_max, r_point (CAPT query/point radii) +// +// HalfSpace obstacles are intentionally not handled here: VAMP's +// `collision::Environment` has no half-space primitive. Callers that need a +// ground plane should pass a large flat cuboid instead. + +#pragma once + +#if defined(VAMP_JAX_ROBOT) && defined(VAMP_JAX_ROBOT_NAME) + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace ffi = xla::ffi; + +namespace vamp::binding::jax +{ + static constexpr const std::size_t rake = vamp::FloatVectorWidth; + using EnvironmentF = vamp::collision::Environment; + using EnvironmentVector = vamp::collision::Environment>; + + // Reconstruct a scalar (float) Environment from the flat FFI buffers. + // + // The CAPT point cloud is built only when at least one finite point is + // supplied. Building the affordance tree is O(n log n); for repeated edge + // batches against a static cloud the caller should keep the cloud fixed so + // JAX caches the buffer on device, but the tree itself is rebuilt per call — + // a future optimisation could memoise it keyed by the points pointer. + inline auto build_environment( + const float *spheres, std::size_t n_spheres, + const float *capsules, std::size_t n_capsules, + const float *cuboids, std::size_t n_cuboids, + const float *points, std::size_t n_points, + float capt_r_min, float capt_r_max, float capt_r_point) -> EnvironmentF + { + EnvironmentF env; + + env.spheres.reserve(n_spheres); + for (std::size_t i = 0; i < n_spheres; ++i) + { + const float *s = &spheres[i * 4]; + env.spheres.emplace_back(vamp::collision::factory::sphere::flat(s[0], s[1], s[2], s[3])); + } + + env.capsules.reserve(n_capsules); + for (std::size_t i = 0; i < n_capsules; ++i) + { + const float *c = &capsules[i * 7]; + env.capsules.emplace_back(vamp::collision::factory::capsule::endpoints::flat( + c[0], c[1], c[2], c[3], c[4], c[5], c[6])); + } + + env.cuboids.reserve(n_cuboids); + for (std::size_t i = 0; i < n_cuboids; ++i) + { + const float *b = &cuboids[i * 15]; + // (cx,cy,cz, axis1(3), axis2(3), axis3(3), half(3)) maps directly to + // the Cuboid(center, axis1, axis2, axis3, half-extents) constructor. + env.cuboids.emplace_back(vamp::collision::Cuboid( + b[0], b[1], b[2], + b[3], b[4], b[5], + b[6], b[7], b[8], + b[9], b[10], b[11], + b[12], b[13], b[14])); + } + + if (n_points > 0) + { + std::vector cloud; + cloud.reserve(n_points); + for (std::size_t i = 0; i < n_points; ++i) + { + const float *p = &points[i * 3]; + cloud.push_back(vamp::collision::Point{p[0], p[1], p[2]}); + } + env.pointclouds.emplace_back(cloud, capt_r_min, capt_r_max, capt_r_point); + } + + env.sort(); + return env; + } + + // FNV-1a over a raw byte range, chained into a running hash. + inline auto hash_bytes(const void *data, std::size_t nbytes, std::uint64_t h) noexcept + -> std::uint64_t + { + const auto *p = static_cast(data); + for (std::size_t i = 0; i < nbytes; ++i) + { + h ^= p[i]; + h *= 1099511628211ULL; + } + return h; + } + + // Content hash of the world buffers (obstacle counts, raw float payloads, and + // CAPT radii). Two calls with byte-identical worlds hash equal, so the + // memoised Environment below is reused; the 64-bit FNV-1a collision risk + // (~2^-64) is negligible for a perf cache over fixed-layout buffers. + inline auto world_hash( + const float *spheres, std::size_t n_spheres, + const float *capsules, std::size_t n_capsules, + const float *cuboids, std::size_t n_cuboids, + const float *points, std::size_t n_points, + float capt_r_min, float capt_r_max, float capt_r_point) noexcept -> std::uint64_t + { + std::uint64_t h = 1469598103934665603ULL; // FNV offset basis + const std::size_t sizes[4] = {n_spheres, n_capsules, n_cuboids, n_points}; + h = hash_bytes(sizes, sizeof(sizes), h); + h = hash_bytes(spheres, n_spheres * 4 * sizeof(float), h); + h = hash_bytes(capsules, n_capsules * 7 * sizeof(float), h); + h = hash_bytes(cuboids, n_cuboids * 15 * sizeof(float), h); + h = hash_bytes(points, n_points * 3 * sizeof(float), h); + const float attrs[3] = {capt_r_min, capt_r_max, capt_r_point}; + h = hash_bytes(attrs, sizeof(attrs), h); + return h; + } + + // Build (or reuse) the SIMD Environment for these world buffers. + // + // The JIT-compiled handler is called repeatedly against the *same* static + // world during planning / benchmarking, yet each call would otherwise rebuild + // the obstacle vectors, re-sort, and reconstruct the CAPT affordance tree + // (O(n log n)) before a single config is checked — pure per-call overhead. + // We memoise the last-built EnvironmentVector keyed by ``world_hash`` so an + // unchanged world is free after the first call. + // + // The cached Environment is heap-allocated and *intentionally never freed*: + // (1) a previously-returned reference may still be in use by a concurrent + // kernel when the world changes, and (2) it keeps the cache statics + // trivially destructible — a function-local ``static`` with a non-trivial + // destructor (e.g. a ``shared_ptr``) makes the compiler emit an + // ``__cxa_atexit(__dso_handle)`` registration that cricket's ORC JIT + // cannot relocate. Worlds change rarely, so the leak is bounded. + inline auto environment_for( + const float *spheres, std::size_t n_spheres, + const float *capsules, std::size_t n_capsules, + const float *cuboids, std::size_t n_cuboids, + const float *points, std::size_t n_points, + float capt_r_min, float capt_r_max, float capt_r_point) -> const EnvironmentVector & + { + const std::uint64_t h = world_hash( + spheres, n_spheres, capsules, n_capsules, cuboids, n_cuboids, + points, n_points, capt_r_min, capt_r_max, capt_r_point); + + static std::mutex mtx; // trivially destructible + static std::uint64_t cached_hash = 0; + static const EnvironmentVector *cached = nullptr; // leaked; never freed + + std::lock_guard lock(mtx); + if (cached != nullptr and cached_hash == h) + { + return *cached; + } + + const EnvironmentF env_f = build_environment( + spheres, n_spheres, capsules, n_capsules, cuboids, n_cuboids, + points, n_points, capt_r_min, capt_r_max, capt_r_point); + cached = new EnvironmentVector(env_f); + cached_hash = h; + return *cached; + } + + // Build a robot Configuration from a dense [dim] row. + // + // We must NOT construct directly from the raw row pointer: a FloatVector + // rounds the dimension up to the SIMD width and loads all rounded lanes, so + // an unaligned load off the final row reads out-of-bounds padding, which + // perturbs l2_norm and hence the edge sample count. Mirror VAMP's + // validate_motion_batch: zero-initialise an aligned buffer, copy `dimension` + // scalars, then load. + template + inline auto make_configuration(const float *row) noexcept -> typename Robot::Configuration + { + typename Robot::ConfigurationBuffer buf{}; // zero-initialised padding + for (std::size_t d = 0; d < Robot::dimension; ++d) + { + buf[d] = row[d]; + } + return typename Robot::Configuration(buf.data()); + } + + template + inline auto validate_configs_impl( + ffi::Buffer a, + ffi::Buffer spheres, + ffi::Buffer capsules, + ffi::Buffer cuboids, + ffi::Buffer points, + float capt_r_min, + float capt_r_max, + float capt_r_point, + ffi::ResultBuffer r) noexcept -> ffi::Error + { + const auto a_d = a.dimensions(); + const std::size_t B = a_d[0]; + const float *a_data = a.typed_data(); + bool *r_data = r->typed_data(); + + const EnvironmentVector &env = environment_for( + spheres.typed_data(), spheres.dimensions()[0], + capsules.typed_data(), capsules.dimensions()[0], + cuboids.typed_data(), cuboids.dimensions()[0], + points.typed_data(), points.dimensions()[0], + capt_r_min, capt_r_max, capt_r_point); + +#ifdef _OPENMP +#pragma omp parallel for schedule(dynamic, 1000) +#endif + for (std::size_t i = 0; i < B; ++i) + { + const auto c = make_configuration(&a_data[i * Robot::dimension]); + // A zero-length motion reduces to a single fused FK + CC pass. + r_data[i] = vamp::planning::validate_motion(c, c, env); + } + + return ffi::Error::Success(); + } + + template + inline auto validate_edges_impl( + ffi::Buffer a, + ffi::Buffer b, + ffi::Buffer spheres, + ffi::Buffer capsules, + ffi::Buffer cuboids, + ffi::Buffer points, + float capt_r_min, + float capt_r_max, + float capt_r_point, + ffi::ResultBuffer r) noexcept -> ffi::Error + { + const auto a_d = a.dimensions(); + const auto b_d = b.dimensions(); + const std::size_t E = a_d[0]; + if (b_d[0] != a_d[0]) + { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "validate_edges expects a and b to have the same number of rows"); + } + + const float *a_data = a.typed_data(); + const float *b_data = b.typed_data(); + bool *r_data = r->typed_data(); + + const EnvironmentVector &env = environment_for( + spheres.typed_data(), spheres.dimensions()[0], + capsules.typed_data(), capsules.dimensions()[0], + cuboids.typed_data(), cuboids.dimensions()[0], + points.typed_data(), points.dimensions()[0], + capt_r_min, capt_r_max, capt_r_point); + +#ifdef _OPENMP +#pragma omp parallel for schedule(dynamic, 1000) +#endif + for (std::size_t i = 0; i < E; ++i) + { + const auto ca = make_configuration(&a_data[i * Robot::dimension]); + const auto cb = make_configuration(&b_data[i * Robot::dimension]); + // validate_motion samples the open interval (0, 1]: it checks the goal + // and the interior but assumes the start is already valid (the usual + // planner contract, and what VAMP's own validate_motion[_batch] does). + // A "valid" edge therefore guarantees the goal endpoint and interior + // are collision-free; callers needing the start checked should + // validate it separately (check_collision_free). + r_data[i] = + vamp::planning::validate_motion(ca, cb, env); + } + + return ffi::Error::Success(); + } +} // namespace vamp::binding::jax + +#define VAMP_PASTE(A, B) A##B +#define VAMP_XSTRING(A) VAMP_STRING(A) +#define VAMP_STRING(A) #A +#define VAMP_CONFIGS_SYMBOL(robot_name) VAMP_PASTE(validate_configs_, robot_name) +#define VAMP_EDGES_SYMBOL(robot_name) VAMP_PASTE(validate_edges_, robot_name) + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + VAMP_CONFIGS_SYMBOL(VAMP_JAX_ROBOT_NAME), + vamp::binding::jax::validate_configs_impl, + ffi::Ffi::Bind() + .Arg>() // a [B, dim] + .Arg>() // spheres [Ms, 4] + .Arg>() // capsules [Mc, 7] + .Arg>() // cuboids [Mb, 15] + .Arg>() // points [Mp, 3] + .Attr("capt_r_min") + .Attr("capt_r_max") + .Attr("capt_r_point") + .Ret>() // [B] validity +); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + VAMP_EDGES_SYMBOL(VAMP_JAX_ROBOT_NAME), + vamp::binding::jax::validate_edges_impl, + ffi::Ffi::Bind() + .Arg>() // a [E, dim] + .Arg>() // b [E, dim] + .Arg>() // spheres [Ms, 4] + .Arg>() // capsules [Mc, 7] + .Arg>() // cuboids [Mb, 15] + .Arg>() // points [Mp, 3] + .Attr("capt_r_min") + .Attr("capt_r_max") + .Attr("capt_r_point") + .Ret>() // [E] validity +); + +#endif // VAMP_JAX_ROBOT && VAMP_JAX_ROBOT_NAME diff --git a/src/pyroffi/vamp_kernels/_robot_edge_validation_tu.cc.in b/src/pyroffi/vamp_kernels/_robot_edge_validation_tu.cc.in new file mode 100644 index 00000000..d19288ac --- /dev/null +++ b/src/pyroffi/vamp_kernels/_robot_edge_validation_tu.cc.in @@ -0,0 +1,33 @@ +// Per-robot translation unit compiled by cricket's JIT at runtime. +// +// pyroffi substitutes the @PLACEHOLDERS@ (simple string replacement, not CMake +// configure) and hands the result to `cricket.jit.ClangCompiler`. The two +// `extern "C"` accessors give the JIT session stable, unmangled symbol names to +// look up; each returns the address of an XLA FFI custom-call handler that JAX +// registers with `jax.ffi.register_ffi_target(..., platform="cpu")`. +// +// Placeholders: +// @ROBOT_HEADER@ absolute path to the cricket-generated robot header +// @FFI_HEADER@ absolute path to _edge_validation_ffi.hh +// @ROBOT_TYPE@ fully-qualified robot struct, e.g. vamp::robots::Panda +// @ROBOT_NAME@ lowercase token used as the handler symbol suffix + +#include "@ROBOT_HEADER@" + +#define VAMP_JAX_ROBOT @ROBOT_TYPE@ +#define VAMP_JAX_ROBOT_NAME @ROBOT_NAME@ + +#include "@FFI_HEADER@" + +#undef VAMP_JAX_ROBOT +#undef VAMP_JAX_ROBOT_NAME + +extern "C" void *pyroffi_get_validate_configs() noexcept +{ + return reinterpret_cast(&validate_configs_@ROBOT_NAME@); +} + +extern "C" void *pyroffi_get_validate_edges() noexcept +{ + return reinterpret_cast(&validate_edges_@ROBOT_NAME@); +} diff --git a/tests/_yourdfpy_compat.py b/tests/_yourdfpy_compat.py new file mode 100644 index 00000000..af783c27 --- /dev/null +++ b/tests/_yourdfpy_compat.py @@ -0,0 +1,62 @@ +"""Process-local shim: make ``yourdfpy==0.0.58`` work under numpy >= 2. + +yourdfpy 0.0.58 computes joint FK with ``float(q)`` where ``q`` can be a size-1 +array. numpy >= 2 makes ``float()`` an error, so loading any URDF +with revolute/continuous joints raises ``TypeError: only 0-dimensional arrays +can be converted to Python scalars``. + +This re-binds ``URDF._forward_kinematics_joint`` with an identical implementation +that coerces ``q`` to a scalar safely. It does NOT modify any installed files +and is a no-op once yourdfpy/numpy are mutually compatible. + +Call :func:`apply` once before loading URDFs. +""" + +from __future__ import annotations + + +def apply() -> bool: + """Install the shim if needed. Returns True if it was applied.""" + import numpy as np + import trimesh.transformations as tra + from yourdfpy.urdf import URDF + + # Already compatible? (numpy < 2 accepts float() on size-1 arrays.) + try: + float(np.zeros(1)) + return False + except TypeError: + pass + + def _forward_kinematics_joint(self, joint, q=None): + origin = np.eye(4) if joint.origin is None else joint.origin + + if joint.mimic is not None: + if joint.mimic.joint in self.actuated_joint_names: + mimic_joint_index = self.actuated_joint_names.index(joint.mimic.joint) + q = ( + self._cfg[mimic_joint_index] * joint.mimic.multiplier + + joint.mimic.offset + ) + else: + q = 0.0 + joint.mimic.offset + + if joint.type in ["revolute", "prismatic", "continuous"]: + if q is None: + q = self.cfg[ + self.actuated_dof_indices[ + self.actuated_joint_names.index(joint.name) + ] + ] + if joint.type == "prismatic": + matrix = origin @ tra.translation_matrix(q * joint.axis) + else: + angle = float(np.asarray(q).reshape(-1)[0]) + matrix = origin @ tra.rotation_matrix(angle, joint.axis) + else: + matrix = origin + + return matrix, q + + URDF._forward_kinematics_joint = _forward_kinematics_joint + return True diff --git a/tests/bench_collision.py b/tests/bench_collision.py index 80aab55e..938f2a4d 100644 --- a/tests/bench_collision.py +++ b/tests/bench_collision.py @@ -14,9 +14,16 @@ skipped and all-+1 distances are returned. NOT differentiable — do not use in trajopt. + Binary checkers (bool per config — for sampling-based planning / edges): + CUDA-Binary — CUDABinaryCollisionChecker (fused FK, GPU) + CUDA-Binary-Coarse — CUDABinaryCollisionChecker with coarse-first guard + VAMP-CPU — VAMPCPUCollisionChecker (JIT-compiled VAMP fkcc, CPU) + Operations: compute_world_collision_distance(robot, cfg, world_geom) compute_self_collision_distance(robot, cfg) + check_collision_free(robot, cfg, world_geom) [binary backends] + check_edges_collision_free(robot, edges, world_geom) [binary backends] Metrics (per backend × batch size): ms/call — wall-clock milliseconds for the full batched call @@ -29,7 +36,7 @@ Prerequisites: pip install robot_descriptions - bash src/pyroffi/cuda_kernels/build_collision_cuda.sh (for CUDA backends) + bash build_kernels/build_collision_cuda.sh (for CUDA backends) (pynvml optional, for GPU monitoring: pip install nvidia-ml-py) Neural training: @@ -57,6 +64,7 @@ from pyroffi.collision import ( CUDARobotCollisionChecker, + CUDABinaryCollisionChecker, NeuralRobotCollision, RobotCollision, RobotCollisionSpherized, @@ -66,6 +74,24 @@ HalfSpace, ) +try: + from pyroffi.collision import VAMPCPUCollisionChecker +except Exception: # cricket not built + VAMPCPUCollisionChecker = None + +try: + from pyroffi.collision import RoboGPUCollisionChecker +except Exception: + RoboGPUCollisionChecker = None + +# yourdfpy 0.0.58 is incompatible with numpy >= 2 (float() on size-1 arrays); +# install a process-local shim so URDF loading works. No-op when unneeded. +try: + import _yourdfpy_compat + _yourdfpy_compat.apply() +except Exception: + pass + # Optional GPU monitoring try: import pynvml as _pynvml @@ -86,13 +112,22 @@ _REPO_ROOT = pathlib.Path(__file__).resolve().parents[1] SPHERIZED_URDF = _REPO_ROOT / "resources" / "panda" / "panda_spherized.urdf" COARSE_SPHERIZED_URDF = _REPO_ROOT / "resources" / "panda" / "panda_spherized_coarse.urdf" +PANDA_SRDF = _REPO_ROOT / "resources" / "panda" / "panda.srdf" # Number of random configs to use as the test set N_WARMUP = 3 # JIT / kernel warm-up calls (results discarded) N_TIMED = 7 # timed repetitions; median is reported # Batch sizes to sweep -BATCH_SIZES = [1, 64, 512, 2048] +BATCH_SIZES = [1, 64, 512, 2048, 4096, 8192, 16384] + +# Number of discretised points per edge for the *CUDA* binary edge check (VAMP +# discretises internally at the robot resolution, so it only takes endpoints). +EDGE_GRANULARITY = 16 + +# RoboGPU point-cloud scene parameters +N_POINT_CLOUD = 8192 # environment points in the point cloud +POINT_CLOUD_R_ENV = 0.01 # sphere radius per environment point (metres) # World scene: a small set of obstacles (Sphere + Capsule + Box) N_WORLD_SPHERES = 4 @@ -324,6 +359,66 @@ def _bench_backend( return results + +def _bench_binary_backend( + label: str, + checker, + robot, + world_geom, + cfgs_by_batch: dict[int, jnp.ndarray], + op_tag: str = "binary", +) -> list[BenchResult]: + """Benchmark a *binary* collision checker's ``check_collision_free``. + + Binary checkers (CUDA-Binary, VAMP-CPU) return one bool per configuration + rather than a distance matrix. ``robot`` is passed through for API parity + (VAMP ignores it — FK is baked into its JIT binary).""" + results = [] + for B, cfgs in cfgs_by_batch.items(): + try: + fn = lambda c=cfgs: checker.check_collision_free(robot, c, world_geom) + logger.debug(f" {label} {op_tag} B={B}: warming up ({N_WARMUP}×)...") + _warmup(fn) + _, t_s, pk_gpu, pk_vram = _time_fn_gpu(fn) + results.append(BenchResult( + label=label, batch_size=B, op=op_tag, + ms_call=t_s * 1e3, ms_cfg=t_s * 1e3 / B, + peak_gpu=pk_gpu, peak_vram=pk_vram, + )) + except Exception as exc: + logger.warning(f" {label} {op_tag} B={B}: SKIPPED ({exc})") + return results + + +def _bench_edges_backend( + label: str, + checker, + robot, + world_geom, + edges_by_batch: dict[int, jnp.ndarray], + op_tag: str = "edges", +) -> list[BenchResult]: + """Benchmark batch edge validation (``check_edges_collision_free``). + + ``edges_by_batch[B]`` is shaped for the checker's convention: ``[E, G, n]`` + pre-discretised points for the CUDA checker, ``[E, 2, n]`` endpoints for + VAMP. ``ms_cfg`` reports per-edge time.""" + results = [] + for E, edges in edges_by_batch.items(): + try: + fn = lambda e=edges: checker.check_edges_collision_free(robot, e, world_geom) + logger.debug(f" {label} {op_tag} E={E}: warming up ({N_WARMUP}×)...") + _warmup(fn) + _, t_s, pk_gpu, pk_vram = _time_fn_gpu(fn) + results.append(BenchResult( + label=label, batch_size=E, op=op_tag, + ms_call=t_s * 1e3, ms_cfg=t_s * 1e3 / E, + peak_gpu=pk_gpu, peak_vram=pk_vram, + )) + except Exception as exc: + logger.warning(f" {label} {op_tag} E={E}: SKIPPED ({exc})") + return results + # --------------------------------------------------------------------------- # Table formatting (mirrors bench_ik.py style) # --------------------------------------------------------------------------- @@ -428,13 +523,19 @@ def main(args) -> None: print("=" * 80) # ── Robot (capsule model) ────────────────────────────────────────────── + # The capsule model + distance backends pull in trimesh/scipy convex-hull + # fitting; --binary-only skips all of that and benchmarks just the binary + # checkers (which only need sphere primitives / the JIT'd VAMP binary). print("\nLoading robot ...") - urdf = load_robot_description(f"{args.robot}_description") - robot_cap = pk.Robot.from_urdf(urdf) - n_act_cap = robot_cap.joints.num_actuated_joints - lo_cap = np.asarray(robot_cap.joints.lower_limits) - hi_cap = np.asarray(robot_cap.joints.upper_limits) - print(f" {args.robot}_description : {n_act_cap} actuated DOF") + urdf = robot_cap = None + lo_cap = hi_cap = None + if not args.binary_only: + urdf = load_robot_description(f"{args.robot}_description") + robot_cap = pk.Robot.from_urdf(urdf) + n_act_cap = robot_cap.joints.num_actuated_joints + lo_cap = np.asarray(robot_cap.joints.lower_limits) + hi_cap = np.asarray(robot_cap.joints.upper_limits) + print(f" {args.robot}_description : {n_act_cap} actuated DOF") # ── Robot (sphere model) — separate URDF with sphere primitives ──────── sph_urdf_path = args.spherized_urdf @@ -462,8 +563,10 @@ def main(args) -> None: # ── Collision models ─────────────────────────────────────────────────── print("\nBuilding collision models ...") - coll_cap = RobotCollision.from_urdf(urdf) - print(f" RobotCollision : {coll_cap.num_links} links") + coll_cap = None + if not args.binary_only: + coll_cap = RobotCollision.from_urdf(urdf) + print(f" RobotCollision : {coll_cap.num_links} links") coll_sph = RobotCollisionSpherized.from_urdf(urdf_sph) print(f" RobotCollisionSpherized : {coll_sph.num_links} links") @@ -475,19 +578,20 @@ def main(args) -> None: cuda_available = False cuda_cap = cuda_sph = cuda_sph_coarse = None - try: - cuda_cap = CUDARobotCollisionChecker(coll_cap) - cuda_sph = CUDARobotCollisionChecker(coll_sph) - if coll_sph_coarse is not None: - cuda_sph_coarse = CUDARobotCollisionChecker(coll_sph, coarse_inner=coll_sph_coarse) - cuda_available = True - print(" CUDARobotCollisionChecker: OK (JAX FFI library loaded)") - if cuda_sph_coarse is not None: - print(" CUDARobotCollisionChecker (coarse-first): OK") - else: - print(" CUDARobotCollisionChecker (coarse-first): SKIP (no coarse URDF)") - except RuntimeError as e: - print(f" CUDARobotCollisionChecker: SKIP ({e})") + if not args.binary_only: + try: + cuda_cap = CUDARobotCollisionChecker(coll_cap) + cuda_sph = CUDARobotCollisionChecker(coll_sph) + if coll_sph_coarse is not None: + cuda_sph_coarse = CUDARobotCollisionChecker(coll_sph, coarse_inner=coll_sph_coarse) + cuda_available = True + print(" CUDARobotCollisionChecker: OK (JAX FFI library loaded)") + if cuda_sph_coarse is not None: + print(" CUDARobotCollisionChecker (coarse-first): OK") + else: + print(" CUDARobotCollisionChecker (coarse-first): SKIP (no coarse URDF)") + except RuntimeError as e: + print(f" CUDARobotCollisionChecker: SKIP ({e})") # ── World scene ──────────────────────────────────────────────────────── print("\nBuilding world scene ...") @@ -509,15 +613,17 @@ def main(args) -> None: # ── Random configs per batch size (one set per DOF count) ───────────── print("\nGenerating configs ...") max_B = max(BATCH_SIZES) - cfgs_cap_np = rng.uniform(lo_cap, hi_cap, (max_B, lo_cap.shape[0])).astype(np.float32) cfgs_sph_np = rng.uniform(lo_sph, hi_sph, (max_B, lo_sph.shape[0])).astype(np.float32) - - cfgs_cap_by_batch: dict[int, jnp.ndarray] = {B: jnp.array(cfgs_cap_np[:B]) for B in BATCH_SIZES} cfgs_sph_by_batch: dict[int, jnp.ndarray] = {B: jnp.array(cfgs_sph_np[:B]) for B in BATCH_SIZES} + cfgs_cap_by_batch: dict[int, jnp.ndarray] = {} + if not args.binary_only: + cfgs_cap_np = rng.uniform(lo_cap, hi_cap, (max_B, lo_cap.shape[0])).astype(np.float32) + cfgs_cap_by_batch = {B: jnp.array(cfgs_cap_np[:B]) for B in BATCH_SIZES} + # ── Neural SDF (one model per primitive type) ─────────────────────────── neural_models: dict[str, NeuralRobotCollision] = {} - if not args.skip_neural: + if not args.binary_only and not args.skip_neural: print(f"\nTraining NeuralRobotCollision per primitive type " f"(samples={args.neural_samples}, epochs={NEURAL_EPOCHS}) ...") neural_base = NeuralRobotCollision.from_existing( @@ -546,110 +652,286 @@ def main(args) -> None: all_results: list[BenchResult] = [] - # --- Self-collision benchmarks (primitive-independent) ----------------- - print("\n Self-collision benchmarks ...") - - print(" JAX-Capsule self ...") - all_results += _bench_backend( - "JAX-Capsule", coll_cap, robot_cap, world_geom, cfgs_cap_by_batch, - skip_world=True, use_vmap=True, - ) + # --- Distance backends (capsule / sphere / neural; --binary-only skips) - + if not args.binary_only: + # --- Self-collision benchmarks (primitive-independent) ------------- + print("\n Self-collision benchmarks ...") - print(" JAX-Sphere self ...") - all_results += _bench_backend( - "JAX-Sphere", coll_sph, robot_sph, world_geom, cfgs_sph_by_batch, - skip_world=True, use_vmap=True, - ) - - if neural_models: - # Neural self-collision delegates to JAX-Capsule (same kernel) - import copy - for r in all_results: - if r.label == "JAX-Capsule" and r.op == "self": - nr = copy.copy(r) - nr.label = "JAX-Neural" - all_results.append(nr) - - if cuda_available: - print(" CUDA-Capsule self ...") + print(" JAX-Capsule self ...") all_results += _bench_backend( - "CUDA-Capsule", cuda_cap, robot_cap, world_geom, cfgs_cap_by_batch, - skip_world=True, + "JAX-Capsule", coll_cap, robot_cap, world_geom, cfgs_cap_by_batch, + skip_world=True, use_vmap=True, ) - print(" CUDA-Sphere self ...") - all_results += _bench_backend( - "CUDA-Sphere", cuda_sph, robot_sph, world_geom, cfgs_sph_by_batch, - skip_world=True, - ) - if cuda_sph_coarse is not None: - print(" CUDA-Sphere-Coarse self ...") - all_results += _bench_backend( - "CUDA-Sphere-Coarse", cuda_sph_coarse, robot_sph, world_geom, - cfgs_sph_by_batch, skip_world=True, - ) - - # --- World-collision benchmarks (per primitive type) -------------------- - for prim_name, (prim_geom, prim_count) in world_primitives.items(): - print(f"\n World-collision benchmarks (obstacle={prim_name}, M={prim_count}) ...") - op_tag = f"world-{prim_name}" - print(f" JAX-Capsule ...") + print(" JAX-Sphere self ...") all_results += _bench_backend( - "JAX-Capsule", coll_cap, robot_cap, prim_geom, cfgs_cap_by_batch, - skip_self=True, use_vmap=True, world_op_tag=op_tag, + "JAX-Sphere", coll_sph, robot_sph, world_geom, cfgs_sph_by_batch, + skip_world=True, use_vmap=True, ) - print(f" JAX-Sphere ...") - all_results += _bench_backend( - "JAX-Sphere", coll_sph, robot_sph, prim_geom, cfgs_sph_by_batch, - skip_self=True, use_vmap=True, world_op_tag=op_tag, - ) + if neural_models: + # Neural self-collision delegates to JAX-Capsule (same kernel) + import copy + for r in all_results: + if r.label == "JAX-Capsule" and r.op == "self": + nr = copy.copy(r) + nr.label = "JAX-Neural" + all_results.append(nr) - if prim_name in neural_models: - print(f" JAX-Neural ...") + if cuda_available: + print(" CUDA-Capsule self ...") all_results += _bench_backend( - "JAX-Neural", neural_models[prim_name], robot_cap, prim_geom, - cfgs_cap_by_batch, - skip_self=True, use_vmap=True, world_op_tag=op_tag, + "CUDA-Capsule", cuda_cap, robot_cap, world_geom, cfgs_cap_by_batch, + skip_world=True, + ) + print(" CUDA-Sphere self ...") + all_results += _bench_backend( + "CUDA-Sphere", cuda_sph, robot_sph, world_geom, cfgs_sph_by_batch, + skip_world=True, ) + if cuda_sph_coarse is not None: + print(" CUDA-Sphere-Coarse self ...") + all_results += _bench_backend( + "CUDA-Sphere-Coarse", cuda_sph_coarse, robot_sph, world_geom, + cfgs_sph_by_batch, skip_world=True, + ) - if cuda_available: - print(f" CUDA-Capsule ...") + # --- World-collision benchmarks (per primitive type) -------------- + for prim_name, (prim_geom, prim_count) in world_primitives.items(): + print(f"\n World-collision benchmarks (obstacle={prim_name}, M={prim_count}) ...") + op_tag = f"world-{prim_name}" + + print(f" JAX-Capsule ...") all_results += _bench_backend( - "CUDA-Capsule", cuda_cap, robot_cap, prim_geom, cfgs_cap_by_batch, - skip_self=True, world_op_tag=op_tag, + "JAX-Capsule", coll_cap, robot_cap, prim_geom, cfgs_cap_by_batch, + skip_self=True, use_vmap=True, world_op_tag=op_tag, ) - print(f" CUDA-Sphere ...") + + print(f" JAX-Sphere ...") all_results += _bench_backend( - "CUDA-Sphere", cuda_sph, robot_sph, prim_geom, cfgs_sph_by_batch, - skip_self=True, world_op_tag=op_tag, + "JAX-Sphere", coll_sph, robot_sph, prim_geom, cfgs_sph_by_batch, + skip_self=True, use_vmap=True, world_op_tag=op_tag, ) - if cuda_sph_coarse is not None: - print(f" CUDA-Sphere-Coarse ...") + + if prim_name in neural_models: + print(f" JAX-Neural ...") + all_results += _bench_backend( + "JAX-Neural", neural_models[prim_name], robot_cap, prim_geom, + cfgs_cap_by_batch, + skip_self=True, use_vmap=True, world_op_tag=op_tag, + ) + + if cuda_available: + print(f" CUDA-Capsule ...") all_results += _bench_backend( - "CUDA-Sphere-Coarse", cuda_sph_coarse, robot_sph, prim_geom, - cfgs_sph_by_batch, skip_self=True, world_op_tag=op_tag, + "CUDA-Capsule", cuda_cap, robot_cap, prim_geom, cfgs_cap_by_batch, + skip_self=True, world_op_tag=op_tag, ) + print(f" CUDA-Sphere ...") + all_results += _bench_backend( + "CUDA-Sphere", cuda_sph, robot_sph, prim_geom, cfgs_sph_by_batch, + skip_self=True, world_op_tag=op_tag, + ) + if cuda_sph_coarse is not None: + print(f" CUDA-Sphere-Coarse ...") + all_results += _bench_backend( + "CUDA-Sphere-Coarse", cuda_sph_coarse, robot_sph, prim_geom, + cfgs_sph_by_batch, skip_self=True, world_op_tag=op_tag, + ) + + # --- Binary collision-check benchmarks (bool per config / per edge) ------ + # These are a different operation from the distance backends above: the + # binary checkers return collision-free verdicts and are intended for + # sampling-based planning / edge validation. World is restricted to spheres + # (the VAMP backend has no half-space primitive and we keep the comparison + # apples-to-apples across backends). + if not args.skip_binary: + print("\n Binary collision-check benchmarks (op=binary, world=Sphere) ...") + bin_world = world_spheres + + bin_checkers: list[tuple] = [] + # CUDA binary checkers reuse the fine (and coarse) spherized models. + try: + cuda_bin = CUDABinaryCollisionChecker(coll_sph) + bin_checkers.append(("CUDA-Binary", cuda_bin, robot_sph, cfgs_sph_by_batch)) + if coll_sph_coarse is not None: + cuda_bin_coarse = CUDABinaryCollisionChecker( + coll_sph, coarse_inner=coll_sph_coarse + ) + bin_checkers.append( + ("CUDA-Binary-Coarse", cuda_bin_coarse, robot_sph, cfgs_sph_by_batch) + ) + print(" CUDABinaryCollisionChecker: OK") + except Exception as exc: + print(f" CUDABinaryCollisionChecker: SKIP ({exc})") + + # VAMP CPU checker — JIT-compiled from the spherized URDF (its own DOF). + vamp_cfgs_by_batch = None + if VAMPCPUCollisionChecker is not None and not args.skip_vamp: + try: + t0 = time.perf_counter() + vamp = VAMPCPUCollisionChecker(args.spherized_urdf, srdf_path=args.srdf) + print(f" VAMPCPUCollisionChecker: OK (dim={vamp.dimension}, " + f"built/loaded in {time.perf_counter()-t0:.2f}s)") + nv = vamp.dimension + vamp_np = rng.uniform(-1.5, 1.5, (max_B, nv)).astype(np.float32) + vamp_cfgs_by_batch = {B: jnp.array(vamp_np[:B]) for B in BATCH_SIZES} + bin_checkers.append(("VAMP-CPU", vamp, None, vamp_cfgs_by_batch)) + except Exception as exc: + print(f" VAMPCPUCollisionChecker: SKIP ({exc})") + + for label, checker, robot_arg, cfgs_bb in bin_checkers: + print(f" {label} binary ...") + all_results += _bench_binary_backend( + label, checker, robot_arg, bin_world, cfgs_bb + ) + + # ── Edge validation ───────────────────────────────────────────────── + print(f"\n Edge-validation benchmarks (op=edges, G={EDGE_GRANULARITY} for " + f"CUDA; VAMP discretises internally) ...") + for label, checker, robot_arg, cfgs_bb in bin_checkers: + # Build edges from consecutive config pairs in this checker's DOF. + n_dof = cfgs_bb[max(BATCH_SIZES)].shape[-1] + a_np = rng.uniform(-1.2, 1.2, (max_B, n_dof)).astype(np.float32) + b_np = rng.uniform(-1.2, 1.2, (max_B, n_dof)).astype(np.float32) + edges_bb = {} + for E in BATCH_SIZES: + a = a_np[:E] + b = b_np[:E] + if label.startswith("VAMP"): + edges_bb[E] = jnp.asarray(np.stack([a, b], axis=1)) # [E,2,n] + else: + # CUDA-Binary and RoboGPU both take pre-discretised [E,G,n] + ts = np.linspace(0.0, 1.0, EDGE_GRANULARITY, dtype=np.float32) + edges_bb[E] = jnp.asarray( + a[:, None, :] * (1 - ts)[None, :, None] + + b[:, None, :] * ts[None, :, None] + ) # [E,G,n] + print(f" {label} edges ...") + all_results += _bench_edges_backend( + label, checker, robot_arg, bin_world, edges_bb + ) + + # ── Isolated point-cloud comparison: RoboGPU (GPU) vs CAPT (CPU) ───────── + # Both backends process the SAME point cloud (RoboGPU via an OptiX BVH, CAPT + # via VAMP's Collision-Affording Point Tree), so this is the apples-to-apples + # head-to-head. The regular world is empty (one far-away sphere) to isolate + # the cost of the point-cloud query. + if not args.skip_pointcloud: + print(f"\n Point-cloud benchmarks: RoboGPU vs CAPT " + f"(Mp={args.pc_points}, r_env={args.pc_r_env}) ...") + # Empty regular world so only the point-cloud path is exercised. + empty_world = Sphere.from_center_and_radius( + center=jnp.array([[100.0, 100.0, 100.0]]), radius=jnp.array([0.01])) + pc_np = rng.uniform(-0.6, 0.6, (args.pc_points, 3)).astype(np.float32) + pc_j = jnp.array(pc_np) + + pc_checkers: list[tuple] = [] + + if RoboGPUCollisionChecker is not None and not args.skip_robogpu: + try: + t0 = time.perf_counter() + robogpu = RoboGPUCollisionChecker( + coll_sph, edge_granularity=EDGE_GRANULARITY) + robogpu.set_world(empty_world, point_cloud=pc_j, r_env=args.pc_r_env) + print(f" RoboGPU: OK (built in {time.perf_counter()-t0:.2f}s)") + pc_checkers.append( + ("RoboGPU", robogpu, robot_sph, cfgs_sph_by_batch)) + except Exception as exc: + print(f" RoboGPU: SKIP ({exc})") + + capt_cfgs_by_batch = None + if VAMPCPUCollisionChecker is not None and not args.skip_vamp: + try: + t0 = time.perf_counter() + capt = VAMPCPUCollisionChecker(args.spherized_urdf, srdf_path=args.srdf) + # CAPT: env points inflated by r_env; robot sphere radius range + # spans [0, pc_r_env*?] — VAMP uses its own internal spherization, + # so capt_r_max only bounds the broadphase, set generously. + capt.set_world( + empty_world, point_cloud=pc_j, + capt_r_min=0.0, capt_r_max=0.2, capt_r_point=args.pc_r_env) + nv = capt.dimension + capt_np = rng.uniform(-1.5, 1.5, (max_B, nv)).astype(np.float32) + capt_cfgs_by_batch = {B: jnp.array(capt_np[:B]) for B in BATCH_SIZES} + print(f" CAPT: OK (dim={nv}, built/loaded in " + f"{time.perf_counter()-t0:.2f}s)") + pc_checkers.append(("CAPT", capt, None, capt_cfgs_by_batch)) + except Exception as exc: + print(f" CAPT: SKIP ({exc})") + + # Binary check on the point cloud. + for label, checker, robot_arg, cfgs_bb in pc_checkers: + print(f" {label} pc-binary ...") + all_results += _bench_binary_backend( + label, checker, robot_arg, empty_world, cfgs_bb, op_tag="pc-binary") + + # Edge validation on the point cloud. + print(f"\n Point-cloud edge validation (G={EDGE_GRANULARITY} for RoboGPU; " + f"CAPT discretises internally) ...") + for label, checker, robot_arg, cfgs_bb in pc_checkers: + n_dof = cfgs_bb[max(BATCH_SIZES)].shape[-1] + a_np = rng.uniform(-1.2, 1.2, (max_B, n_dof)).astype(np.float32) + b_np = rng.uniform(-1.2, 1.2, (max_B, n_dof)).astype(np.float32) + edges_bb = {} + for E in BATCH_SIZES: + a, b = a_np[:E], b_np[:E] + if label == "CAPT": + edges_bb[E] = jnp.asarray(np.stack([a, b], axis=1)) # [E,2,n] + else: + ts = np.linspace(0.0, 1.0, EDGE_GRANULARITY, dtype=np.float32) + edges_bb[E] = jnp.asarray( + a[:, None, :] * (1 - ts)[None, :, None] + + b[:, None, :] * ts[None, :, None]) # [E,G,n] + print(f" {label} pc-edges ...") + all_results += _bench_edges_backend( + label, checker, robot_arg, empty_world, edges_bb, op_tag="pc-edges") # ── Print tables ─────────────────────────────────────────────────────── print("\n\n") - for prim_name in world_primitives: - op_tag = f"world-{prim_name}" - _print_table(f"World collision distance (obstacle={prim_name})", all_results, op_tag) - _print_table("Self-collision distance", all_results, "self") + if not args.binary_only: + for prim_name in world_primitives: + op_tag = f"world-{prim_name}" + _print_table(f"World collision distance (obstacle={prim_name})", all_results, op_tag) + _print_table("Self-collision distance", all_results, "self") + if not args.skip_binary: + _print_table("Binary collision check (bool/config, world=Sphere)", + all_results, "binary") + _print_table("Edge validation (bool/edge, world=Sphere)", + all_results, "edges") + if not args.skip_pointcloud: + _print_table("Point-cloud collision check (bool/config, RoboGPU vs CAPT)", + all_results, "pc-binary") + _print_table("Point-cloud edge validation (bool/edge, RoboGPU vs CAPT)", + all_results, "pc-edges") # Speed-up tables print("\n\n") print("=" * 80) print(" Speed-up summary") print("=" * 80) - for prim_name in world_primitives: - op_tag = f"world-{prim_name}" - _print_speedup_table(all_results, op_tag, baseline_label="JAX-Capsule") - _print_speedup_table(all_results, "self", baseline_label="JAX-Capsule") + if not args.binary_only: + for prim_name in world_primitives: + op_tag = f"world-{prim_name}" + _print_speedup_table(all_results, op_tag, baseline_label="JAX-Capsule") + _print_speedup_table(all_results, "self", baseline_label="JAX-Capsule") + if not args.skip_binary: + # Binary checkers form their own family; compare against the CPU VAMP + # baseline (falls back to "—" if VAMP was skipped). + _print_speedup_table(all_results, "binary", baseline_label="VAMP-CPU") + _print_speedup_table(all_results, "edges", baseline_label="VAMP-CPU") + if not args.skip_pointcloud: + # Isolated head-to-head: RoboGPU (GPU OptiX) vs CAPT (CPU VAMP) on the + # same point cloud. Baseline is CAPT so >1× means RoboGPU is faster. + _print_speedup_table(all_results, "pc-binary", baseline_label="CAPT") + _print_speedup_table(all_results, "pc-edges", baseline_label="CAPT") # ── Numerical agreement check ────────────────────────────────────────── - if cuda_available: + # Guarded: the prebuilt CUDA *distance* kernel may not match the local GPU + # arch ("no kernel image"), in which case we skip the check rather than + # crashing the whole benchmark. + try: + if cuda_available: print("\n\n" + "=" * 80) print(" Numerical agreement check (batch=64, max abs diff)") print("=" * 80) @@ -749,6 +1031,10 @@ def main(args) -> None: print(f" self coarse-clear: {n_clear_self}/{n_total_self} " f"({100*n_clear_self/n_total_self:.0f}%)" f" — fine kernel skipped for {n_clear_self} configs") + except Exception as exc: + print(f"\n Numerical agreement check SKIPPED ({type(exc).__name__}: {exc})") + print(" (Likely a CUDA distance-kernel arch mismatch — rebuild with " + "bash build_kernels/build_collision_cuda.sh)") print("\nDone.") @@ -768,8 +1054,27 @@ def main(args) -> None: type=pathlib.Path, help="Path to the coarse spherized URDF for CUDA-Sphere-Coarse " f"(default: {COARSE_SPHERIZED_URDF})") + parser.add_argument("--srdf", default=str(PANDA_SRDF), + type=pathlib.Path, + help=f"SRDF for the VAMP CPU checker (default: {PANDA_SRDF})") parser.add_argument("--skip-neural", action="store_true", help="Skip neural SDF training and benchmarking") + parser.add_argument("--skip-binary", action="store_true", + help="Skip binary collision-check + edge-validation benchmarks") + parser.add_argument("--skip-vamp", action="store_true", + help="Skip the VAMP CPU backend (still runs CUDA binary checkers)") + parser.add_argument("--skip-robogpu", action="store_true", + help="Skip the RoboGPU OptiX backend") + parser.add_argument("--skip-pointcloud", action="store_true", + help="Skip the isolated RoboGPU-vs-CAPT point-cloud benchmark") + parser.add_argument("--pc-points", type=int, default=N_POINT_CLOUD, + help=f"Number of random environment points in the RoboGPU " + f"point cloud (default: {N_POINT_CLOUD})") + parser.add_argument("--pc-r-env", type=float, default=POINT_CLOUD_R_ENV, + help=f"Sphere radius for each env point (default: {POINT_CLOUD_R_ENV})") + parser.add_argument("--binary-only", action="store_true", + help="Only run the binary collision-check + edge benchmarks " + "(skips the capsule/sphere/neural distance backends)") parser.add_argument("--neural-samples", type=int, default=NEURAL_SAMPLES, help=f"Training set size for neural SDF (default: {NEURAL_SAMPLES})") args = parser.parse_args() diff --git a/tests/bench_fk.py b/tests/bench_fk.py index 8e3c3521..128580ec 100644 --- a/tests/bench_fk.py +++ b/tests/bench_fk.py @@ -18,7 +18,7 @@ Prerequisites: 1. A CUDA-capable GPU must be available. 2. The CUDA FK library must be compiled: - bash src/pyroffi/cuda_kernels/build_fk_cuda.sh + bash build_kernels/build_fk_cuda.sh 3. Local spherized URDFs must be present under ``resources/``. """ diff --git a/tests/bench_ik.py b/tests/bench_ik.py index 788dfc83..9900a7ef 100644 --- a/tests/bench_ik.py +++ b/tests/bench_ik.py @@ -34,10 +34,10 @@ Prerequisites: 1. A CUDA-capable GPU. 2. CUDA libraries compiled: - bash src/pyroffi/cuda_kernels/build_hjcd_ik_cuda.sh - bash src/pyroffi/cuda_kernels/build_ls_ik_cuda.sh - bash src/pyroffi/cuda_kernels/build_sqp_ik_cuda.sh - bash src/pyroffi/cuda_kernels/build_mppi_ik_cuda.sh + bash build_kernels/build_hjcd_ik_cuda.sh + bash build_kernels/build_ls_ik_cuda.sh + bash build_kernels/build_sqp_ik_cuda.sh + bash build_kernels/build_mppi_ik_cuda.sh 3. robot_descriptions installed: pip install robot_descriptions 4. (Optional) Flax model for Learned-IK: diff --git a/tests/debug_fk_cuda.py b/tests/debug_fk_cuda.py index 337e5d9c..538ee8f4 100644 --- a/tests/debug_fk_cuda.py +++ b/tests/debug_fk_cuda.py @@ -15,7 +15,7 @@ def _fk_joints_jax(robot, cfg): def _fk_joints_cuda(robot, cfg): """Return (n_joints, 7) world transforms via CUDA kernel directly.""" - from pyroffi.cuda_kernels._fk_cuda import fk_cuda + from pyroffi.cuda_kernels.fk._fk_cuda import fk_cuda return np.array(fk_cuda( cfg=cfg, twists=robot.joints.twists, @@ -47,7 +47,7 @@ def diagnose(robot_name: str, batch: int = 1, seed: int = 42): fk_jax_jit = jax.jit(robot._forward_kinematics_joints) fk_cuda_raw = jax.jit(lambda c: _fk_joints_cuda_jit(robot, c)) - from pyroffi.cuda_kernels._fk_cuda import fk_cuda + from pyroffi.cuda_kernels.fk._fk_cuda import fk_cuda fk_cuda_jit = jax.jit(lambda c: fk_cuda( cfg=c, twists=robot.joints.twists, diff --git a/tests/test_fk_cuda.py b/tests/test_fk_cuda.py index e98100df..288dbc29 100644 --- a/tests/test_fk_cuda.py +++ b/tests/test_fk_cuda.py @@ -6,7 +6,7 @@ Prerequisites: 1. A CUDA-capable GPU must be available. 2. The CUDA FK library must be compiled: - bash src/pyroffi/cuda_kernels/build_fk_cuda.sh + bash build_kernels/build_fk_cuda.sh 3. Local spherized URDFs must be present under ``resources/``. """ diff --git a/tests/test_robogpu_collision.py b/tests/test_robogpu_collision.py new file mode 100644 index 00000000..4a77b35b --- /dev/null +++ b/tests/test_robogpu_collision.py @@ -0,0 +1,99 @@ +"""Correctness test for the RoboGPU OptiX point-cloud collision path. + +Independent oracle: pyroffi's own FK (`RobotCollisionSpherized.at_config`) gives +world-frame robot spheres; we brute-force test each against the env point cloud +(point i collides iff dist(center, p_i) < r_robot + r_env). Self-collision is +disabled in the checker so the OptiX point-cloud stage is what's under test. + +Skipped unless a CUDA GPU is present and the RoboGPU library has been built: + + bash build_kernels/build_robogpu_collision.sh + +Run: + pytest tests/test_robogpu_collision.py -s +""" +from __future__ import annotations + +import pathlib + +import numpy as np +import pytest + +jnp = pytest.importorskip("jax.numpy") +import jax # noqa: E402 +import yourdfpy # noqa: E402 + +import pyroffi as pk # noqa: E402 +from pyroffi.collision import RobotCollisionSpherized, Sphere # noqa: E402 + +RES = pathlib.Path(__file__).resolve().parent.parent / "resources" / "panda" + + +def _checker(coll): + try: + from pyroffi.collision import RoboGPUCollisionChecker + except Exception as exc: # pragma: no cover + pytest.skip(f"RoboGPU checker unavailable: {exc}") + try: + return RoboGPUCollisionChecker(coll) + except RuntimeError as exc: # library not built / no OptiX + pytest.skip(str(exc)) + + +def test_robogpu_pointcloud_matches_oracle(): + urdf = yourdfpy.URDF.load(str(RES / "panda_spherized.urdf")) + robot = pk.Robot.from_urdf(urdf) + coll = RobotCollisionSpherized.from_urdf(urdf) + + rng = np.random.default_rng(3) + B = 384 + home = np.array([0.0, -0.785, 0.0, -2.356, 0.0, 1.571, 0.785], dtype=np.float32) + cfgs = jnp.array((home[None, :] + rng.uniform(-0.7, 0.7, (B, 7))).astype(np.float32)) + + Mp = 300 + pc = rng.uniform(-0.8, 0.8, (Mp, 3)).astype(np.float32) + pc_j = jnp.array(pc) + R_ENV = 0.05 + far = Sphere.from_center_and_radius( + center=jnp.array([[100.0, 100.0, 100.0]]), radius=jnp.array([0.01])) + + # RoboGPU with self-collision disabled so only the point-cloud path runs. + rg = _checker(coll) + rg._f_pair_i = jnp.zeros((0,), dtype=jnp.int32) + rg._f_pair_j = jnp.zeros((0,), dtype=jnp.int32) + rg._cached_robot_id = None + rg._jit_fn = None + rg.set_world(far, point_cloud=pc_j, r_env=R_ENV) + + try: + v_rg = np.asarray(rg.check_collision_free(robot, cfgs)).astype(int) + except Exception as exc: # CUDA/OptiX runtime failure → skip, not fail + pytest.skip(f"RoboGPU kernel did not run: {exc}") + + # Independent oracle: world-frame robot spheres via FK, brute-force vs cloud. + geom = jax.vmap(lambda c: coll.at_config(robot, c))(cfgs) + centers = np.asarray(geom.pose.translation()) # [B, S, 3] + radii = np.asarray(geom.size).reshape(centers.shape[:-1]) # [B, S] + + pc_np = np.asarray(pc) + v_oracle = np.ones(B, dtype=int) + for b in range(B): + c, r = centers[b], radii[b] + valid = r > 0 + c, r = c[valid], r[valid] + d2 = ((c[:, None, :] - pc_np[None, :, :]) ** 2).sum(-1) + if np.any(d2 < (r[:, None] + R_ENV) ** 2): + v_oracle[b] = 0 + + mismatch = int((v_rg != v_oracle).sum()) + # A real free/hit mix proves the point-cloud path is actually exercised. + assert 0.02 < v_rg.mean() < 0.98, f"degenerate verdict distribution: {v_rg.mean()}" + assert mismatch == 0, ( + f"{mismatch}/{B} verdict mismatches vs brute-force oracle " + f"(RoboGPU free={v_rg.mean():.3f}, oracle free={v_oracle.mean():.3f})" + ) + + +if __name__ == "__main__": + test_robogpu_pointcloud_matches_oracle() + print("PASS") diff --git a/tests/test_vamp_cpu_collision.py b/tests/test_vamp_cpu_collision.py new file mode 100644 index 00000000..aabbd747 --- /dev/null +++ b/tests/test_vamp_cpu_collision.py @@ -0,0 +1,449 @@ +"""Validate the JIT-compiled VAMP CPU collision checker. + +This exercises pyroffi's :class:`VAMPCPUCollisionChecker`, which JIT-compiles a +robot-specialised VAMP collision checker through cricket and invokes it via the +JAX FFI on CPU. + +The whole suite is skipped unless cricket (with JIT) is importable AND a `clang` +binary is on PATH — cricket's JIT driver shells out to clang at runtime. Build +cricket first with: + + bash build_kernels/build_cricket_jit.sh + +Oracles are kept self-contained (no dependence on matching pyroffi's spherized +model to VAMP's own spherization): + + * shape / dtype of the verdicts, + * a real mix of free / in-collision configs once obstacles are placed, + * edge-vs-config consistency: a valid edge must have a free goal endpoint + (VAMP samples ``(0, 1]`` — the start is assumed pre-validated), + * a hand-built edge that plainly crosses an obstacle is rejected while a + clear edge in free space is accepted. + +Run: + pytest tests/test_vamp_cpu_collision.py -s +""" + +from __future__ import annotations + +import pathlib +import shutil +import time + +import numpy as np +import pytest + +jnp = pytest.importorskip("jax.numpy") +import jax # noqa: E402 + +import pyroffi as pk # noqa: E402 +from pyroffi.collision import Sphere # noqa: E402 + +RESOURCE_ROOT = pathlib.Path(__file__).resolve().parent.parent / "resources" +# Use the spherized URDF: cricket reads the sphere primitives directly, so no +# meshes need to be resolved on disk. +URDF = RESOURCE_ROOT / "panda" / "panda_spherized.urdf" +SRDF = RESOURCE_ROOT / "panda" / "panda.srdf" + + +def _checker(): + """Build the VAMP CPU checker or skip if the toolchain isn't present.""" + if shutil.which("clang") is None: + pytest.skip("clang not on PATH; cricket JIT cannot run") + try: + from pyroffi.collision import VAMPCPUCollisionChecker + except Exception as exc: # pragma: no cover + pytest.skip(f"VAMP checker unavailable: {exc}") + try: + return VAMPCPUCollisionChecker(URDF, srdf_path=SRDF) + except RuntimeError as exc: # cricket not built / not importable + pytest.skip(str(exc)) + + +@pytest.fixture(scope="module") +def checker(): + return _checker() + + +def _vamp_panda(): + """Return VAMP's nanobind `panda` robot module or skip if unavailable. + + This is the upstream, hand-written/-spherised VAMP collision checker exposed + through nanobind (`import vamp`). pyroffi's :class:`VAMPCPUCollisionChecker` + JIT-compiles its own kernel from the *spherized URDF*; benchmarking the two + side by side shows the cost of going through the JAX FFI + cricket JIT versus + calling the native binding directly. + """ + vamp = pytest.importorskip("vamp") + return vamp + + +def _vamp_env_from_world(vamp, world): + """Build a native VAMP `Environment` from a pyroffi `Sphere` world. + + Reading the centers/radii straight off the pyroffi geometry guarantees both + checkers see the *same* obstacles, so the only thing that differs is the + robot's own sphere model (pyroffi's spherized URDF vs VAMP's built-in one) — + which is why we report agreement as info but never assert verdict equality. + """ + centers = np.asarray(world.pose.translation()).reshape(-1, 3) + radii = np.asarray(world.radius).reshape(-1) + env = vamp.Environment() + for c, r in zip(centers, radii): + env.add_sphere(vamp.Sphere([float(c[0]), float(c[1]), float(c[2])], float(r))) + return env + + +# VAMP's edge discretisation is controlled by two compile-time constants: +# * the SIMD rake width (AVX2 -> 8 floats per vector), and +# * the robot's planning ``resolution`` (Panda is 32, which pyroffi's checker +# also bakes in via its ``resolution=32`` default). +# We replicate them here so the *raw VAMP* edge check below samples each edge at +# exactly the same density pyroffi's kernel does — see ``_vamp_fine_samples``. +VAMP_RESOLUTION = 32 + + +def _vamp_rake() -> int: + """SIMD float-vector width VAMP compiles to on this CPU (its sample stride). + + VAMP picks the widest available: AVX-512 -> 16, AVX2/AVX -> 8, else a 4-wide + SSE/NEON fallback. This must match the build or the replicated sample + positions (and hence verdicts) drift from pyroffi's kernel. + """ + try: + flags = pathlib.Path("/proc/cpuinfo").read_text() + except OSError: + return 8 + if "avx512f" in flags: + return 16 + if "avx2" in flags or " avx " in flags: + return 8 + return 4 + + +def _vamp_fine_samples(a, b, rake: int, resolution: int): + """Replicate the exact sample points of ``validate_motion``. + + VAMP validates an edge a->b by checking configurations at fractions + ``m / (rake * n)`` for ``m = 1 .. rake*n`` (the open interval ``(0, 1]`` — the + start is assumed pre-validated), where ``n = max(ceil(dist / rake * res), 1)`` + and ``dist = ||b - a||``. Reproducing that fraction set lets us drive the raw + VAMP checker at the *same* fine resolution pyroffi uses, so a verdict is the + AND of the per-sample checks and the two backends agree (modulo tiny + float-accumulation differences at a collision boundary). + + Returns the flat ``[S, dim]`` sample buffer plus the ``[E]`` segment offsets + into it (the start index of each edge), suitable for ``reduceat``. + """ + a = np.asarray(a, dtype=np.float32) + b = np.asarray(b, dtype=np.float32) + v = b - a + dist = np.linalg.norm(v, axis=1) + n = np.maximum(np.ceil(dist / rake * resolution), 1.0).astype(np.int64) + counts = rake * n # samples per edge [E] + offsets = np.concatenate([[0], np.cumsum(counts)[:-1]]).astype(np.int64) + + seg = np.repeat(np.arange(a.shape[0]), counts) # edge id per sample + within = np.arange(counts.sum()) - np.repeat(offsets, counts) # 0-based + frac = ((within + 1) / counts[seg]).astype(np.float32)[:, None] + samples = (a[seg] + v[seg] * frac).astype(np.float32) + return samples, offsets + + +def _vamp_fine_edges(panda, env, samples, offsets): + """Raw VAMP edge verdicts at fine resolution: AND per-sample fkcc per edge.""" + per_sample = np.asarray(panda.validate_motion_batch(samples, samples, env)) + return np.logical_and.reduceat(per_sample, offsets) + + +def test_configs_shape_and_mix(checker): + n = checker.dimension + cfg = jnp.asarray( + np.random.RandomState(1).uniform(-1.2, 1.2, size=(512, n)), dtype=jnp.float32 + ) + world = Sphere.from_center_and_radius( + center=jnp.array([[0.3, 0.0, 0.6], [0.0, 0.35, 0.7], [-0.25, 0.0, 0.5]]), + radius=jnp.array([0.13, 0.12, 0.12]), + ) + free = np.asarray(checker.check_collision_free(None, cfg, world)) + assert free.shape == (512,) + assert free.dtype == bool + n_free = int(free.sum()) + print(f"[vamp-configs] free {n_free}/512") + assert 0 < n_free < 512, "vacuous test — need a mix of free / colliding configs" + + +def test_edges_consistent_with_endpoints(checker): + n = checker.dimension + world = Sphere.from_center_and_radius( + center=jnp.array([[0.3, 0.0, 0.6], [0.0, 0.35, 0.7]]), + radius=jnp.array([0.15, 0.14]), + ) + rng = np.random.RandomState(5) + E = 128 + a = jnp.asarray(rng.uniform(-1.2, 1.2, size=(E, n)), dtype=jnp.float32) + b = jnp.asarray(rng.uniform(-1.2, 1.2, size=(E, n)), dtype=jnp.float32) + edges = jnp.stack([a, b], axis=1) # [E, 2, n] + + edge_ok = np.asarray(checker.check_edges_collision_free(None, edges, world)) + assert edge_ok.shape == (E,) + + b_free = np.asarray(checker.check_collision_free(None, b, world)) + # VAMP samples the open interval (0, 1], so a valid edge guarantees the *goal* + # endpoint is free (the start is assumed pre-validated by the planner; see + # check_edges_collision_free). The goal being free is the necessary condition. + assert np.all(~edge_ok | b_free), "edge marked valid with a colliding goal endpoint" + assert 0 < int(edge_ok.sum()) < E, "edge test vacuous — need a mix" + print(f"[vamp-edges] valid {int(edge_ok.sum())}/{E}") + + +def test_batch_matches_per_edge_and_is_deterministic(checker): + """The OpenMP batch edge kernel must equal validating each edge on its own. + + This is the real correctness invariant of the batch handler (independent of + VAMP's internal discretisation): row *i* of the batch result depends only on + edge *i*, so any batch/loop disagreement would indicate a data-race or + indexing bug in the parallel loop. Also checks run-to-run determinism. + """ + n = checker.dimension + world = Sphere.from_center_and_radius( + center=jnp.array([[0.3, 0.0, 0.6], [0.0, 0.35, 0.7]]), + radius=jnp.array([0.13, 0.12]), + ) + rng = np.random.RandomState(7) + E = 256 + a = jnp.asarray(rng.uniform(-1.2, 1.2, size=(E, n)), dtype=jnp.float32) + b = jnp.asarray(rng.uniform(-1.2, 1.2, size=(E, n)), dtype=jnp.float32) + edges = jnp.stack([a, b], axis=1) + + batch = np.asarray(checker.check_edges_collision_free(None, edges, world)) + again = np.asarray(checker.check_edges_collision_free(None, edges, world)) + assert np.array_equal(batch, again), "edge validation is not deterministic" + + # Validate a handful of edges one at a time and compare to the batch rows. + idx = rng.choice(E, size=16, replace=False) + per_edge = np.array( + [ + bool( + np.asarray( + checker.check_edges_collision_free(None, edges[i : i + 1], world) + )[0] + ) + for i in idx + ] + ) + assert np.array_equal(per_edge, batch[idx]), "batch result differs from per-edge" + assert 0 < int(batch.sum()) < E, "batch test vacuous — need a mix of edges" + print(f"[vamp-batch] valid {int(batch.sum())}/{E}; batch==per-edge, deterministic") + + +def test_point_cloud_capt(checker): + """A point-cloud (CAPT) wall must invalidate configs that were free without it.""" + n = checker.dimension + far = Sphere.from_center_and_radius( + center=jnp.array([[100.0, 100.0, 100.0]]), radius=jnp.array([1e-3]) + ) + cfg = jnp.asarray( + np.random.RandomState(1).uniform(-1.2, 1.2, size=(512, n)), dtype=jnp.float32 + ) + base = int(np.asarray(checker.check_collision_free(None, cfg, far)).sum()) + + gx, gz = np.meshgrid(np.linspace(0.15, 0.5, 25), np.linspace(0.2, 1.0, 25)) + pts = np.stack([gx.ravel(), np.zeros(gx.size), gz.ravel()], axis=1).astype(np.float32) + withpc = int( + np.asarray( + checker.check_collision_free( + None, cfg, far, point_cloud=jnp.asarray(pts), capt=(0.0, 1.0, 0.04) + ) + ).sum() + ) + print(f"[vamp-capt] free without cloud {base}/512; with cloud {withpc}/512") + assert withpc < base, "CAPT point-cloud wall removed no free configurations" + + +def _time_call(fn, repeats): + """Return the best wall-clock time (seconds) over `repeats` runs of `fn`. + + We report the minimum rather than the mean: it's the cleanest estimate of the + kernel's intrinsic cost, least polluted by scheduler noise / other processes. + `np.asarray(...)` forces the FFI result to be materialised so we're not timing + lazy JAX dispatch. + """ + best = float("inf") + for _ in range(repeats): + t0 = time.perf_counter() + np.asarray(fn()) + best = min(best, time.perf_counter() - t0) + return best + + +def test_timing_profile(checker): + """Basic timing profile of config / edge validation across batch sizes. + + This is a profiling aid rather than a hard assertion: it warms up the JIT once + (so compilation isn't counted) and then reports best-of-N latency and + throughput for a range of batch sizes. The only assertion is that timing + actually succeeded for every batch size. + """ + n = checker.dimension + world = Sphere.from_center_and_radius( + center=jnp.array([[0.3, 0.0, 0.6], [0.0, 0.35, 0.7], [-0.25, 0.0, 0.5]]), + radius=jnp.array([0.13, 0.12, 0.12]), + ) + rng = np.random.RandomState(11) + batch_sizes = [1, 8, 64, 256, 1024, 4096] + repeats = 5 + + # Warm up the JIT-compiled kernels once so compilation isn't timed. + warm_cfg = jnp.asarray(rng.uniform(-1.2, 1.2, size=(8, n)), dtype=jnp.float32) + warm_edges = jnp.stack([warm_cfg, warm_cfg], axis=1) + np.asarray(checker.check_collision_free(None, warm_cfg, world)) + np.asarray(checker.check_edges_collision_free(None, warm_edges, world)) + + print("\n[vamp-timing] config validation (best of %d):" % repeats) + print(f"{'batch':>8} {'total (ms)':>12} {'per-cfg (us)':>14} {'cfg/s':>14}") + for bs in batch_sizes: + cfg = jnp.asarray(rng.uniform(-1.2, 1.2, size=(bs, n)), dtype=jnp.float32) + cfg.block_until_ready() + dt = _time_call(lambda c=cfg: checker.check_collision_free(None, c, world), repeats) + print( + f"{bs:>8} {dt * 1e3:>12.3f} {dt / bs * 1e6:>14.3f} {bs / dt:>14.0f}" + ) + + print("\n[vamp-timing] edge validation (best of %d):" % repeats) + print(f"{'batch':>8} {'total (ms)':>12} {'per-edge (us)':>14} {'edge/s':>14}") + for bs in batch_sizes: + a = jnp.asarray(rng.uniform(-1.2, 1.2, size=(bs, n)), dtype=jnp.float32) + b = jnp.asarray(rng.uniform(-1.2, 1.2, size=(bs, n)), dtype=jnp.float32) + edges = jnp.stack([a, b], axis=1) + edges.block_until_ready() + dt = _time_call( + lambda e=edges: checker.check_edges_collision_free(None, e, world), repeats + ) + print( + f"{bs:>8} {dt * 1e3:>12.3f} {dt / bs * 1e6:>14.3f} {bs / dt:>14.0f}" + ) + + +def test_benchmark_against_vamp_nanobind(checker): + """Benchmark pyroffi's VAMP CPU checker against the native VAMP nanobind API. + + For both config and edge validation we time pyroffi's JAX-FFI kernel and the + upstream `vamp.panda` binding on identical obstacle worlds and report + best-of-N throughput plus the pyroffi/native speedup. Notes on fairness: + + * Config validation: VAMP exposes only a *single*-config `validate`, so the + native number is a Python loop (per-call binding overhead dominates at + small batches); pyroffi validates the whole batch in one FFI call. + * Edge validation: VAMP's nanobind `validate_motion`/`validate_motion_batch` + are hardcoded to discretisation *resolution 1*, while pyroffi's kernel + runs at the robot's full ``resolution`` (32). To compare like with like + — both checking each edge *finely* — we drive raw VAMP at resolution 32 + by validating the exact same sample points pyroffi's kernel visits (see + :func:`_vamp_fine_samples`) through one batched `validate_motion_batch` + call and AND-reducing per edge. Because both now sample identically, the + edge verdicts agree (reported below). Caveat: pyroffi can early-out per + edge on the first colliding sample; the flat-batch replication checks + every sample, so this row slightly favours pyroffi. + + The only assertion is that both backends produced a result for every batch + size — config verdicts can still differ because the two use different robot + sphere models (printed for information only). + """ + vamp = _vamp_panda() + panda = vamp.panda + n = checker.dimension + assert panda.dimension() == n, "VAMP panda DOF disagrees with pyroffi checker" + + world = Sphere.from_center_and_radius( + center=jnp.array([[0.3, 0.0, 0.6], [0.0, 0.35, 0.7], [-0.25, 0.0, 0.5]]), + radius=jnp.array([0.13, 0.12, 0.12]), + ) + env = _vamp_env_from_world(vamp, world) + + rng = np.random.RandomState(11) + batch_sizes = [1, 8, 64, 256, 1024, 4096] + repeats = 5 + + # Warm up pyroffi's JIT-compiled kernels once so compilation isn't timed. + warm = jnp.asarray(rng.uniform(-1.2, 1.2, size=(8, n)), dtype=jnp.float32) + np.asarray(checker.check_collision_free(None, warm, world)) + np.asarray( + checker.check_edges_collision_free(None, jnp.stack([warm, warm], axis=1), world) + ) + + # --- agreement sanity (informational) --------------------------------- + cfg_np = rng.uniform(-1.2, 1.2, size=(256, n)).astype(np.float32) + pk_free = np.asarray(checker.check_collision_free(None, jnp.asarray(cfg_np), world)) + vamp_free = np.array([panda.validate(c, env) for c in cfg_np]) + agree = float(np.mean(pk_free == vamp_free)) * 100.0 + print( + f"\n[vamp-bench] config-verdict agreement pyroffi vs native VAMP: " + f"{agree:.1f}% (differ due to distinct sphere models)" + ) + + print("\n[vamp-bench] config validation (best of %d):" % repeats) + hdr = f"{'batch':>8} {'pyroffi cfg/s':>16} {'vamp cfg/s':>16} {'speedup':>10}" + print(hdr) + for bs in batch_sizes: + cfg_np = rng.uniform(-1.2, 1.2, size=(bs, n)).astype(np.float32) + cfg_j = jnp.asarray(cfg_np) + cfg_j.block_until_ready() + dt_pk = _time_call( + lambda c=cfg_j: checker.check_collision_free(None, c, world), repeats + ) + dt_vamp = _time_call( + lambda c=cfg_np: np.array([panda.validate(x, env) for x in c]), repeats + ) + print( + f"{bs:>8} {bs / dt_pk:>16.0f} {bs / dt_vamp:>16.0f} " + f"{dt_vamp / dt_pk:>9.2f}x" + ) + + # --- edge-verdict agreement at matched (fine) resolution -------------- + rake = _vamp_rake() + a_np = rng.uniform(-1.2, 1.2, size=(256, n)).astype(np.float32) + b_np = rng.uniform(-1.2, 1.2, size=(256, n)).astype(np.float32) + edges = jnp.stack([jnp.asarray(a_np), jnp.asarray(b_np)], axis=1) + pk_edge = np.asarray(checker.check_edges_collision_free(None, edges, world)) + samples, offsets = _vamp_fine_samples(a_np, b_np, rake, VAMP_RESOLUTION) + vamp_edge = _vamp_fine_edges(panda, env, samples, offsets) + edge_agree = float(np.mean(pk_edge == vamp_edge)) * 100.0 + print( + f"\n[vamp-bench] edge-verdict agreement pyroffi (res {VAMP_RESOLUTION}) vs " + f"raw VAMP (res {VAMP_RESOLUTION}, rake {rake}): {edge_agree:.1f}%" + ) + + print("\n[vamp-bench] edge validation @ resolution %d (best of %d):" + % (VAMP_RESOLUTION, repeats)) + hdr = f"{'batch':>8} {'pyroffi edge/s':>16} {'vamp edge/s':>16} {'speedup':>10}" + print(hdr) + for bs in batch_sizes: + a_np = rng.uniform(-1.2, 1.2, size=(bs, n)).astype(np.float32) + b_np = rng.uniform(-1.2, 1.2, size=(bs, n)).astype(np.float32) + edges = jnp.stack([jnp.asarray(a_np), jnp.asarray(b_np)], axis=1) + edges.block_until_ready() + # Sample positions for the raw-VAMP fine check are precomputed outside the + # timed region so we measure VAMP's kernel, not numpy interpolation. + samples, offsets = _vamp_fine_samples(a_np, b_np, rake, VAMP_RESOLUTION) + dt_pk = _time_call( + lambda e=edges: checker.check_edges_collision_free(None, e, world), repeats + ) + dt_vamp = _time_call( + lambda s=samples, o=offsets: _vamp_fine_edges(panda, env, s, o), repeats + ) + print( + f"{bs:>8} {bs / dt_pk:>16.0f} {bs / dt_vamp:>16.0f} " + f"{dt_vamp / dt_pk:>9.2f}x" + ) + + +if __name__ == "__main__": + c = _checker() + test_configs_shape_and_mix(c) + test_edges_consistent_with_endpoints(c) + test_batch_matches_per_edge_and_is_deterministic(c) + test_point_cloud_capt(c) + test_timing_profile(c) + test_benchmark_against_vamp_nanobind(c) + print("\nAll VAMP CPU collision checks passed.")