Skip to content

Latest commit

 

History

History
99 lines (67 loc) · 3.49 KB

File metadata and controls

99 lines (67 loc) · 3.49 KB

ptychoml

Neural network inference for ptychography. Runs PtychoViT models via TensorRT.

About PtychoViT

PtychoViT is a Vision Transformer (ViT) adapted for ptychographic reconstruction. It takes a batch of diffraction patterns and directly predicts amplitude and phase estimates — orders of magnitude faster than iterative (DM, ML) methods, enabling real-time feedback during live scans.

The model is developed at Argonne National Laboratory (ANL). Training code lives in the ptycho-vit repo (private, maintained at ANL).

ptychoml handles the inference side only — taking a trained model exported to ONNX, converting it to a TensorRT engine, and running fast batched inference on a GPU.

Architecture

What this repo is: a pure computation library for ML-based ptychographic reconstruction. Loads pre-built TensorRT engines and runs inference on diffraction patterns.

What this repo is not:

  • Pipeline orchestration → see NSLS2/holoptycho
  • Iterative (DM) reconstruction → see NSLS2/ptycho
  • Model training → see ptycho-vit (PyTorch training code maintained by ANL)

Design principle: no I/O, no framework deps (Holoscan, MPI, etc.). Return data to the caller; let the caller decide where it goes.

Install

git clone git@github.com:NSLS2/ptychoml.git
cd ptychoml
pixi install

Requires an NVIDIA GPU with CUDA 12 driver and pixi.

Usage

Python API:

from ptychoml import PtychoViTInference

with PtychoViTInference(engine_path="model.engine", gpu=0) as session:
    pred, indices = session.predict(diff_amp, image_indices)
    # pred.shape == (B, 2, H, W) or (B, H, W)

Build a TensorRT engine from ONNX:

pixi run build-engine --onnx model.onnx --output model.engine
# or
ptychoml-build-engine --onnx model.onnx --output model.engine
from ptychoml import build_engine, save_engine

engine = build_engine("model.onnx", fp16=False, tf32=True)
save_engine(engine, "model.engine")

Run inference on an HDF5 dataset:

pixi run predict --engine model.engine --data scan_1234.h5 --output results.h5
# or
ptychoml-predict --engine model.engine --data scan_1234.h5 --output results.h5

By default, diffraction amplitudes are read from the diffamp dataset key (matching the format used by holoptycho). Use --dataset to specify a different key:

pixi run predict --engine model.engine --data scan.h5 --output results.h5 --dataset entry/data/data

Additional options:

Flag Description
--gpu N CUDA device ordinal (default: 0)
--shifted Set if input data has been fftshift'd
--dataset KEY HDF5 dataset key for diffraction amplitudes (default: diffamp)

The output HDF5 file contains a predictions dataset with shape (N, 2, H, W) or (N, H, W) depending on the model. If the input file has a points dataset (scan positions), it is copied through to the output.

Run tests

pixi run test

Tests run without GPU/TRT (via pytest.importorskip gates).

Related repos

Repo Role
NSLS2/holoptycho Streaming Holoscan pipeline that uses ptychoml for live inference
NSLS2/ptycho Iterative DM reconstruction kernels
ptycho-vit PyTorch training code, produces ONNX files consumed by ptychoml