diff --git a/contrib/models/SmolVLA-Libero/README.md b/contrib/models/SmolVLA-Libero/README.md new file mode 100644 index 00000000..c30dd1d2 --- /dev/null +++ b/contrib/models/SmolVLA-Libero/README.md @@ -0,0 +1,235 @@ +# Contrib Model: SmolVLA-Libero + +NeuronX Distributed Inference port of `HuggingFaceVLA/smolvla_libero` — a +SmolVLM2-backed flow-matching vision-language-action (VLA) policy fine-tuned +on the LIBERO benchmark. Three compiled subgraphs, maximally on-Neuron, +written in the per-model NxDI structure. + +## Model Information + +- **HuggingFace ID:** `HuggingFaceVLA/smolvla_libero` +- **Backbone:** `HuggingFaceTB/SmolVLM2-500M-Instruct` (full 32-layer text decoder) +- **Model Type:** Flow-matching VLA (SigLIP vision + SmolLM-style VLM + action expert) +- **Action head:** 32-layer expert with self/cross-attn alternation, 50-step action chunk, 10-step Euler denoising +- **License:** Check HuggingFace model card + +## Architecture Details + +| Component | Where | Subgraph | +|--------------------------------------------------------|------------|----------| +| SigLIP vision encoder (12 layers, hidden=768) | **Neuron** | #1 | +| Pixel-shuffle 4× + connector + scale by sqrt(960) | **Neuron** | #1 | +| Lang token embed + scale by sqrt(960) | **Neuron** | #2 | +| State projection (32 → 960) | **Neuron** | #2 | +| VLM 32-layer text decoder (eager GQA, RoPE, RMSNorm) | **Neuron** | #2 | +| Pad-aware position_ids + 2D attention mask | **Neuron** | #2 / #3 | +| Action expert: 16× self-attn (concat past KV) layers | **Neuron** | #3 | +| Action expert: 16× cross-attn (Q from suffix) layers | **Neuron** | #3 | +| Sinusoidal timestep embedding | **Neuron** | #3 | +| Action in/out projections + time MLP | **Neuron** | #3 | +| Image preprocessing (resize-with-pad, normalize) | CPU | — | +| Tokenization | CPU | — | +| 10-step Euler denoising loop | CPU | — | + +**Deviations from "everything on Neuron":** + +1. The 10-step Euler loop runs on CPU. Static-shape compilation cannot host + a Python `for step in range(N)` as a single graph; the loop body is the + compiled subgraph. Each step calls NEFF #3 with the updated `noisy_actions`. +2. Tokenization, image flip / resize-with-pad, and state-vector composition + run on CPU because they are data-loading, not model compute. + +**Hardware constraint flagged:** `tp_degree = 1` because +`num_attention_heads = 15` and `num_kv_heads = 5` — neither divides cleanly +into the 4 Neuron cores on `trn3pd98.3xlarge`. The NxDI parallel primitives +(`ColumnParallelLinear`, `RowParallelLinear`, `ParallelEmbedding`) are still +used throughout, so this code is portable to instances with divisor-friendly +head counts. On this instance, 3 of 4 cores idle; the model fits comfortably +in one core's HBM with vast headroom. + +## Validation Results + +**Validated:** 2026-05-06 +**Configuration:** TP=1, batch_size=1, bfloat16 +**Instance:** trn3pd98.3xlarge +**NxDI:** 2.29 (`/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference`) + +### Test Results + +| Check | Result | +|--------------------------------------------------------------|----------------------------------------------| +| Vision NEFF vs HF SmolVLM2 vision (single image) | cos_sim = 0.99990 | +| Prefix KV layer 0..31 vs lerobot CPU | max abs diff ≤ 0.4 (BF16) | +| Full action chunk vs lerobot CPU (matched noise) | cos_sim = 0.9999, mean abs diff = 0.007 | +| Full action chunk: Neuron vs lerobot CPU (this `test_model.py`) | cos_sim = 0.999921, mean abs diff = 0.0015 | +| Closed-loop LIBERO `libero_object` task 0 | success | +| End-to-end inference latency (one chunk) | warm p50 = 62.5 ms (10 iters) | + +The numerical match against the lerobot CPU reference replicates four +lerobot-specific quirks that aren't in the SmolVLM2 HF config: + +1. **`resize_with_pad` pads top+left only** (image lands in the bottom-right + corner of the 512×512 frame), not centered. +2. **Pad-aware attention**: dynamic 2D mask + cumsum-based position_ids that + skip padding tokens. A static prefix-LM mask leaks attention into pad-token + positions. +3. **RoPE max_wavelength = 10000** (lerobot hardcodes this in `apply_rope`); + the SmolVLM2 HF config says 100000, but lerobot trained the model with 10000. +4. **Image flip** in the LIBERO env (180° rotate, both H and W) per the + `libero_processor` step in `lerobot.processor.env_processor`. + +## Inference Flow + +``` +images [2 cams × [B, 3, 512, 512]] lang_token_ids [B, 48] lang_mask [B, 48] state [B, 32] + | | | | + [Neuron NEFF #1] | | | + Vision (per camera) | | | + | | | | + [B, 128, 960] vision_features | | | + |______________________________|____________________|_________________| + | + [Neuron NEFF #2] + VLM Prefix (32 layers, pad-aware) + | + prefix_keys, prefix_values + [32, B, 177, 5, 64] each + | + ┌─────────────────────┴─────────────────────┐ + │ CPU Euler loop (10 steps) │ + │ for t in [1.0, 0.9, ..., 0.1]: │ + │ v_t = NEFF#3(x_t, t, K, V, pad) │ + │ x_t += dt * v_t │ + └─────────────────────┬─────────────────────┘ + | + action_chunk [B, 50, 32] + (first 7 dims used by env) +``` + +## Source Layout + +``` +SmolVLA-Libero/ +├── README.md +├── src/ +│ ├── __init__.py +│ ├── config_constants.py # All architecture constants from the checkpoint +│ ├── modeling_smolvla.py # SmolVLAPolicy: orchestrator (compile / load / generate) +│ ├── modeling_smolvla_vision.py # SigLIP-12L + connector (NEFF #1) +│ ├── modeling_smolvla_text.py # VLM 32L prefix + Action expert 32L denoiser (NEFF #2 + #3) +│ ├── neuron_action_head_base.py # NeuronDenoisingConfig — ModelWrapper-compatible config shim +│ ├── weight_mapping.py # HF safetensors -> 3 per-subgraph state dicts +│ └── run_inference.py # CLI: compile / run / benchmark (synthetic inputs) +└── test/ + ├── __init__.py + ├── integration/ + │ ├── __init__.py + │ └── test_model.py # Smoke + numerical tests against lerobot CPU + └── unit/ + └── __init__.py +``` + +We add `_text` because SmolVLA has separate text-prefix and action-expert +subgraphs that the existing per-model layout (e.g. `pixtral/` with +`modeling_pixtral.py` + `modeling_pixtral_vision.py`) does not need. + +## Usage + +### Setup + +```bash +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +# Download checkpoint (one-time) +python -c "from huggingface_hub import snapshot_download; \ + print(snapshot_download(repo_id='HuggingFaceVLA/smolvla_libero'))" +``` + +### Compile + run via CLI + +```bash +cd contrib/models/SmolVLA-Libero/src + +# Compile (one-time, ~90s wall clock for 3 NEFFs) +python run_inference.py --action compile \ + --hf-checkpoint /home/ubuntu/.cache/huggingface/hub/models--HuggingFaceVLA--smolvla_libero/snapshots// \ + --neff-dir /home/ubuntu/smol_vla_neff_libero + +# Run inference (synthetic inputs, p50 / p99 latency, NaN check) +python run_inference.py --action run \ + --hf-checkpoint /home/ubuntu/.cache/huggingface/hub/models--HuggingFaceVLA--smolvla_libero/snapshots// \ + --neff-dir /home/ubuntu/smol_vla_neff_libero +``` + +### Programmatic + +```python +import sys +from pathlib import Path +sys.path.insert(0, str(Path("contrib/models/SmolVLA-Libero/src"))) + +from modeling_smolvla import SmolVLAPolicy + +policy = SmolVLAPolicy(hf_checkpoint_dir="", tp_degree=1) +policy.load("/home/ubuntu/smol_vla_neff_libero") + +# images: list of NUM_CAMERAS tensors, each [B, 3, 512, 512] BF16 +# lang_token_ids: [B, 48] INT32 +# lang_mask: [B, 48] BOOL (True = real token, False = pad) +# state: [B, 32] FP32 (already normalized, zero-padded) +action_chunk = policy.generate(images, lang_token_ids, state, lang_mask=lang_mask) +# action_chunk: [B, 50, 32] FP32 (first 7 dims used by LIBERO) +``` + +## Compatibility Matrix + +| Instance / NxDI | 2.29 | +|-----------------|------| +| Trn3 | ✅ Working | +| Trn2 | Not tested | +| Trn1 / Inf2 | Not tested | + +## Testing + +The integration test compiles the three NEFFs (or reuses an already-compiled +directory), loads them, and runs three checks: + +1. **Smoke** — full pipeline returns a finite `[B, 50, 32]` action chunk with + non-zero variance. +2. **Warm latency** — p50 latency over 5 iterations is under a generous bound + (1 s; expected ~65 ms on `trn3pd98.3xlarge`). +3. **Neuron vs lerobot CPU parity** (NxDI accuracy check) — loads the + upstream `lerobot.SmolVLAPolicy` from the same HF checkpoint, runs a + CPU forward with identical inputs and identical seeded initial noise, + and asserts cosine similarity ≥ 0.99 and mean abs diff < 0.05 against + the Neuron action chunk. This is the SmolVLA equivalent of the logit + validation NxDI uses for CausalLM contrib models — it validates that + the Neuron port reproduces the reference implementation, not just + self-consistency. Skipped automatically if `lerobot` is not installed. + +```bash +# One-time, point the test at a checkpoint and a NEFF output directory: +export SMOLVLA_CKPT=/home/ubuntu/.cache/huggingface/hub/models--HuggingFaceVLA--smolvla_libero/snapshots// +export SMOLVLA_NEFF=/home/ubuntu/smol_vla_neff_libero + +# Run +pytest contrib/models/SmolVLA-Libero/test/integration/test_model.py --capture=tee-sys + +# Or directly +cd contrib/models/SmolVLA-Libero +python test/integration/test_model.py +``` + +The first invocation compiles the three NEFFs into `$SMOLVLA_NEFF` (~90 s). +Subsequent runs reuse the compiled artifacts and only re-load + execute. + +## Example Checkpoints + +- [`HuggingFaceVLA/smolvla_libero`](https://huggingface.co/HuggingFaceVLA/smolvla_libero) + — used for the validation results above + +## Maintainer + +Community contribution. + +**Last Updated:** 2026-05-06 diff --git a/contrib/models/SmolVLA-Libero/src/__init__.py b/contrib/models/SmolVLA-Libero/src/__init__.py new file mode 100644 index 00000000..d2aa75f0 --- /dev/null +++ b/contrib/models/SmolVLA-Libero/src/__init__.py @@ -0,0 +1,6 @@ +"""SmolVLA-Libero NxDI port for AWS Trainium. + +This directory is intended to be imported flat — add it to ``sys.path`` and +import the modules directly (e.g. ``from modeling_smolvla import SmolVLAPolicy``). +This matches the convention used by other models under ``contrib/models/``. +""" diff --git a/contrib/models/SmolVLA-Libero/src/config_constants.py b/contrib/models/SmolVLA-Libero/src/config_constants.py new file mode 100644 index 00000000..85d74b07 --- /dev/null +++ b/contrib/models/SmolVLA-Libero/src/config_constants.py @@ -0,0 +1,109 @@ +""" +SmolVLA architecture constants. + +All numbers extracted directly from the HuggingFaceVLA/smolvla_libero +checkpoint and config.json. Every other file in this port imports from here +— no hardcoded shapes anywhere else. + +Source: + HF model id : HuggingFaceVLA/smolvla_libero + Backbone : HuggingFaceTB/SmolVLM2-500M-Instruct (full 32-layer text model) + expert_width_multiplier : 0.5 (expert hidden = 480) + chunk size : 50 (action prediction horizon) + num steps : 10 (Euler denoising steps, on CPU) +""" + +# --------------------------------------------------------------------------- +# Vision encoder (SigLIP) +# --------------------------------------------------------------------------- + +VISION_NUM_LAYERS = 12 +VISION_HIDDEN = 768 +VISION_INTERMEDIATE = 3072 +VISION_NUM_HEADS = 12 # 768 / 64 +VISION_HEAD_DIM = 64 +VISION_PATCH_SIZE = 16 +VISION_IMAGE_SIZE = 512 +VISION_NUM_PATCHES = (VISION_IMAGE_SIZE // VISION_PATCH_SIZE) ** 2 # 1024 +VISION_LAYER_NORM_EPS = 1e-6 + +# Connector pixel-shuffle: scale_factor=4 → 4x4 spatial merge → 16x token reduction +PIXEL_SHUFFLE_SCALE = 4 +VISION_TOKENS_PER_IMAGE = VISION_NUM_PATCHES // (PIXEL_SHUFFLE_SCALE ** 2) # 64 +CONNECTOR_INPUT_DIM = VISION_HIDDEN * (PIXEL_SHUFFLE_SCALE ** 2) # 12288 + +# --------------------------------------------------------------------------- +# VLM text backbone (SmolLM2-style — LlamaDecoderLayer with GQA) +# --------------------------------------------------------------------------- + +VLM_NUM_LAYERS = 32 +VLM_HIDDEN = 960 +VLM_INTERMEDIATE = 2560 +VLM_NUM_HEADS = 15 +VLM_NUM_KV_HEADS = 5 +VLM_HEAD_DIM = 64 +VLM_KV_DIM = VLM_NUM_KV_HEADS * VLM_HEAD_DIM # 320 +VLM_RMS_NORM_EPS = 1e-5 +VLM_ROPE_THETA = 10000.0 # lerobot.smolvlm_with_expert.apply_rope hardcodes max_wavelength=10000 +VLM_VOCAB_SIZE = 49280 + +# --------------------------------------------------------------------------- +# Action expert (Llama-style with cross-attn override every other layer) +# --------------------------------------------------------------------------- + +EXPERT_NUM_LAYERS = 32 +EXPERT_HIDDEN = 480 # round(960 * 0.5) +EXPERT_INTERMEDIATE = 1280 # confirmed from checkpoint +EXPERT_NUM_HEADS = 15 # same as VLM +EXPERT_NUM_KV_HEADS = 5 # same as VLM +EXPERT_HEAD_DIM = 64 +EXPERT_KV_DIM = EXPERT_NUM_KV_HEADS * EXPERT_HEAD_DIM # 320 +EXPERT_Q_DIM = EXPERT_NUM_HEADS * EXPERT_HEAD_DIM # 960 +SELF_ATTN_EVERY_N_LAYERS = 2 +# Even-indexed expert layers (0,2,...,14) are "self-attn" — concat past VLM K/V +# with new expert K/V along seq dim. k/v_proj input dim = expert hidden (720). +# Odd-indexed expert layers (1,3,...,15) are "cross-attn" — Q from expert, +# K/V from past VLM K/V projected through k/v_proj. k/v_proj input dim = 320. + +# --------------------------------------------------------------------------- +# Action / state projections (flow-matching action head) +# --------------------------------------------------------------------------- + +MAX_STATE_DIM = 32 +MAX_ACTION_DIM = 32 +ACTION_CHUNK_SIZE = 50 # config.chunk_size / n_action_steps +NUM_DENOISE_STEPS = 10 + +# Sinusoidal timestep embedding parameters +TIMESTEP_EMBED_DIM = EXPERT_HIDDEN # 720 — output dim of sin/cos block +TIMESTEP_MIN_PERIOD = 0.004 +TIMESTEP_MAX_PERIOD = 4.0 + +# action_time_mlp_in: Linear(1440, 720) — (action_emb 720 ⊕ time_emb 720) → 720 +# action_time_mlp_out: Linear(720, 720) +ACTION_TIME_MLP_IN_DIM = EXPERT_HIDDEN * 2 # 1440 + +# --------------------------------------------------------------------------- +# Sequence layout (static at compile time) +# --------------------------------------------------------------------------- + +NUM_CAMERAS = 2 # image (agentview), image2 (wrist) +NUM_TEXT_TOKENS = 48 # tokenizer_max_length +NUM_VISION_TOKENS_TOTAL = NUM_CAMERAS * VISION_TOKENS_PER_IMAGE # 192 +NUM_STATE_TOKENS = 1 +PREFIX_LEN = NUM_VISION_TOKENS_TOTAL + NUM_TEXT_TOKENS + NUM_STATE_TOKENS # 241 +SUFFIX_LEN = ACTION_CHUNK_SIZE # 50 +FULL_LEN = PREFIX_LEN + SUFFIX_LEN # 291 + +# --------------------------------------------------------------------------- +# Neuron runtime +# --------------------------------------------------------------------------- + +# DEVIATION FLAG: tp_degree=1 because num_attention_heads=15 and +# num_kv_heads=5 — neither divides cleanly into the 4 cores available on +# trn3pd98.3xlarge. Production NxDI parallel primitives are still used so the +# code stays portable to instances where head counts allow real TP, but on +# this hardware sharding effectively no-ops. +DEFAULT_TP_DEGREE = 1 +BATCH_SIZE = 1 +TORCH_DTYPE_STR = "bfloat16" diff --git a/contrib/models/SmolVLA-Libero/src/modeling_smolvla.py b/contrib/models/SmolVLA-Libero/src/modeling_smolvla.py new file mode 100644 index 00000000..79a3b6d0 --- /dev/null +++ b/contrib/models/SmolVLA-Libero/src/modeling_smolvla.py @@ -0,0 +1,333 @@ +""" +SmolVLA top-level application class. + +Orchestrates the three compiled subgraphs and the CPU-side Euler loop: + + images, lang_ids, state + -> Vision NEFF (per camera) -> stack [B, 192, 960] + -> Prefix NEFF -> prefix_keys, prefix_values + -> CPU loop over 10 Euler steps: + Denoise NEFF(noisy_actions, t, K, V) -> v_t + noisy_actions <- noisy_actions + dt * v_t + -> action chunk [B, 50, 32] + +The compiled subgraphs are compiled via NxDI's ModelBuilder, which initializes +parallel_state so ColumnParallelLinear / RowParallelLinear use the parallel +path. tp_degree=1 on this hardware (see config_constants.py — 15 attn heads +don't divide the 4 cores cleanly). + +DEVIATIONS FROM "everything on Neuron": + - The 10-step Euler loop runs on CPU. Static-shape compilation cannot host + a Python `for step in range(N)` — the loop body is the compiled graph. + - Image preprocessing (resize, normalize) and tokenization run on CPU. + Dataloading, not model compute. +""" + +from __future__ import annotations + +import logging +import os +import time +from typing import List + +import torch +import torch.nn as nn + +from neuronx_distributed.trace.model_builder import ModelBuilder, BaseModelInstance +from safetensors.torch import load_file + +import config_constants as C +from neuron_action_head_base import NeuronDenoisingConfig, COMPILED_MODEL_FILE_NAME +from modeling_smolvla_vision import SmolVLAVisionEncoder +from modeling_smolvla_text import SmolVLAPrefixModel, SmolVLADenoiseStep +from weight_mapping import load_hf_state_dict, split_hf_state_dict + +logger = logging.getLogger("SmolVLA") + +# --------------------------------------------------------------------------- +# NEFF directory layout +# --------------------------------------------------------------------------- + +VISION_NEFF_SUBDIR = "vision" +PREFIX_NEFF_SUBDIR = "prefix" +DENOISE_NEFF_SUBDIR = "denoise" + + +def _make_config(tp_degree: int = C.DEFAULT_TP_DEGREE, + batch_size: int = C.BATCH_SIZE) -> NeuronDenoisingConfig: + """Construct a config object compatible with NxDI's ModelWrapper.""" + return NeuronDenoisingConfig( + batch_size=batch_size, + tp_degree=tp_degree, + action_chunk_size=C.ACTION_CHUNK_SIZE, + action_dim=C.MAX_ACTION_DIM, + num_conditioning_tokens=C.PREFIX_LEN, + conditioning_hidden_size=C.VLM_HIDDEN, + timestep_embed_dim=C.TIMESTEP_EMBED_DIM, + torch_dtype=torch.bfloat16, + ) + + +def _compiler_args() -> str: + """Standard DiT-safe compiler args (see action_head_translation.md).""" + return ( + "--auto-cast=none " + "-O1 " + "--tensorizer-options='" + "--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=1'" + ) + + +# --------------------------------------------------------------------------- +# SmolVLAPolicy — compile / load / generate +# --------------------------------------------------------------------------- + +class SmolVLAPolicy(nn.Module): + """ + End-to-end SmolVLA policy on Trainium. + + Lifecycle: + compile(save_dir, hf_checkpoint) -> writes 3 NEFFs and 3 sharded weight dirs + load(save_dir) -> loads 3 NEFFs to Neuron + generate(images, lang_ids, state) -> [B, 50, 32] action chunk + """ + + def __init__(self, + hf_checkpoint_dir: str, + tp_degree: int = C.DEFAULT_TP_DEGREE, + batch_size: int = C.BATCH_SIZE): + super().__init__() + self.hf_checkpoint_dir = hf_checkpoint_dir + self.config = _make_config(tp_degree=tp_degree, batch_size=batch_size) + + # ModelBuilder needs a callable that constructs the nn.Module after + # parallel_state is initialized. Use BaseModelInstance with module_cls + # set to a no-arg lambda. + # Cache the HF state dict slices for use as ModelBuilder checkpoint loaders + self._hf_sd = None + self._vision_sd = None + self._prefix_sd = None + self._denoise_sd = None + + # Loaded NEFFs + self._vision_traced = None + self._prefix_traced = None + self._denoise_traced = None + + # ---------------------------------------------------------------------- + # Checkpoint loaders for ModelBuilder + # ---------------------------------------------------------------------- + + def _ensure_hf_loaded(self): + if self._hf_sd is None: + self._hf_sd = load_hf_state_dict(self.hf_checkpoint_dir) + self._vision_sd, self._prefix_sd, self._denoise_sd = split_hf_state_dict(self._hf_sd) + # Cast everything to bf16 (cast_type='config' default) + for sd in (self._vision_sd, self._prefix_sd, self._denoise_sd): + for k, v in list(sd.items()): + if torch.is_floating_point(v) and v.dtype != torch.bfloat16: + sd[k] = v.to(torch.bfloat16) + + def _vision_loader(self, mmap=False): + self._ensure_hf_loaded() + return self._vision_sd + + def _prefix_loader(self, mmap=False): + self._ensure_hf_loaded() + return self._prefix_sd + + def _denoise_loader(self, mmap=False): + self._ensure_hf_loaded() + return self._denoise_sd + + # ---------------------------------------------------------------------- + # Compile + # ---------------------------------------------------------------------- + + def _build_one(self, tag: str, module_cls, example_inputs, save_dir: str, ckpt_loader): + """Compile a single subgraph via ModelBuilder and shard its weights. + + Args: + tag: Subgraph name (used as the key in builder.add). + module_cls: No-arg callable returning an nn.Module. Called by + BaseModelInstance.load_module() AFTER parallel_state is up. + example_inputs: List[Tuple[Tensor, ...]] — one tuple per bucket. + save_dir: Output dir for NEFF + sharded weights. + ckpt_loader: Callable returning the state_dict for this subgraph. + """ + os.makedirs(save_dir, exist_ok=True) + builder = ModelBuilder( + router=None, + tp_degree=self.config.neuron_config.tp_degree, + pp_degree=1, + ep_degree=1, + world_size=self.config.neuron_config.tp_degree, + start_rank_id=0, + local_ranks_size=self.config.neuron_config.tp_degree, + checkpoint_loader=ckpt_loader, + compiler_workdir=os.path.join(save_dir, "compiler_workdir"), + ) + instance = BaseModelInstance(module_cls=module_cls, input_output_aliases={}) + builder.add( + key=tag, + model_instance=instance, + example_inputs=example_inputs, + compiler_args=_compiler_args(), + ) + traced = builder.trace(initialize_model_weights=False) + torch.jit.save(traced, os.path.join(save_dir, COMPILED_MODEL_FILE_NAME)) + sharded_dir = os.path.join(save_dir, "weights") + os.makedirs(sharded_dir, exist_ok=True) + builder.shard_checkpoint(serialize_path=sharded_dir + "/") + del traced + logger.info(f"Compiled {tag} -> {save_dir}") + + def compile(self, save_root: str): + """Compile all three NEFFs to save_root/{vision,prefix,denoise}/.""" + self._ensure_hf_loaded() + + B = self.config.neuron_config.batch_size + + # 1. Vision + vision_inputs = [( + torch.zeros(B, 3, C.VISION_IMAGE_SIZE, C.VISION_IMAGE_SIZE, dtype=torch.bfloat16), + )] + vdir = os.path.join(save_root, VISION_NEFF_SUBDIR) + logger.info("=== Compiling vision encoder ===") + t0 = time.monotonic() + self._build_one("vision_encoder", SmolVLAVisionEncoder, vision_inputs, vdir, self._vision_loader) + logger.info(f"Vision compile time: {time.monotonic()-t0:.1f}s") + + # 2. Prefix + prefix_inputs = [( + torch.zeros(B, C.NUM_VISION_TOKENS_TOTAL, C.VLM_HIDDEN, dtype=torch.bfloat16), + torch.zeros(B, C.NUM_TEXT_TOKENS, dtype=torch.int32), + torch.ones(B, C.NUM_TEXT_TOKENS, dtype=torch.bool), + torch.zeros(B, C.MAX_STATE_DIM, dtype=torch.float32), + )] + pdir = os.path.join(save_root, PREFIX_NEFF_SUBDIR) + logger.info("=== Compiling VLM prefix ===") + t0 = time.monotonic() + self._build_one("prefix", SmolVLAPrefixModel, prefix_inputs, pdir, self._prefix_loader) + logger.info(f"Prefix compile time: {time.monotonic()-t0:.1f}s") + + # 3. Denoise step + denoise_inputs = [( + torch.zeros(B, C.ACTION_CHUNK_SIZE, C.MAX_ACTION_DIM, dtype=torch.float32), + torch.zeros(B, dtype=torch.float32), + torch.zeros(C.VLM_NUM_LAYERS, B, C.PREFIX_LEN, C.VLM_NUM_KV_HEADS, C.VLM_HEAD_DIM, dtype=torch.bfloat16), + torch.zeros(C.VLM_NUM_LAYERS, B, C.PREFIX_LEN, C.VLM_NUM_KV_HEADS, C.VLM_HEAD_DIM, dtype=torch.bfloat16), + torch.ones(B, C.PREFIX_LEN, dtype=torch.bool), + )] + ddir = os.path.join(save_root, DENOISE_NEFF_SUBDIR) + logger.info("=== Compiling denoise step ===") + t0 = time.monotonic() + self._build_one("denoise_step", SmolVLADenoiseStep, denoise_inputs, ddir, self._denoise_loader) + logger.info(f"Denoise compile time: {time.monotonic()-t0:.1f}s") + + # ---------------------------------------------------------------------- + # Load + # ---------------------------------------------------------------------- + + def _load_one(self, save_dir: str): + traced = torch.jit.load(os.path.join(save_dir, COMPILED_MODEL_FILE_NAME)) + weights = [] + local_ranks = self.config.neuron_config.tp_degree + for rank in range(local_ranks): + ckpt = load_file(os.path.join(save_dir, "weights", f"tp{rank}_sharded_checkpoint.safetensors")) + weights.append(ckpt) + start_rank = torch.tensor([0], dtype=torch.int32) + traced.nxd_model.initialize(weights, start_rank) + return traced + + def load(self, save_root: str): + """Load three NEFFs and their pre-sharded weights to Neuron device.""" + logger.info("Loading vision NEFF...") + self._vision_traced = self._load_one(os.path.join(save_root, VISION_NEFF_SUBDIR)) + logger.info("Loading prefix NEFF...") + self._prefix_traced = self._load_one(os.path.join(save_root, PREFIX_NEFF_SUBDIR)) + logger.info("Loading denoise NEFF...") + self._denoise_traced = self._load_one(os.path.join(save_root, DENOISE_NEFF_SUBDIR)) + + # ---------------------------------------------------------------------- + # Inference + # ---------------------------------------------------------------------- + + def _embed_cameras(self, images: List[torch.Tensor]) -> torch.Tensor: + """Run vision NEFF once per camera and concat -> [B, 192, 960].""" + outs = [] + for img in images: + assert img.shape == (self.config.neuron_config.batch_size, 3, + C.VISION_IMAGE_SIZE, C.VISION_IMAGE_SIZE), ( + f"Each camera image must be [B, 3, {C.VISION_IMAGE_SIZE}, {C.VISION_IMAGE_SIZE}]" + ) + out = self._vision_traced.nxd_model.forward([img.to(torch.bfloat16)]) + # ModelBuilder traced forward returns a list/tuple — unwrap + out = out[0] if isinstance(out, (list, tuple)) else out + outs.append(out) + return torch.cat(outs, dim=1) # [B, 192, 960] + + @torch.no_grad() + def generate( + self, + images: List[torch.Tensor], # length-NUM_CAMERAS list, each [B, 3, 512, 512] + lang_token_ids: torch.Tensor, # [B, NUM_TEXT_TOKENS] INT32 + state: torch.Tensor, # [B, 32] FP32 + lang_mask: torch.Tensor = None, # [B, NUM_TEXT_TOKENS] BOOL (defaults to all-True) + num_steps: int = C.NUM_DENOISE_STEPS, + noise: torch.Tensor = None, # [B, ACTION_CHUNK_SIZE, MAX_ACTION_DIM] FP32, optional + ) -> torch.Tensor: + """Run the full pipeline: vision -> prefix -> N denoise steps -> action chunk. + + ``noise`` controls the Euler-loop initial state. Pass an explicit tensor + (e.g. for parity testing against a CPU reference) to bypass the + ``torch.randn`` default. + """ + assert len(images) == C.NUM_CAMERAS, f"Expected {C.NUM_CAMERAS} camera tensors" + assert isinstance(num_steps, int) + + B = state.shape[0] + if lang_mask is None: + lang_mask = torch.ones(B, C.NUM_TEXT_TOKENS, dtype=torch.bool) + lang_mask = lang_mask.to(torch.bool) + + # 1. Vision (returns fp32 from compiled graph; cast to bf16 for prefix) + vision_features = self._embed_cameras(images).to(torch.bfloat16) # [B, NV, 960] + + # 2. Prefix -> KV cache (with attention_mask) + prefix_out = self._prefix_traced.nxd_model.forward([ + vision_features, + lang_token_ids.to(torch.int32), + lang_mask, + state.to(torch.float32), + ]) + pk, pv = prefix_out + pk = pk.to(torch.bfloat16) + pv = pv.to(torch.bfloat16) + + # Build the prefix-wide pad mask: vision (all valid) + lang_mask + state (valid) + prefix_pad = torch.cat([ + torch.ones(B, C.NUM_VISION_TOKENS_TOTAL, dtype=torch.bool), + lang_mask, + torch.ones(B, C.NUM_STATE_TOKENS, dtype=torch.bool), + ], dim=1) # [B, PREFIX_LEN] + + # 3. Euler loop on CPU + if noise is not None: + assert tuple(noise.shape) == (B, C.ACTION_CHUNK_SIZE, C.MAX_ACTION_DIM), ( + f"noise must have shape [B={B}, {C.ACTION_CHUNK_SIZE}, {C.MAX_ACTION_DIM}], " + f"got {tuple(noise.shape)}" + ) + x_t = noise.to(torch.float32) + else: + x_t = torch.randn(B, C.ACTION_CHUNK_SIZE, C.MAX_ACTION_DIM, dtype=torch.float32) + dt = -1.0 / num_steps + for step in range(num_steps): + t = 1.0 + step * dt + t_tensor = torch.tensor([t] * B, dtype=torch.float32) + v_t = self._denoise_traced.nxd_model.forward([x_t, t_tensor, pk, pv, prefix_pad]) + v_t = v_t[0] if isinstance(v_t, (list, tuple)) else v_t + x_t = x_t + dt * v_t.to(torch.float32) + + return x_t # [B, 50, 32] diff --git a/contrib/models/SmolVLA-Libero/src/modeling_smolvla_text.py b/contrib/models/SmolVLA-Libero/src/modeling_smolvla_text.py new file mode 100644 index 00000000..08ebd58a --- /dev/null +++ b/contrib/models/SmolVLA-Libero/src/modeling_smolvla_text.py @@ -0,0 +1,709 @@ +""" +SmolVLA text & expert subgraphs +================================ + +Defines two compiled subgraphs: + + Subgraph #2 PrefixWrapper + VLM 16-layer text decoder, fills KV cache. + + Input: + vision_features [B, 192, 960] BF16 — from vision encoder + lang_token_ids [B, 48] INT32 + state [B, 32] FP32 + Output (stacked across layers): + prefix_keys [16, B, 241, 5, 64] BF16 + prefix_values [16, B, 241, 5, 64] BF16 + + Subgraph #3 DenoiseStepWrapper + One Euler step of the action expert. + + Input: + noisy_actions [B, 50, 32] BF16 + timestep [B] FP32 (scalar per batch) + prefix_keys [16, B, 241, 5, 64] BF16 + prefix_values [16, B, 241, 5, 64] BF16 + Output: + v_t [B, 50, 32] FP32 + +Expert layer alternation (config: self_attn_every_n_layers=2): + Even layers (0, 2, ..., 14) "self-attn" + Q from suffix; K/V = concat(past_VLM_KV, suffix_KV) over seq dim. + K/V dim = 320 (5 KV heads × 64). q_proj outputs 960 (15 q heads × 64). + RoPE on Q and K with positions 241..290 (continuing prefix). + Attention: Q[50] × K[291], full bidirectional within suffix. + + Odd layers (1, 3, ..., 15) "cross-attn" + Q from suffix expert hidden; K/V from cached VLM K/V re-projected + through expert k_proj/v_proj (input 320 → output 320). + RoPE on Q only with positions 0..49. + Attention: Q[50] × K[241]. + +The interleaving exactly mirrors `forward_attn_layer` and +`forward_cross_attn_layer` in the lerobot SmolVLA source. + +Sinusoidal timestep embedding runs INSIDE the compiled denoise graph. +The frequency table is a register_buffer (pre-computed in __init__), so +no torch.linspace/arange is invoked during forward — see nxdi_background.md +"Dynamic Constants". +""" + +from __future__ import annotations + +import math +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, + ParallelEmbedding, +) +from neuronx_distributed_inference.models.model_wrapper import ModelWrapper + +import config_constants as C +from neuron_action_head_base import NeuronDenoisingConfig + + +# --------------------------------------------------------------------------- +# Parallel-linear helpers (TP=1 fallback safe — see config_constants.py) +# --------------------------------------------------------------------------- + +def _col(in_f: int, out_f: int, bias: bool = False) -> nn.Module: + if parallel_state.model_parallel_is_initialized(): + return ColumnParallelLinear( + in_f, out_f, bias=bias, gather_output=False, + dtype=torch.bfloat16, + tensor_model_parallel_group=parallel_state.get_tensor_model_parallel_group(), + ) + return nn.Linear(in_f, out_f, bias=bias) + + +def _row(in_f: int, out_f: int, bias: bool = False) -> nn.Module: + if parallel_state.model_parallel_is_initialized(): + return RowParallelLinear( + in_f, out_f, bias=bias, input_is_parallel=True, + dtype=torch.bfloat16, + tensor_model_parallel_group=parallel_state.get_tensor_model_parallel_group(), + ) + return nn.Linear(in_f, out_f, bias=bias) + + +def _embed(num_emb: int, dim: int) -> nn.Module: + if parallel_state.model_parallel_is_initialized(): + return ParallelEmbedding( + num_emb, dim, + dtype=torch.bfloat16, + shard_across_embedding=True, + ) + return nn.Embedding(num_emb, dim) + + +# --------------------------------------------------------------------------- +# RoPE (Llama-style, applied to a [B, S, H, D] tensor) +# --------------------------------------------------------------------------- + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + half = x.shape[-1] // 2 + return torch.cat((-x[..., half:], x[..., :half]), dim=-1) + + +def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: + """ + x: [B, S, H, D] + cos: [B, S, D] -> broadcast to [B, S, 1, D] + sin: [B, S, D] + """ + cos = cos.unsqueeze(2) + sin = sin.unsqueeze(2) + return (x * cos) + (_rotate_half(x) * sin) + + +class _RoPECache(nn.Module): + """Pre-computed RoPE cos/sin tables; indexed by position_ids at forward.""" + def __init__(self, head_dim: int, max_pos: int, base: float): + super().__init__() + # inv_freq: [head_dim // 2] + inv_freq = 1.0 / ( + base ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) + ) + # positions: [max_pos] + positions = torch.arange(max_pos, dtype=torch.float32) + # freqs: [max_pos, head_dim // 2] + freqs = positions.unsqueeze(1) * inv_freq.unsqueeze(0) + # emb: [max_pos, head_dim] (concat of [freqs, freqs] to match Llama convention) + emb = torch.cat([freqs, freqs], dim=-1) + self.register_buffer("cos_cached", emb.cos().to(torch.bfloat16), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(torch.bfloat16), persistent=False) + + def forward(self, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # position_ids: [B, S] -> cos/sin: [B, S, head_dim] + cos = self.cos_cached[position_ids] + sin = self.sin_cached[position_ids] + return cos, sin + + +# --------------------------------------------------------------------------- +# Eager GQA attention (used by both VLM prefix and expert layers) +# --------------------------------------------------------------------------- + +def _eager_gqa_attention( + q: torch.Tensor, # [B, S_q, H, D] + k: torch.Tensor, # [B, S_kv, KH, D] + v: torch.Tensor, # [B, S_kv, KH, D] + attn_mask_2d: torch.Tensor, # [B, S_q, S_kv] BOOL + head_dim: int, +) -> torch.Tensor: + B, S_q, H, D = q.shape + S_kv, KH = k.shape[1], k.shape[2] + groups = H // KH + + # Repeat K/V to match Q heads (GQA expansion) + k = k[:, :, :, None, :].expand(B, S_kv, KH, groups, D).reshape(B, S_kv, H, D) + v = v[:, :, :, None, :].expand(B, S_kv, KH, groups, D).reshape(B, S_kv, H, D) + + # Compute in fp32 to match HF eager_attention_forward upcast (see modeling line 528-540) + q32 = q.to(torch.float32).transpose(1, 2) # [B, H, S_q, D] + k32 = k.to(torch.float32).transpose(1, 2) # [B, H, S_kv, D] + attn = torch.matmul(q32, k32.transpose(2, 3)) * (head_dim ** -0.5) + big_neg = torch.finfo(torch.float32).min + # broadcast mask [B, S_q, S_kv] -> [B, 1, S_q, S_kv] + attn = torch.where(attn_mask_2d.unsqueeze(1), attn, big_neg) + probs = F.softmax(attn, dim=-1).to(v.dtype) + out = torch.matmul(probs, v.transpose(1, 2)) # [B, H, S_q, D] + out = out.transpose(1, 2).reshape(B, S_q, H * D) + return out + + +# --------------------------------------------------------------------------- +# RMS norm +# --------------------------------------------------------------------------- + +class _RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # cast to fp32 for the variance computation (Llama convention) + in_dtype = x.dtype + x_f32 = x.to(torch.float32) + var = x_f32.pow(2).mean(-1, keepdim=True) + x_f32 = x_f32 * torch.rsqrt(var + self.eps) + return (x_f32 * self.weight).to(in_dtype) + + +# --------------------------------------------------------------------------- +# Llama-style MLP (gated SiLU) +# --------------------------------------------------------------------------- + +class _LlamaMLP(nn.Module): + def __init__(self, hidden: int, intermediate: int): + super().__init__() + self.gate_proj = _col(hidden, intermediate, bias=False) + self.up_proj = _col(hidden, intermediate, bias=False) + self.down_proj = _row(intermediate, hidden, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +# --------------------------------------------------------------------------- +# VLM decoder layer (standard GQA self-attention, used in prefix pass only) +# --------------------------------------------------------------------------- + +class _VLMSelfAttention(nn.Module): + def __init__(self): + super().__init__() + H, KH, D = C.VLM_NUM_HEADS, C.VLM_NUM_KV_HEADS, C.VLM_HEAD_DIM + self.num_heads = H + self.num_kv_heads = KH + self.head_dim = D + self.q_proj = _col(C.VLM_HIDDEN, H * D, bias=False) + self.k_proj = _col(C.VLM_HIDDEN, KH * D, bias=False) + self.v_proj = _col(C.VLM_HIDDEN, KH * D, bias=False) + self.o_proj = _row(H * D, C.VLM_HIDDEN, bias=False) + + +class VLMDecoderLayer(nn.Module): + """ + LlamaDecoderLayer split apart so the prefix pass can return per-layer K/V + in the same step as the residual update. + """ + def __init__(self): + super().__init__() + self.input_layernorm = _RMSNorm(C.VLM_HIDDEN, C.VLM_RMS_NORM_EPS) + self.self_attn = _VLMSelfAttention() + self.post_attention_layernorm = _RMSNorm(C.VLM_HIDDEN, C.VLM_RMS_NORM_EPS) + self.mlp = _LlamaMLP(C.VLM_HIDDEN, C.VLM_INTERMEDIATE) + + def forward( + self, + hidden_states: torch.Tensor, # [B, 241, 960] + cos: torch.Tensor, # [B, 241, 64] + sin: torch.Tensor, # [B, 241, 64] + attention_mask_2d: torch.Tensor,# [B, 241, 241] BOOL + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + residual = hidden_states + x = self.input_layernorm(hidden_states) + + B, S, _ = x.shape + H, KH, D = self.self_attn.num_heads, self.self_attn.num_kv_heads, self.self_attn.head_dim + q = self.self_attn.q_proj(x).view(B, S, H, D) + k = self.self_attn.k_proj(x).view(B, S, KH, D) + v = self.self_attn.v_proj(x).view(B, S, KH, D) + + q = _apply_rope(q, cos, sin) + k = _apply_rope(k, cos, sin) + + attn = _eager_gqa_attention(q, k, v, attention_mask_2d, D) + attn = attn.to(self.self_attn.o_proj.weight.dtype if hasattr(self.self_attn.o_proj, "weight") else attn.dtype) + x = residual + self.self_attn.o_proj(attn) + + residual2 = x + x = self.post_attention_layernorm(x) + x = self.mlp(x) + x = residual2 + x + + return x, k, v + + +# --------------------------------------------------------------------------- +# Prefix model — the thing that goes into NEFF #2 +# --------------------------------------------------------------------------- + +class SmolVLAPrefixModel(nn.Module): + def __init__(self): + super().__init__() + self.embed_tokens = _embed(C.VLM_VOCAB_SIZE, C.VLM_HIDDEN) + self.layers = nn.ModuleList( + [VLMDecoderLayer() for _ in range(C.VLM_NUM_LAYERS)] + ) + self.norm = _RMSNorm(C.VLM_HIDDEN, C.VLM_RMS_NORM_EPS) + # state_proj is owned by the prefix: state -> 1 prefix token + self.state_proj = nn.Linear(C.MAX_STATE_DIM, C.VLM_HIDDEN, bias=True) + + # RoPE table over the largest position the prefix will see + self.rope = _RoPECache(C.VLM_HEAD_DIM, C.FULL_LEN, C.VLM_ROPE_THETA) + + # Constants for embed scaling (sqrt(hidden) — see modeling_smolvla.py:684) + self.register_buffer( + "lang_emb_scale", + torch.tensor(C.VLM_HIDDEN ** 0.5, dtype=torch.bfloat16), + persistent=False, + ) + + # Block-attention markers for the prefix-LM cumsum mask: + # image+lang are one block (mark=0 for all but the first), state is a + # separate block (mark=1 at state position to start a new block). + # Pad-aware variant: pad lang positions are skipped via the pad mask. + att_marks = torch.zeros(C.PREFIX_LEN, dtype=torch.int64) + att_marks[C.NUM_VISION_TOKENS_TOTAL + C.NUM_TEXT_TOKENS:] = 1 + self.register_buffer( + "prefix_att_marks", + att_marks.unsqueeze(0), # [1, PREFIX_LEN] + persistent=False, + ) + # constant ones for the always-valid vision and state regions + self.register_buffer( + "vision_pad_const", + torch.ones(1, C.NUM_VISION_TOKENS_TOTAL, dtype=torch.bool), + persistent=False, + ) + self.register_buffer( + "state_pad_const", + torch.ones(1, C.NUM_STATE_TOKENS, dtype=torch.bool), + persistent=False, + ) + + def forward( + self, + vision_features: torch.Tensor, # [B, 192, 960] — already scaled by sqrt(hidden) + lang_token_ids: torch.Tensor, # [B, NUM_TEXT_TOKENS] INT32 + lang_mask: torch.Tensor, # [B, NUM_TEXT_TOKENS] BOOL (True = valid token) + state: torch.Tensor, # [B, 32] + ) -> Tuple[torch.Tensor, torch.Tensor]: + B = vision_features.shape[0] + + # 1. embed lang and scale + lang_emb = self.embed_tokens(lang_token_ids).to(torch.bfloat16) * self.lang_emb_scale # [B, S_lang, 960] + + # 2. project state and add token dim + state_emb = self.state_proj(state.to(torch.bfloat16)).unsqueeze(1) # [B, 1, 960] + + # 3. concat: image + lang + state + prefix = torch.cat([vision_features, lang_emb, state_emb], dim=1) # [B, PREFIX_LEN, 960] + + # 4. Pad-aware position ids and attention mask. + # Build full pad mask [B, PREFIX_LEN] = vision_ones | lang_mask | state_ones + pad_mask = torch.cat([ + self.vision_pad_const.expand(B, -1), + lang_mask.to(torch.bool), + self.state_pad_const.expand(B, -1), + ], dim=1) # [B, PREFIX_LEN] + + # position_ids = cumsum(pad_mask) - 1, clamped at 0 + position_ids = torch.cumsum(pad_mask.to(torch.int64), dim=1) - 1 + position_ids = torch.clamp(position_ids, min=0) # [B, PREFIX_LEN] + cos, sin = self.rope(position_ids) # [B, PREFIX_LEN, 64] + + # 2D attention mask = (cumsum-prefix-LM mask) AND (pad outer product) + att_marks = self.prefix_att_marks.expand(B, -1) # [B, PREFIX_LEN] + cumsum_att = torch.cumsum(att_marks, dim=1) # [B, PREFIX_LEN] + att_2d = cumsum_att.unsqueeze(1) <= cumsum_att.unsqueeze(2) # [B, PREFIX_LEN, PREFIX_LEN] + pad_2d = pad_mask.unsqueeze(1) & pad_mask.unsqueeze(2) # [B, PREFIX_LEN, PREFIX_LEN] + attn_mask = att_2d & pad_2d # [B, PREFIX_LEN, PREFIX_LEN] + + # 5. 16 layers, collect per-layer K/V + keys: List[torch.Tensor] = [] + values: List[torch.Tensor] = [] + x = prefix + for layer in self.layers: + x, k, v = layer(x, cos, sin, attn_mask) + keys.append(k) + values.append(v) + + # x is unused after final layer (we only need K/V from the prefix), but + # the final norm matches the HF reference exactly. Keep it cheap-out: + # the result is discarded. + _ = self.norm(x) + + prefix_keys = torch.stack(keys, dim=0) # [16, B, 241, 5, 64] + prefix_values = torch.stack(values, dim=0) + return prefix_keys, prefix_values + + +# --------------------------------------------------------------------------- +# Expert decoder layers: even index = self-attn, odd index = cross-attn +# --------------------------------------------------------------------------- + +class _ExpertSelfAttnLayer(nn.Module): + """ + Even-indexed expert layer. + + Q from suffix expert hidden (720). K/V projected from suffix expert hidden + (720 → 320), then concatenated with the past VLM K/V (320 each) along the + seq dim. RoPE applied with prefix-continued positions (241..290) on Q and + on the suffix-portion of K only — past K already has RoPE baked in from + the prefix pass. + """ + def __init__(self): + super().__init__() + H, KH, D = C.EXPERT_NUM_HEADS, C.EXPERT_NUM_KV_HEADS, C.EXPERT_HEAD_DIM + self.head_dim = D + self.num_heads = H + self.num_kv_heads = KH + + self.input_layernorm = _RMSNorm(C.EXPERT_HIDDEN, C.VLM_RMS_NORM_EPS) + self.q_proj = _col(C.EXPERT_HIDDEN, H * D, bias=False) # 720 -> 960 + self.k_proj = _col(C.EXPERT_HIDDEN, KH * D, bias=False) # 720 -> 320 + self.v_proj = _col(C.EXPERT_HIDDEN, KH * D, bias=False) # 720 -> 320 + self.o_proj = _row(H * D, C.EXPERT_HIDDEN, bias=False) # 960 -> 720 + self.post_attention_layernorm = _RMSNorm(C.EXPERT_HIDDEN, C.VLM_RMS_NORM_EPS) + self.mlp = _LlamaMLP(C.EXPERT_HIDDEN, C.EXPERT_INTERMEDIATE) + + def forward( + self, + suffix_hidden: torch.Tensor, # [B, 50, 720] + past_k: torch.Tensor, # [B, 241, 5, 64] + past_v: torch.Tensor, # [B, 241, 5, 64] + suffix_cos: torch.Tensor, # [B, 50, 64] — positions 241..290 + suffix_sin: torch.Tensor, # [B, 50, 64] + attention_mask_2d: torch.Tensor,# [B, 50, 291] BOOL + ) -> torch.Tensor: + residual = suffix_hidden + x = self.input_layernorm(suffix_hidden) + + B, S = x.shape[:2] + H, KH, D = self.num_heads, self.num_kv_heads, self.head_dim + + q = self.q_proj(x).view(B, S, H, D) + k = self.k_proj(x).view(B, S, KH, D) + v = self.v_proj(x).view(B, S, KH, D) + + q = _apply_rope(q, suffix_cos, suffix_sin) + k = _apply_rope(k, suffix_cos, suffix_sin) + + # Concat past + new along seq dim -> 291 keys/values + full_k = torch.cat([past_k, k], dim=1) + full_v = torch.cat([past_v, v], dim=1) + + attn = _eager_gqa_attention(q, full_k, full_v, attention_mask_2d, D) + x = residual + self.o_proj(attn.to(suffix_hidden.dtype)) + + residual2 = x + x = self.post_attention_layernorm(x) + x = self.mlp(x) + return residual2 + x + + +class _ExpertCrossAttnLayer(nn.Module): + """ + Odd-indexed expert layer. + + Q from suffix expert hidden (720). K/V from cached VLM K/V re-projected + through expert k_proj/v_proj (320 → 320). RoPE on Q only with positions + 0..49 (independent of prefix positions — see modeling_smolvla.py:365). + """ + def __init__(self): + super().__init__() + H, KH, D = C.EXPERT_NUM_HEADS, C.EXPERT_NUM_KV_HEADS, C.EXPERT_HEAD_DIM + self.head_dim = D + self.num_heads = H + self.num_kv_heads = KH + + self.input_layernorm = _RMSNorm(C.EXPERT_HIDDEN, C.VLM_RMS_NORM_EPS) + self.q_proj = _col(C.EXPERT_HIDDEN, H * D, bias=False) # 720 -> 960 + self.k_proj = _col(C.EXPERT_KV_DIM, KH * D, bias=False) # 320 -> 320 + self.v_proj = _col(C.EXPERT_KV_DIM, KH * D, bias=False) # 320 -> 320 + self.o_proj = _row(H * D, C.EXPERT_HIDDEN, bias=False) # 960 -> 720 + self.post_attention_layernorm = _RMSNorm(C.EXPERT_HIDDEN, C.VLM_RMS_NORM_EPS) + self.mlp = _LlamaMLP(C.EXPERT_HIDDEN, C.EXPERT_INTERMEDIATE) + + def forward( + self, + suffix_hidden: torch.Tensor, # [B, 50, 720] + past_k: torch.Tensor, # [B, 241, 5, 64] — VLM cached K + past_v: torch.Tensor, # [B, 241, 5, 64] + suffix_cos: torch.Tensor, # [B, 50, 64] — positions 0..49 + suffix_sin: torch.Tensor, # [B, 50, 64] + attention_mask_2d: torch.Tensor,# [B, 50, 241] BOOL + ) -> torch.Tensor: + residual = suffix_hidden + x = self.input_layernorm(suffix_hidden) + + B, S = x.shape[:2] + H, KH, D = self.num_heads, self.num_kv_heads, self.head_dim + + q = self.q_proj(x).view(B, S, H, D) + q = _apply_rope(q, suffix_cos, suffix_sin) + + # Re-project past VLM K/V through expert k_proj/v_proj (320 -> 320) + past_k_flat = past_k.reshape(B, C.PREFIX_LEN, KH * D) + past_v_flat = past_v.reshape(B, C.PREFIX_LEN, KH * D) + k = self.k_proj(past_k_flat).view(B, C.PREFIX_LEN, KH, D) + v = self.v_proj(past_v_flat).view(B, C.PREFIX_LEN, KH, D) + # No RoPE on K (VLM K already had prefix RoPE applied during prefix pass) + + attn = _eager_gqa_attention(q, k, v, attention_mask_2d, D) + x = residual + self.o_proj(attn.to(suffix_hidden.dtype)) + + residual2 = x + x = self.post_attention_layernorm(x) + x = self.mlp(x) + return residual2 + x + + +# --------------------------------------------------------------------------- +# Suffix embedder + denoise step (compiled subgraph #3) +# --------------------------------------------------------------------------- + +class _SinusoidalTimestepEmbedder(nn.Module): + """ + Pure-Neuron sinusoidal positional embedding for the diffusion timestep. + + Instead of `torch.linspace` inside forward (which becomes a dynamic + constant and bloats NEFF compile-time path names), the period table is + computed once in __init__ and registered as a buffer. + """ + def __init__(self, dim: int, min_period: float, max_period: float): + super().__init__() + assert dim % 2 == 0, "Sinusoidal embedding dim must be even." + fraction = torch.linspace(0.0, 1.0, dim // 2, dtype=torch.float32) + period = min_period * (max_period / min_period) ** fraction + # angular_freq[i] = 2*pi / period[i] + ang_freq = (2.0 * math.pi) / period + self.register_buffer("angular_freq", ang_freq, persistent=False) + self.dim = dim + + def forward(self, time: torch.Tensor) -> torch.Tensor: + # time: [B] FP32 in [0, 1] -> emb: [B, dim] BF16 + # angular_freq: [dim/2] + x = time.unsqueeze(-1) * self.angular_freq.unsqueeze(0) # [B, dim/2] + emb = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) # [B, dim] + return emb.to(torch.bfloat16) + + +class SmolVLADenoiseStep(nn.Module): + """ + One Euler step of the action expert. + + Inputs are exactly the four tensors the compiled NEFF needs. CPU-side + Euler integration calls this with new noisy_actions on each step. + """ + def __init__(self): + super().__init__() + # Embed suffix: action_in_proj + sinusoidal time + action_time_mlp + self.action_in_proj = nn.Linear(C.MAX_ACTION_DIM, C.EXPERT_HIDDEN, bias=True) + self.action_time_mlp_in = nn.Linear(C.ACTION_TIME_MLP_IN_DIM, C.EXPERT_HIDDEN, bias=True) + self.action_time_mlp_out = nn.Linear(C.EXPERT_HIDDEN, C.EXPERT_HIDDEN, bias=True) + self.action_out_proj = nn.Linear(C.EXPERT_HIDDEN, C.MAX_ACTION_DIM, bias=True) + self.timestep_embedder = _SinusoidalTimestepEmbedder( + C.TIMESTEP_EMBED_DIM, C.TIMESTEP_MIN_PERIOD, C.TIMESTEP_MAX_PERIOD + ) + + # 16 expert layers: even idx self-attn, odd idx cross-attn + layers = [] + for i in range(C.EXPERT_NUM_LAYERS): + if i % C.SELF_ATTN_EVERY_N_LAYERS == 0: + layers.append(_ExpertSelfAttnLayer()) + else: + layers.append(_ExpertCrossAttnLayer()) + self.layers = nn.ModuleList(layers) + self.norm = _RMSNorm(C.EXPERT_HIDDEN, C.VLM_RMS_NORM_EPS) + + # RoPE caches for both layer types + self.rope = _RoPECache(C.EXPERT_HEAD_DIM, C.FULL_LEN, C.VLM_ROPE_THETA) + + # Cumsum-based block-attention pattern over the FULL sequence. + # prefix has [0]*(vis+lang) + [1]*1 (state starts a new block) + # suffix has [1]*50 (each suffix token starts a new block) + full_att_marks = torch.zeros(C.FULL_LEN, dtype=torch.int64) + full_att_marks[C.NUM_VISION_TOKENS_TOTAL + C.NUM_TEXT_TOKENS:] = 1 + self.register_buffer( + "full_att_marks", + full_att_marks.unsqueeze(0), # [1, FULL_LEN] + persistent=False, + ) + self.register_buffer( + "suffix_pad_const", + torch.ones(1, C.SUFFIX_LEN, dtype=torch.bool), + persistent=False, + ) + self.register_buffer( + "suffix_arange", + torch.arange(C.SUFFIX_LEN, dtype=torch.int64).unsqueeze(0), # [1, 50] + persistent=False, + ) + + def _embed_suffix( + self, noisy_actions: torch.Tensor, timestep: torch.Tensor + ) -> torch.Tensor: + # noisy_actions: [B, 50, 32] timestep: [B] fp32 + action_emb = self.action_in_proj(noisy_actions.to(torch.bfloat16)) # [B, 50, 720] + time_emb = self.timestep_embedder(timestep) # [B, 720] + time_emb = time_emb.unsqueeze(1).expand_as(action_emb) # [B, 50, 720] + cat = torch.cat([action_emb, time_emb], dim=-1) # [B, 50, 1440] + x = self.action_time_mlp_in(cat) + x = F.silu(x) + x = self.action_time_mlp_out(x) + return x # [B, 50, 720] + + def forward( + self, + noisy_actions: torch.Tensor, # [B, 50, 32] FP32 in (from CPU) + timestep: torch.Tensor, # [B] FP32 + prefix_keys: torch.Tensor, # [L, B, PREFIX_LEN, 5, 64] BF16 + prefix_values: torch.Tensor, # [L, B, PREFIX_LEN, 5, 64] BF16 + prefix_pad_mask: torch.Tensor, # [B, PREFIX_LEN] BOOL + ) -> torch.Tensor: # [B, 50, 32] FP32 + + B = noisy_actions.shape[0] + suffix = self._embed_suffix(noisy_actions, timestep) # [B, 50, hidden] + + # Suffix position_ids for self-attn = prefix_offset + 0..49 (per lerobot + # `position_ids = prefix_offsets + cumsum(suffix_pad_masks) - 1`). + # For cross-attn, RoPE on Q only with positions 0..49 (independent — + # see modeling_smolvla.py:365). + prefix_pad_b = prefix_pad_mask.to(torch.bool) + prefix_offset = prefix_pad_b.to(torch.int64).sum(dim=1, keepdim=True) # [B, 1] + self_pos = prefix_offset + self.suffix_arange.expand(B, -1) # [B, 50] + cross_pos = self.suffix_arange.expand(B, -1) # [B, 50] + + self_cos, self_sin = self.rope(self_pos) + cross_cos, cross_sin = self.rope(cross_pos) + + # Self-attn mask over [B, 50, FULL_LEN]: cumsum-block AND pad-2D. + full_pad = torch.cat( + [prefix_pad_b, self.suffix_pad_const.expand(B, -1)], dim=1, + ) # [B, FULL_LEN] + att_marks = self.full_att_marks.expand(B, -1) # [B, FULL_LEN] + cumsum_att = torch.cumsum(att_marks, dim=1) # [B, FULL_LEN] + att_2d = cumsum_att.unsqueeze(1) <= cumsum_att.unsqueeze(2) # [B, FULL_LEN, FULL_LEN] + pad_2d = full_pad.unsqueeze(1) & full_pad.unsqueeze(2) # [B, FULL_LEN, FULL_LEN] + full_mask = att_2d & pad_2d # [B, FULL_LEN, FULL_LEN] + self_mask = full_mask[:, C.PREFIX_LEN:, :] # [B, 50, FULL_LEN] + cross_mask = prefix_pad_b.unsqueeze(1).expand(B, C.SUFFIX_LEN, -1) # [B, 50, PREFIX_LEN] + + x = suffix + for i, layer in enumerate(self.layers): + past_k = prefix_keys[i] # [B, 241, 5, 64] + past_v = prefix_values[i] + if i % C.SELF_ATTN_EVERY_N_LAYERS == 0: + x = layer(x, past_k, past_v, self_cos, self_sin, self_mask) + else: + x = layer(x, past_k, past_v, cross_cos, cross_sin, cross_mask) + + x = self.norm(x) + # The HF reference upcasts to fp32 before action_out_proj. Linear in + # bf16 followed by fp32 cast preserves numeric accuracy adequately + # (action_out_proj is a single 720->32 projection at the end of 16 + # layers of bf16 attention) and keeps the Linear dtype-matched. + v_t = self.action_out_proj(x).to(torch.float32) + return v_t + + +# --------------------------------------------------------------------------- +# Wrappers for ModelBuilder compilation +# --------------------------------------------------------------------------- + +class SmolVLAPrefixWrapper(ModelWrapper): + tag = "prefix" + + def __init__(self, config: NeuronDenoisingConfig): + nn.Module.__init__(self) + super().__init__(config=config, model_cls=type(self)) + self.config = config + self.model = None + + def load_module(self): + self.model = SmolVLAPrefixModel().bfloat16().eval() + + def forward(self, vision_features, lang_token_ids, lang_mask, state): + return self.model(vision_features, lang_token_ids, lang_mask, state) + + def input_generator(self): + B = self.config.neuron_config.batch_size + return [( + torch.zeros(B, C.NUM_VISION_TOKENS_TOTAL, C.VLM_HIDDEN, dtype=torch.bfloat16), + torch.zeros(B, C.NUM_TEXT_TOKENS, dtype=torch.int32), + torch.ones(B, C.NUM_TEXT_TOKENS, dtype=torch.bool), + torch.zeros(B, C.MAX_STATE_DIM, dtype=torch.float32), + )] + + def load_state_dict(self, state_dict, strict=True, **kwargs): + return super().load_state_dict(state_dict, strict=strict, **kwargs) + + +class SmolVLADenoiseWrapper(ModelWrapper): + tag = "denoise_step" + + def __init__(self, config: NeuronDenoisingConfig): + nn.Module.__init__(self) + super().__init__(config=config, model_cls=type(self)) + self.config = config + self.model = None + + def load_module(self): + self.model = SmolVLADenoiseStep().bfloat16().eval() + + def forward(self, noisy_actions, timestep, prefix_keys, prefix_values, prefix_pad_mask): + return self.model(noisy_actions, timestep, prefix_keys, prefix_values, prefix_pad_mask) + + def input_generator(self): + B = self.config.neuron_config.batch_size + return [( + torch.zeros(B, C.ACTION_CHUNK_SIZE, C.MAX_ACTION_DIM, dtype=torch.float32), + torch.zeros(B, dtype=torch.float32), + torch.zeros(C.VLM_NUM_LAYERS, B, C.PREFIX_LEN, C.VLM_NUM_KV_HEADS, C.VLM_HEAD_DIM, dtype=torch.bfloat16), + torch.zeros(C.VLM_NUM_LAYERS, B, C.PREFIX_LEN, C.VLM_NUM_KV_HEADS, C.VLM_HEAD_DIM, dtype=torch.bfloat16), + torch.ones(B, C.PREFIX_LEN, dtype=torch.bool), + )] + + def load_state_dict(self, state_dict, strict=True, **kwargs): + return super().load_state_dict(state_dict, strict=strict, **kwargs) diff --git a/contrib/models/SmolVLA-Libero/src/modeling_smolvla_vision.py b/contrib/models/SmolVLA-Libero/src/modeling_smolvla_vision.py new file mode 100644 index 00000000..21602b03 --- /dev/null +++ b/contrib/models/SmolVLA-Libero/src/modeling_smolvla_vision.py @@ -0,0 +1,271 @@ +""" +SmolVLA vision subgraph +======================= + +Compiled subgraph #1 of three: + + Input: pixel_values [B, 3, 512, 512] BF16 + Output: vision_features [B, 64, 960] BF16 + +Pipeline (all on Neuron): + + SigLIPVisionTransformer (12 layers, hidden=768) + → patch_embedding (Conv2d 16x16, stride 16) → [B, 1024, 768] + → + position_embedding[1024, 768] + → 12x SigLIP encoder layer (post-LN architecture, GELU MLP, eager attn) + → post_layernorm + + SmolVLMConnector + → pixel_shuffle x4 [B, 1024, 768] → [B, 64, 12288] + → modality_projection.proj Linear(12288, 960, bias=False) + +The Neuron compiler accepts Conv2d in eval mode for the patch embedding (it +unfolds internally), so we keep the Conv2d as in the HF source rather than +pre-unfolding to keep weight mapping trivial. + +DEVIATIONS from "everything on Neuron": + - None for this subgraph. Image preprocessing (PIL → resize → normalize) + runs on CPU, but that's data-loading, not model compute. + +Per-camera vs all-cameras: + This subgraph runs once per camera (B=1, single image at a time). The + caller stacks the three outputs to get [B, 192, 960] before passing into + the prefix subgraph. This avoids a 3x increase in patch-embedding tile size + at compile time. +""" + +from __future__ import annotations + +from typing import List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from neuronx_distributed_inference.models.model_wrapper import ModelWrapper + +import config_constants as C +from neuron_action_head_base import NeuronDenoisingConfig + + +# --------------------------------------------------------------------------- +# Parallel-linear helpers (TP=1 on this instance, see config_constants.py) +# --------------------------------------------------------------------------- + +def _col(in_f: int, out_f: int, bias: bool = True) -> nn.Module: + if parallel_state.model_parallel_is_initialized(): + return ColumnParallelLinear( + in_f, out_f, + bias=bias, + gather_output=False, + dtype=torch.bfloat16, + tensor_model_parallel_group=parallel_state.get_tensor_model_parallel_group(), + ) + return nn.Linear(in_f, out_f, bias=bias) + + +def _row(in_f: int, out_f: int, bias: bool = True) -> nn.Module: + if parallel_state.model_parallel_is_initialized(): + return RowParallelLinear( + in_f, out_f, + bias=bias, + input_is_parallel=True, + dtype=torch.bfloat16, + tensor_model_parallel_group=parallel_state.get_tensor_model_parallel_group(), + ) + return nn.Linear(in_f, out_f, bias=bias) + + +# --------------------------------------------------------------------------- +# SigLIP encoder +# --------------------------------------------------------------------------- + +class SigLIPAttention(nn.Module): + def __init__(self): + super().__init__() + H, D = C.VISION_HIDDEN, C.VISION_HEAD_DIM + self.num_heads = C.VISION_NUM_HEADS + self.head_dim = D + self.scale = D ** -0.5 + # SigLIP attention has bias on q/k/v and out_proj + self.q_proj = _col(H, H, bias=True) + self.k_proj = _col(H, H, bias=True) + self.v_proj = _col(H, H, bias=True) + self.out_proj = _row(H, H, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, _ = x.shape + H, D = self.num_heads, self.head_dim + q = self.q_proj(x).view(B, N, H, D).transpose(1, 2) # [B, H, N, D] + k = self.k_proj(x).view(B, N, H, D).transpose(1, 2) + v = self.v_proj(x).view(B, N, H, D).transpose(1, 2) + attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale + attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(x.dtype) + out = torch.matmul(attn, v) # [B, H, N, D] + out = out.transpose(1, 2).reshape(B, N, H * D) + return self.out_proj(out) + + +class SigLIPMLP(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = _col(C.VISION_HIDDEN, C.VISION_INTERMEDIATE, bias=True) + self.fc2 = _row(C.VISION_INTERMEDIATE, C.VISION_HIDDEN, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc2(F.gelu(self.fc1(x), approximate="tanh")) + + +class SigLIPEncoderLayer(nn.Module): + """Pre-norm transformer block (matches HF SiglipEncoderLayer).""" + def __init__(self): + super().__init__() + self.layer_norm1 = nn.LayerNorm(C.VISION_HIDDEN, eps=C.VISION_LAYER_NORM_EPS) + self.self_attn = SigLIPAttention() + self.layer_norm2 = nn.LayerNorm(C.VISION_HIDDEN, eps=C.VISION_LAYER_NORM_EPS) + self.mlp = SigLIPMLP() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.self_attn(self.layer_norm1(x)) + x = x + self.mlp(self.layer_norm2(x)) + return x + + +class SmolVLMVisionModel(nn.Module): + """SigLIP vision tower + post-layernorm.""" + def __init__(self): + super().__init__() + self.patch_embedding = nn.Conv2d( + in_channels=3, + out_channels=C.VISION_HIDDEN, + kernel_size=C.VISION_PATCH_SIZE, + stride=C.VISION_PATCH_SIZE, + padding="valid", + ) + self.position_embedding = nn.Embedding( + C.VISION_NUM_PATCHES, C.VISION_HIDDEN + ) + self.register_buffer( + "position_ids", + torch.arange(C.VISION_NUM_PATCHES).unsqueeze(0), + persistent=False, + ) + self.layers = nn.ModuleList( + [SigLIPEncoderLayer() for _ in range(C.VISION_NUM_LAYERS)] + ) + self.post_layernorm = nn.LayerNorm(C.VISION_HIDDEN, eps=C.VISION_LAYER_NORM_EPS) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + # [B, 3, 512, 512] -> [B, 768, 32, 32] -> [B, 1024, 768] + x = self.patch_embedding(pixel_values) + x = x.flatten(2).transpose(1, 2) + x = x + self.position_embedding(self.position_ids) + for layer in self.layers: + x = layer(x) + x = self.post_layernorm(x) + return x # [B, 1024, 768] + + +class SmolVLMConnector(nn.Module): + """Pixel-shuffle 4x then linear projection to VLM hidden size.""" + def __init__(self): + super().__init__() + self.scale_factor = C.PIXEL_SHUFFLE_SCALE + # HF naming: connector.modality_projection.proj + self.modality_projection_proj = _col( + C.CONNECTOR_INPUT_DIM, C.VLM_HIDDEN, bias=False + ) + + def pixel_shuffle(self, x: torch.Tensor) -> torch.Tensor: + # x: [B, 1024, 768] -> [B, 64, 12288] + B, N, D = x.shape + H = W = int(N ** 0.5) # 32 + s = self.scale_factor # 4 + x = x.view(B, H, W, D) + x = x.view(B, H, W // s, D * s) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(B, W // s, H // s, D * s * s) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(B, (H // s) * (W // s), D * s * s) + return x + + def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor: + x = self.pixel_shuffle(image_hidden_states) + return self.modality_projection_proj(x) # [B, 64, 960] + + +# --------------------------------------------------------------------------- +# Combined vision encoder model — the thing that gets compiled to a NEFF +# --------------------------------------------------------------------------- + +class SmolVLAVisionEncoder(nn.Module): + """ + pixel_values [B, 3, 512, 512] BF16 -> vision_features [B, 64, 960] BF16 + + The SigLIP image embeddings are scaled by sqrt(hidden_dim) AFTER the + connector to match `embed_prefix` in the HF source (see modeling_smolvla.py + line 654-659). The scale is folded into this subgraph so the prefix + subgraph receives ready-to-concat embeddings. + """ + def __init__(self): + super().__init__() + self.vision_model = SmolVLMVisionModel() + self.connector = SmolVLMConnector() + # img_emb * sqrt(VLM_HIDDEN) — pre-compute scalar + self.register_buffer( + "vlm_hidden_sqrt", + torch.tensor(C.VLM_HIDDEN ** 0.5, dtype=torch.bfloat16), + persistent=False, + ) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + x = self.vision_model(pixel_values) # [B, 1024, 768] + x = self.connector(x) # [B, 64, 960] + x = x * self.vlm_hidden_sqrt + return x + + +# --------------------------------------------------------------------------- +# Vision wrapper for ModelBuilder compilation +# --------------------------------------------------------------------------- + +class SmolVLAVisionWrapper(ModelWrapper): + """ + Wraps SmolVLAVisionEncoder for compilation via NxDI's ModelBuilder. + + Takes the same minimal NeuronDenoisingConfig used by the action head — its + fields are a strict superset of what ModelWrapper needs. + """ + + tag = "vision_encoder" + + def __init__(self, config: NeuronDenoisingConfig): + nn.Module.__init__(self) + super().__init__(config=config, model_cls=type(self)) + self.config = config + self.model = None # constructed in load_module() — see nxdi_background.md + + def load_module(self): + # parallel_state is active here under ModelBuilder + self.model = SmolVLAVisionEncoder() + self.model = self.model.bfloat16().eval() + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + return self.model(pixel_values) + + def input_generator(self) -> List[Tuple[torch.Tensor, ...]]: + return [( + torch.zeros( + self.config.neuron_config.batch_size, 3, + C.VISION_IMAGE_SIZE, C.VISION_IMAGE_SIZE, + dtype=torch.bfloat16, + ), + )] + + def load_state_dict(self, state_dict, strict=True, **kwargs): + return super().load_state_dict(state_dict, strict=strict, **kwargs) diff --git a/contrib/models/SmolVLA-Libero/src/neuron_action_head_base.py b/contrib/models/SmolVLA-Libero/src/neuron_action_head_base.py new file mode 100644 index 00000000..da26476f --- /dev/null +++ b/contrib/models/SmolVLA-Libero/src/neuron_action_head_base.py @@ -0,0 +1,92 @@ +""" +neuron_action_head_base +======================= + +Minimal config shim used by the three SmolVLA subgraphs. + +SmolVLA is a flow-matching VLA (vision-language-action) policy, not a CausalLM, +so it cannot reuse NxDI's stock ``InferenceConfig`` — that config carries +LLM-specific fields (KV-cache layout, sequence buckets, vocab size, etc.) that +have no meaning here. + +``NeuronDenoisingConfig`` exposes the small set of attributes that +``ModelWrapper.__init__()`` actually reads (``neuron_config.torch_dtype``, +``neuron_config.tp_degree``, ``neuron_config.batch_size``, ``pad_token_id``) +plus a handful of action-head-specific fields used by the per-subgraph wrappers +in ``modeling_smolvla_vision.py`` and ``modeling_smolvla_text.py``. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import torch + + +COMPILED_MODEL_FILE_NAME = "model.pt" + + +class NeuronDenoisingConfig: + """Minimal config that satisfies ``ModelWrapper.__init__()``. + + LLM-specific fields are stubbed as ``None`` / ``False`` / ``0``. Action-head + fields (``action_chunk_size`` etc.) are exposed at the top level so each + subgraph wrapper can pull static input shapes from them. + """ + + def __init__( + self, + batch_size: int, + tp_degree: int, + action_chunk_size: int, + action_dim: int, + num_conditioning_tokens: int, + conditioning_hidden_size: int, + timestep_embed_dim: int, + torch_dtype: torch.dtype = torch.bfloat16, + ): + self.neuron_config = SimpleNamespace( + # Core — required by ModelWrapper + torch_dtype=torch_dtype, + tp_degree=tp_degree, + batch_size=batch_size, + # Compiler tuning + cc_pipeline_tiling_factor=1, + logical_nc_config=1, + # LLM features — all disabled (we are not a CausalLM) + is_block_kv_layout=False, + is_prefix_caching=False, + is_medusa=False, + token_generation_batches=None, + async_mode=False, + scratchpad_page_size=None, + attn_block_tkg_nki_kernel_enabled=False, + enable_long_context_mode=False, + layer_boundary_markers=False, + dma_order_config=None, + enable_spill_reload_dge=False, + target=None, + quantized=False, + quantization_dtype=None, + kv_cache_quant=False, + quantized_mlp_kernel_enabled=False, + activation_quantization_type=None, + enable_output_completion_notifications=False, + # Weight loading + save_sharded_checkpoint=True, + start_rank_id=0, + local_ranks_size=tp_degree, + cast_type="config", + # Parallelism + pp_degree=1, + ep_degree=1, + world_size=tp_degree, + ) + self.pad_token_id = 0 + + # Action-head specific + self.action_chunk_size = action_chunk_size + self.action_dim = action_dim + self.num_conditioning_tokens = num_conditioning_tokens + self.conditioning_hidden_size = conditioning_hidden_size + self.timestep_embed_dim = timestep_embed_dim diff --git a/contrib/models/SmolVLA-Libero/src/run_inference.py b/contrib/models/SmolVLA-Libero/src/run_inference.py new file mode 100644 index 00000000..f2841af9 --- /dev/null +++ b/contrib/models/SmolVLA-Libero/src/run_inference.py @@ -0,0 +1,146 @@ +""" +Entry point for SmolVLA inference on AWS Trainium. + +Usage: + cd contrib/models/SmolVLA-Libero/src + + # First time only — compile the 3 NEFFs (one-shot, takes ~90s total) + python run_inference.py --action compile \\ + --hf-checkpoint /path/to/HuggingFaceVLA/smolvla_libero \\ + --neff-dir /path/to/output_neffs + + # Run inference (load NEFFs + benchmark with synthetic inputs) + python run_inference.py --action run \\ + --hf-checkpoint /path/to/HuggingFaceVLA/smolvla_libero \\ + --neff-dir /path/to/output_neffs + +What runs where: + On Neuron (3 compiled NEFFs): + 1. SigLIP vision encoder (12 layers) + connector + scaling + 2. VLM prefix decoder (16 SmolLM layers, returns full KV cache) + 3. Per-step action expert denoiser (16 expert layers, interleaved + self/cross attn over the cached prefix KV) — including the + sinusoidal timestep embedding + + On CPU (necessarily; flagged as deviations from "everything on Neuron"): + - The 10-step Euler loop (Python control flow, can't fuse into + a single static-shape NEFF without unrolling 10x larger graphs) + - Tokenization and image preprocessing (data loading, not model + compute — moving them to Neuron has no perf benefit) + +Hardware constraints (also flagged): + - tp_degree=1 because num_attention_heads=15, num_kv_heads=5 — neither + divides cleanly into the 4 Neuron cores on trn3pd98.3xlarge. + Production NxDI parallel primitives (ColumnParallelLinear / + RowParallelLinear / ParallelEmbedding) are still used so the code is + portable to instances with more divisor-friendly head counts. On this + instance, 3 of 4 cores idle. The model (450M params, ~900 MB BF16) + fits in one core's HBM with vast headroom. +""" + +from __future__ import annotations + +import argparse +import logging +import os +import time + +import torch + +# Make sibling modules importable when running as a script: +# ``python run_inference.py ...`` from inside ``src/`` +# or ``python contrib/models/SmolVLA-Libero/src/run_inference.py ...`` +import sys +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +import config_constants as C +from modeling_smolvla import SmolVLAPolicy + +logger = logging.getLogger("smolvla.run_inference") + + +def _make_dummy_inputs(batch_size: int = 1): + """Synthetic inputs — same shapes the real policy expects.""" + images = [ + torch.randn(batch_size, 3, C.VISION_IMAGE_SIZE, C.VISION_IMAGE_SIZE, + dtype=torch.bfloat16) + for _ in range(C.NUM_CAMERAS) + ] + lang = torch.randint( + 0, C.VLM_VOCAB_SIZE, + (batch_size, C.NUM_TEXT_TOKENS), + dtype=torch.int32, + ) + state = torch.zeros(batch_size, C.MAX_STATE_DIM, dtype=torch.float32) + return images, lang, state + + +def cmd_compile(args): + policy = SmolVLAPolicy( + hf_checkpoint_dir=args.hf_checkpoint, + tp_degree=args.tp_degree, + batch_size=args.batch_size, + ) + t0 = time.monotonic() + policy.compile(args.neff_dir) + print(f"All 3 NEFFs compiled in {time.monotonic()-t0:.1f}s -> {args.neff_dir}") + + +def cmd_run(args): + policy = SmolVLAPolicy( + hf_checkpoint_dir=args.hf_checkpoint, + tp_degree=args.tp_degree, + batch_size=args.batch_size, + ) + print("Loading 3 NEFFs to Neuron...") + t0 = time.monotonic() + policy.load(args.neff_dir) + print(f"Loaded in {time.monotonic()-t0:.1f}s") + + images, lang, state = _make_dummy_inputs(args.batch_size) + + print("Cold inference (first call includes lazy device init)...") + t0 = time.monotonic() + chunk = policy.generate(images, lang, state, num_steps=args.num_steps) + cold_ms = (time.monotonic() - t0) * 1000 + print( + f" cold: {cold_ms:.1f} ms shape={tuple(chunk.shape)} " + f"hasNaN={torch.isnan(chunk).any().item()} " + f"mean={chunk.mean().item():.4f} std={chunk.std().item():.4f}" + ) + + print(f"Warm benchmark — {args.bench_iters} iterations:") + timings = [] + for i in range(args.bench_iters): + t0 = time.monotonic() + chunk = policy.generate(images, lang, state, num_steps=args.num_steps) + timings.append((time.monotonic() - t0) * 1000) + timings.sort() + p50 = timings[len(timings) // 2] + p99 = timings[int(len(timings) * 0.99)] + print(f" p50={p50:.1f} ms p99={p99:.1f} ms min={timings[0]:.1f} ms max={timings[-1]:.1f} ms") + + +def main(): + logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s: %(message)s") + parser = argparse.ArgumentParser(description="SmolVLA Trainium inference") + parser.add_argument("--action", choices=["compile", "run"], required=True) + parser.add_argument("--hf-checkpoint", required=True, + help="Path to HuggingFaceVLA/smolvla_libero HF snapshot dir") + parser.add_argument("--neff-dir", required=True, + help="Where to save (or load) the 3 compiled NEFFs") + parser.add_argument("--tp-degree", type=int, default=C.DEFAULT_TP_DEGREE) + parser.add_argument("--batch-size", type=int, default=C.BATCH_SIZE) + parser.add_argument("--num-steps", type=int, default=C.NUM_DENOISE_STEPS, + help="Number of Euler steps in the denoising loop") + parser.add_argument("--bench-iters", type=int, default=20) + args = parser.parse_args() + + if args.action == "compile": + cmd_compile(args) + else: + cmd_run(args) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/SmolVLA-Libero/src/weight_mapping.py b/contrib/models/SmolVLA-Libero/src/weight_mapping.py new file mode 100644 index 00000000..9a965ff4 --- /dev/null +++ b/contrib/models/SmolVLA-Libero/src/weight_mapping.py @@ -0,0 +1,151 @@ +""" +HF checkpoint -> Neuron subgraph weight mapping for SmolVLA. + +The lerobot/smolvla_base checkpoint has this top-level layout (500 keys): + + model.vlm_with_expert.vlm.model.vision_model.* (SigLIP + post_layernorm) + model.vlm_with_expert.vlm.model.connector.modality_projection.proj.* + model.vlm_with_expert.vlm.model.text_model.* (16-layer SmolLM) + model.vlm_with_expert.lm_expert.* (16-layer expert) + model.action_in_proj / action_out_proj + model.action_time_mlp_in / action_time_mlp_out + model.state_proj + model.vlm_with_expert.vlm.lm_head.weight (unused at inference) + +This file slices the flat HF state-dict into three per-subgraph state-dicts +matching the Neuron module trees. + +Sharded outputs: + vision_state_dict -> SmolVLAVisionEncoder + prefix_state_dict -> SmolVLAPrefixModel + denoise_state_dict -> SmolVLADenoiseStep +""" + +from __future__ import annotations + +import os +from typing import Dict, Tuple + +import torch +from safetensors.torch import load_file + +import config_constants as C + + +# --------------------------------------------------------------------------- +# Top-level loader +# --------------------------------------------------------------------------- + +def load_hf_state_dict(checkpoint_dir: str) -> Dict[str, torch.Tensor]: + """Load the lerobot/smolvla_base safetensors into a flat dict.""" + sd_path = os.path.join(checkpoint_dir, "model.safetensors") + if not os.path.isfile(sd_path): + raise FileNotFoundError(f"No model.safetensors at {sd_path}") + return load_file(sd_path) + + +def split_hf_state_dict( + hf_sd: Dict[str, torch.Tensor], +) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + """Return (vision_sd, prefix_sd, denoise_sd) — three per-subgraph dicts.""" + return ( + _build_vision_sd(hf_sd), + _build_prefix_sd(hf_sd), + _build_denoise_sd(hf_sd), + ) + + +# --------------------------------------------------------------------------- +# Vision encoder mapping +# --------------------------------------------------------------------------- + +def _build_vision_sd(hf: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + HF prefix: model.vlm_with_expert.vlm.model.vision_model.<...> + model.vlm_with_expert.vlm.model.connector.modality_projection.proj.<...> + Neuron: vision_model.<...> (SigLIP) + connector.modality_projection_proj.<...> (renamed: '.' to '_') + """ + out: Dict[str, torch.Tensor] = {} + vp = "model.vlm_with_expert.vlm.model.vision_model." + cp = "model.vlm_with_expert.vlm.model.connector.modality_projection.proj." + + # --- vision tower --- + for k, v in hf.items(): + if k.startswith(vp): + tail = k[len(vp):] + # encoder.layers.N.<...> -> layers.N.<...> + tail = tail.replace("encoder.layers.", "layers.") + # embeddings.patch_embedding.* / embeddings.position_embedding.* + # -> patch_embedding.* / position_embedding.* + tail = tail.replace("embeddings.", "") + out[f"vision_model.{tail}"] = v + + # --- connector --- + for k, v in hf.items(): + if k.startswith(cp): + tail = k[len(cp):] # 'weight' + out[f"connector.modality_projection_proj.{tail}"] = v + return out + + +# --------------------------------------------------------------------------- +# Prefix (VLM text decoder) mapping +# --------------------------------------------------------------------------- + +def _build_prefix_sd(hf: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + HF prefix: model.vlm_with_expert.vlm.model.text_model.<...> + model.state_proj.<...> + Neuron: embed_tokens.weight + layers.N. + norm.weight + state_proj.<...> + """ + out: Dict[str, torch.Tensor] = {} + tp = "model.vlm_with_expert.vlm.model.text_model." + for k, v in hf.items(): + if k.startswith(tp): + tail = k[len(tp):] + out[tail] = v + # state_proj + out["state_proj.weight"] = hf["model.state_proj.weight"] + out["state_proj.bias"] = hf["model.state_proj.bias"] + return out + + +# --------------------------------------------------------------------------- +# Denoise (action expert) mapping +# --------------------------------------------------------------------------- + +def _build_denoise_sd(hf: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + HF prefix: model.vlm_with_expert.lm_expert.<...> + model.action_in_proj / action_out_proj + model.action_time_mlp_in / action_time_mlp_out + Neuron: layers.N.<...> (renamed: self_attn._proj -> _proj at layer level) + norm.weight + action_in_proj.<...> + action_out_proj.<...> + action_time_mlp_in.<...> + action_time_mlp_out.<...> + """ + out: Dict[str, torch.Tensor] = {} + ep = "model.vlm_with_expert.lm_expert." + + for k, v in hf.items(): + if k.startswith(ep): + tail = k[len(ep):] + # Strip the "self_attn." segment so attention projections sit + # directly on the layer module (matches the flatter layer module + # _ExpertSelfAttnLayer / _ExpertCrossAttnLayer). + tail = tail.replace("self_attn.", "") + out[tail] = v + + # action / time MLP / out_proj + for name in ("action_in_proj", "action_out_proj", + "action_time_mlp_in", "action_time_mlp_out"): + out[f"{name}.weight"] = hf[f"model.{name}.weight"] + out[f"{name}.bias"] = hf[f"model.{name}.bias"] + + return out diff --git a/contrib/models/SmolVLA-Libero/test/__init__.py b/contrib/models/SmolVLA-Libero/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/SmolVLA-Libero/test/integration/__init__.py b/contrib/models/SmolVLA-Libero/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/SmolVLA-Libero/test/integration/test_model.py b/contrib/models/SmolVLA-Libero/test/integration/test_model.py new file mode 100644 index 00000000..cffcb882 --- /dev/null +++ b/contrib/models/SmolVLA-Libero/test/integration/test_model.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python3 +""" +Integration tests for the SmolVLA-Libero NeuronX port. + +Tests: + + * ``test_smoke_synthetic_chunk`` — one forward pass through all three + NEFFs with synthetic inputs; checks shape, finiteness, non-zero variance. + * ``test_warm_latency`` — light p50 latency sanity bound. + * ``test_lerobot_cpu_neuron_parity`` — Neuron vs upstream lerobot CPU + action-chunk parity. Loads ``lerobot.SmolVLAPolicy`` from the same HF + checkpoint, runs a CPU forward with identical inputs and identical + seeded initial noise, and asserts cos-sim ≥ 0.99 against the Neuron + output. This is the SmolVLA equivalent of the logit validation NxDI + uses for CausalLM contrib models — it validates that the Neuron port + reproduces the reference implementation, not just self-consistency. + Skipped automatically if ``lerobot`` is not installed. + +Required environment variables: + + SMOLVLA_CKPT Path to the HuggingFaceVLA/smolvla_libero snapshot directory. + SMOLVLA_NEFF Output directory for the three compiled NEFFs. If it does + not yet contain compiled artifacts, the test will compile + them (≈ 90 s on trn3pd98.3xlarge). + +Run: + + pytest contrib/models/SmolVLA-Libero/test/integration/test_model.py --capture=tee-sys + +or directly: + + cd contrib/models/SmolVLA-Libero + python test/integration/test_model.py +""" + +from __future__ import annotations + +import os +import sys +import time +from pathlib import Path + +import pytest +import torch + +# Make ``src/`` importable without a package install. +_SRC = Path(__file__).resolve().parents[2] / "src" +if str(_SRC) not in sys.path: + sys.path.insert(0, str(_SRC)) + +import config_constants as C # noqa: E402 +from modeling_smolvla import ( # noqa: E402 + DENOISE_NEFF_SUBDIR, + PREFIX_NEFF_SUBDIR, + VISION_NEFF_SUBDIR, + SmolVLAPolicy, +) +from neuron_action_head_base import COMPILED_MODEL_FILE_NAME # noqa: E402 + + +# --------------------------------------------------------------------------- +# Configuration via env vars +# --------------------------------------------------------------------------- + +CKPT_ENV = "SMOLVLA_CKPT" +NEFF_ENV = "SMOLVLA_NEFF" + + +def _require_env(name: str) -> str: + value = os.environ.get(name) + if not value: + pytest.skip(f"{name} is not set; skipping SmolVLA integration test") + return value + + +def _all_neffs_present(neff_root: str) -> bool: + for sub in (VISION_NEFF_SUBDIR, PREFIX_NEFF_SUBDIR, DENOISE_NEFF_SUBDIR): + if not (Path(neff_root) / sub / COMPILED_MODEL_FILE_NAME).is_file(): + return False + return True + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def hf_checkpoint() -> str: + return _require_env(CKPT_ENV) + + +@pytest.fixture(scope="module") +def neff_dir() -> str: + return _require_env(NEFF_ENV) + + +@pytest.fixture(scope="module") +def policy(hf_checkpoint: str, neff_dir: str) -> SmolVLAPolicy: + p = SmolVLAPolicy( + hf_checkpoint_dir=hf_checkpoint, + tp_degree=C.DEFAULT_TP_DEGREE, + batch_size=C.BATCH_SIZE, + ) + if not _all_neffs_present(neff_dir): + t0 = time.monotonic() + p.compile(neff_dir) + print(f"[smolvla] compile: {time.monotonic() - t0:.1f}s -> {neff_dir}") + t0 = time.monotonic() + p.load(neff_dir) + print(f"[smolvla] load: {time.monotonic() - t0:.1f}s") + return p + + +# --------------------------------------------------------------------------- +# Synthetic inputs (deterministic) +# --------------------------------------------------------------------------- + + +def _make_synthetic_inputs(batch_size: int, seed: int = 0): + g = torch.Generator().manual_seed(seed) + images = [ + torch.randn( + batch_size, + 3, + C.VISION_IMAGE_SIZE, + C.VISION_IMAGE_SIZE, + generator=g, + dtype=torch.bfloat16, + ) + for _ in range(C.NUM_CAMERAS) + ] + lang = torch.randint( + 0, + C.VLM_VOCAB_SIZE, + (batch_size, C.NUM_TEXT_TOKENS), + generator=g, + dtype=torch.int32, + ) + state = torch.zeros(batch_size, C.MAX_STATE_DIM, dtype=torch.float32) + return images, lang, state + + +# --------------------------------------------------------------------------- +# CPU reference (lerobot) +# --------------------------------------------------------------------------- + + +def _load_lerobot_reference_policy(hf_checkpoint: str): + """Load the upstream ``lerobot`` SmolVLA policy from the same checkpoint. + + This is the *reference implementation* — the model the NxDI port is + expected to match. We compare the Neuron output against this CPU + forward pass, which is the SmolVLA equivalent of the logit validation + NxDI uses for CausalLM contrib models. + """ + lerobot = pytest.importorskip( + "lerobot", + reason="lerobot is required for the CPU reference accuracy test", + ) + from lerobot.policies.smolvla.modeling_smolvla import ( + SmolVLAPolicy as LerobotSmolVLAPolicy, + ) + + pol = LerobotSmolVLAPolicy.from_pretrained(hf_checkpoint).cpu().eval() + return pol + + +def _cos_sim(a: torch.Tensor, b: torch.Tensor) -> float: + af = a.flatten().to(torch.float32) + bf = b.flatten().to(torch.float32) + return torch.nn.functional.cosine_similarity(af, bf, dim=0).item() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_smoke_synthetic_chunk(policy: SmolVLAPolicy): + """One forward pass through all three NEFFs with synthetic inputs.""" + images, lang, state = _make_synthetic_inputs(batch_size=C.BATCH_SIZE) + chunk = policy.generate(images, lang, state) + + expected = (C.BATCH_SIZE, C.ACTION_CHUNK_SIZE, C.MAX_ACTION_DIM) + assert tuple(chunk.shape) == expected, ( + f"Expected action chunk shape {expected}, got {tuple(chunk.shape)}" + ) + + assert torch.isfinite(chunk).all(), "action chunk contains NaN or Inf" + assert chunk.std().item() > 0.0, "action chunk has zero variance — graph likely failed silently" + + +def test_warm_latency(policy: SmolVLAPolicy): + """Light p50 latency check — sanity bound, not a benchmark.""" + images, lang, state = _make_synthetic_inputs(batch_size=C.BATCH_SIZE, seed=1) + + # Warm-up + policy.generate(images, lang, state) + + timings_ms = [] + for _ in range(5): + t0 = time.monotonic() + policy.generate(images, lang, state) + timings_ms.append((time.monotonic() - t0) * 1000.0) + timings_ms.sort() + p50 = timings_ms[len(timings_ms) // 2] + print(f"[smolvla] warm p50 latency: {p50:.1f} ms over {len(timings_ms)} iters") + + # Generous upper bound — the full pipeline runs in ~65 ms warm on + # trn3pd98.3xlarge. 1 s catches "something is dramatically wrong" without + # being flaky on slower hardware or under load. + assert p50 < 1000.0, f"warm p50 latency unexpectedly high: {p50:.1f} ms" + + +def test_lerobot_cpu_neuron_parity(policy: SmolVLAPolicy, hf_checkpoint: str): + """Neuron vs upstream lerobot CPU action-chunk parity (NxDI accuracy check). + + The lerobot ``SmolVLAPolicy`` is the reference implementation the NxDI + port targets. We load it from the same checkpoint, run a CPU forward + with identical inputs and identical initial noise, and assert the Neuron + output matches via cosine similarity and mean abs diff. + """ + B = C.BATCH_SIZE + + # Synthetic inputs — fp32 floats for image pixels so they feed both paths + # cleanly. Seeded ``Generator`` so each test run is reproducible. + g = torch.Generator().manual_seed(2) + images_fp32 = [ + torch.randn( + B, 3, C.VISION_IMAGE_SIZE, C.VISION_IMAGE_SIZE, + generator=g, dtype=torch.float32, + ) + for _ in range(C.NUM_CAMERAS) + ] + lang = torch.randint( + 0, C.VLM_VOCAB_SIZE, (B, C.NUM_TEXT_TOKENS), + generator=g, dtype=torch.long, + ) + lang_mask = torch.ones(B, C.NUM_TEXT_TOKENS, dtype=torch.bool) + state = torch.zeros(B, C.MAX_STATE_DIM, dtype=torch.float32) + + # Shared initial noise — fed to both paths so the only numerical + # difference is the model implementation, not the random starting point. + noise = torch.randn( + B, C.ACTION_CHUNK_SIZE, C.MAX_ACTION_DIM, + generator=torch.Generator().manual_seed(123), dtype=torch.float32, + ) + + # --- Reference: lerobot CPU --- + lerobot_pol = _load_lerobot_reference_policy(hf_checkpoint) + img_masks = [torch.ones(B, dtype=torch.bool) for _ in images_fp32] + with torch.no_grad(): + chunk_cpu = lerobot_pol.model.sample_actions( + images=images_fp32, + img_masks=img_masks, + lang_tokens=lang, + lang_masks=lang_mask, + state=state, + noise=noise, + ) + + # --- Neuron port --- + images_neuron = [img.to(torch.bfloat16) for img in images_fp32] + chunk_neuron = policy.generate( + images_neuron, + lang.to(torch.int32), + state, + lang_mask=lang_mask, + noise=noise, + ).cpu() + + cos = _cos_sim(chunk_neuron, chunk_cpu) + max_abs = (chunk_neuron - chunk_cpu).abs().max().item() + mean_abs = (chunk_neuron - chunk_cpu).abs().mean().item() + print( + f"[smolvla] Neuron vs lerobot CPU parity: " + f"cos_sim={cos:.6f} max_abs={max_abs:.4f} mean_abs={mean_abs:.4f}" + ) + + # Cosine similarity is the primary acceptance criterion: it is invariant + # to bf16 magnitude noise and accumulated rounding across the 10 Euler + # steps. The README documents 0.9999 vs lerobot CPU on real LIBERO + # inputs; we use a slightly looser bound (0.99) for synthetic inputs + # which can amplify low-magnitude divergence. Mean abs diff catches + # systemic divergence. + assert cos >= 0.99, f"Neuron vs lerobot cos_sim too low: {cos:.6f}" + assert mean_abs < 0.05, f"Neuron vs lerobot mean abs diff too high: {mean_abs:.4f}" + + +# --------------------------------------------------------------------------- +# Allow `python test/integration/test_model.py` invocation +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + sys.exit(pytest.main([__file__, "-v", "--capture=tee-sys"])) diff --git a/contrib/models/SmolVLA-Libero/test/unit/__init__.py b/contrib/models/SmolVLA-Libero/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b