Train and deploy Vision-Language-Action models natively on Apple Silicon.
6 lines of Python. Image in, robot action out. Runs entirely on your Mac.
from mlx_vla.models.modeling_vla import VLAForAction
from mlx_vla.inference.pipeline import VLAPipeline
model = VLAForAction(vision_backbone="clip", action_type="continuous", action_dim=7)
pipeline = VLAPipeline(model=model, unnorm_key="bridge_orig")
action = pipeline.predict(image="camera.jpg", language="pick up the cup")
# action: (7,) -- [x, y, z, rx, ry, rz, gripper]| Inference | Training | Throughput | |
|---|---|---|---|
| CLIP + Continuous | 1.9 ms | 17.5 ms/step | 847 img/s |
| DINOv2 + Diffusion | 2.7 ms | 24.9 ms/step | 626 img/s |
| SigLIP + Discrete | 2.2 ms | 25.0 ms/step | 746 img/s |
| SAM + Chunking | 1.5 ms | 17.5 ms/step | 1530 img/s |
Run python examples/benchmark.py to reproduce on your hardware.
git clone https://github.com/ArchishmanSengupta/mlx-vla.git
cd mlx-vla
pip install -e .Optional extras:
pip install -e ".[all]" # everything (dev + training + data)
pip install -e ".[training]" # mlx-lm for full LLM backbones
pip install -e ".[data]" # h5py for HDF5 datasets
pip install -e ".[dev]" # pytest, black, ruff, mypyRequirements: Python 3.10+, Apple Silicon Mac (M1/M2/M3/M4), MLX >= 0.18.0
python examples/quickstart.py6 lines -- creates a model, runs inference, prints a 7-DOF action. Under 1 second.
from mlx_vla.train_vla import train_vla
trainer = train_vla(
model="openvla-7b",
dataset="bridge_v2",
use_lora=True,
output_dir="./output",
num_epochs=3,
)
trainer.train()# Train with YAML config
mlx-vla train --config examples/train_lora.yaml
# Train with CLI flags
mlx-vla train --model openvla-7b --dataset bridge_v2 --use-lora --epochs 5
# Inference
mlx-vla infer --model ./my_model --image img.jpg --instruction "pick up the cup"
# Generate a default config
mlx-vla create-config --output config.yamlpython examples/demo.pyEnd-to-end: creates dataset, trains model, saves, loads, runs inference, streams actions. Takes ~2 seconds.
| Component | Options | Default |
|---|---|---|
| Vision encoder | clip, dinov2, siglip, sam |
clip |
| Fusion | cross_attention, concat, gated, qkv_fusion |
cross_attention |
| Action head | discrete, continuous, diffusion, chunking |
discrete |
| Action dim | any int | 7 |
| Action horizon | any int (multi-step for diffusion/chunking) | 1 |
| Robot platform | bridge_orig, widowx_250, franka, panda, kuka |
bridge_orig |
Everything is config-driven via YAML or Python dataclasses. See examples/train_lora.yaml for a full example.
Prepare a directory with JSON episode files:
my_data/
episode_0.json # [{"image": "img_0.png", "action": [0.1, ...], "language": "pick up cup"}, ...]
img_0.png
from mlx_vla.models.modeling_vla import VLAForAction
from mlx_vla.data.dataset import EpisodeDataset
from mlx_vla.training.trainer import VLATrainer
from mlx_vla.core import VLATrainingArguments
model = VLAForAction(vision_backbone="clip", action_type="continuous", action_dim=7, image_size=224)
dataset = EpisodeDataset("./my_data", split="train")
args = VLATrainingArguments(output_dir="./output", num_train_epochs=10, learning_rate=1e-4)
trainer = VLATrainer(model=model, args=args, train_dataset=dataset)
trainer.train()
model.save("./my_model")from mlx_vla.training.lora import apply_lora, merge_lora
model = VLAForAction(action_type="discrete", action_dim=7)
model = apply_lora(model, rank=8, alpha=16, target_modules=["query_proj", "value_proj"])
# ... train ...
merged = merge_lora(model) # merge weights for zero-overhead inference
merged.save("./deployed_model")model = VLAForAction(
vision_backbone="dinov2",
action_type="diffusion",
action_dim=7,
action_horizon=4, # predict 4 future timesteps
)
actions = model.predict_action(pixel_values, input_ids)
# shape: (batch, 4, 7) -- 4-step action trajectoryfrom mlx_vla.inference.pipeline import VLAPipeline
pipeline = VLAPipeline(model="./my_model", unnorm_key="franka")
for action in pipeline.stream_actions(camera_stream(), language="stack the cups"):
robot.send(action) # 7-DOF action at each framemodel.save("./my_model") # saves config.json + model.npz
loaded = VLAForAction.load("./my_model") # restores architecture + weightsmlx_vla/
models/
vision.py # CLIP, DINOv2, SigLIP, SAM encoders
language.py # Language model wrappers (MLX-LM, embedding fallback)
fusion.py # Vision-language fusion (4 strategies)
action_heads.py # Action heads (discrete, continuous, diffusion, chunking)
modeling_vla.py # VLAForAction -- the main model
training/
trainer.py # VLATrainer (gradient accumulation, eval, callbacks)
lora.py # LoRA apply / merge
optimizers.py # AdamW, Adam, SGD, RMSprop, Lion + schedulers
callbacks.py # Checkpoint, logging, early stopping
data/
dataset.py # VLADataset, RLDSDataset, BridgeDataset, EpisodeDataset
collator.py # Batch collation, image preprocessing, action normalization
dataloader.py # Batching and shuffling
normalizer.py # Per-robot action unnormalization
inference/
pipeline.py # VLAPipeline -- image in, action out
utils/
config.py # YAML/JSON config system with typed dataclasses
cli/
main.py # CLI: train, infer, export, create-config
| Model | Vision | Fusion | Action | Notes |
|---|---|---|---|---|
openvla-7b |
CLIP (1024) | cross_attention | discrete | LLaMA-7B backbone |
openvla-3b |
CLIP (768) | cross_attention | discrete | Smaller variant |
llava-1.5-7b |
CLIP (1024) | concat | continuous | LLaVA-based |
octo-small |
DINOv2 (384) | concat | diffusion | 4-step horizon |
octo-base |
DINOv2 (768) | concat | diffusion | 4-step horizon |
RLDS (require tensorflow_datasets): bridge_v2, oxe/bridge_v2, oxe/rx1, oxe/franka_kitchen, oxe/taco, aloha
Local: directory of JSON episodes, single JSON file, HDF5
| Robot | unnorm_key |
Workspace | Gripper |
|---|---|---|---|
| Bridge/WidowX | bridge_orig, widowx_250 |
+/-150mm | 0-1 |
| Franka | franka |
+/-0.5m | 0-1 |
| Panda | panda |
+/-0.5m | 0-1 |
| Kuka | kuka |
+/-0.5m | -1 to 1 |
pip install -e ".[dev]"
python -m pytest tests/ -v281 tests covering models, training, inference, LoRA, data loading, config, edge cases, and all component combinations.
| File | Description |
|---|---|
examples/quickstart.py |
6-line inference demo |
examples/demo.py |
Full end-to-end demo (dataset + train + inference + streaming) |
examples/benchmark.py |
Performance benchmarks across all configurations |
examples/train_lora.yaml |
YAML config for LoRA fine-tuning |
examples/train_diffusion.yaml |
YAML config for diffusion policy |
MIT License - see LICENSE for details.