Skip to content

Commit 608226c

Browse files
feat: Add 4-bit quantization support for LLM inference on Apple Silicon
This PR adds quantized tensor operations to EMLX, enabling efficient large language model inference on Apple Silicon GPUs. It powers a pure Elixir LLM inference stack achieving 135 tok/s on Qwen3-8B-4bit. ## Motivation Running 8B parameter models requires 16GB+ at fp16. With 4-bit quantization, the same model fits in ~5GB, enabling inference on consumer Macs. This work is part of a broader effort to bring production LLM inference to the Elixir ecosystem: - bobby_posts: Pure Elixir Qwen3-8B inference (135 tok/s) - bobby_posts_adapters: LoRA fine-tuning for personalized generation - bumblebee_quantized: Quantized model loading for Bumblebee - safetensors_ex: MLX 4-bit safetensors format support ## Implementation ### NIFs (c_src/emlx_nif.cpp) Three new NIFs wrapping MLX's quantization functions: - quantized_matmul(x, w, scales, biases, transpose, group_size, bits) - dequantize(w, scales, biases, group_size, bits) - quantize(w, group_size, bits) ### Backend Integration (lib/emlx/backend.ex) Per Paulo's feedback, quantization metadata is stored directly on the Backend struct (not a nested map): defstruct [:ref, :shape, :type, :data, :scales, :biases, :group_size] When Nx.dot detects a quantized tensor (scales != nil), it automatically dispatches to quantized_matmul. The tensor type {:s, 4} carries the bit width, so bits is not stored separately. ### User API (lib/emlx/quantization.ex) Clean user-facing module with comprehensive documentation: # Quantize weights {q_weight, scales, biases} = EMLX.Quantization.quantize(weight) # Create tensor for Nx operations qt = EMLX.Quantization.tensor(q_weight, scales, biases, shape) # Nx.dot automatically uses quantized_matmul result = Nx.dot(input, qt) ### Elixir API (lib/emlx.ex) Low-level functions for direct NIF access: - EMLX.quantized_matmul/7 - EMLX.dequantize/5 - EMLX.quantize/3 - EMLX.quantized_tensor/5 ## MLX 4-bit Format MLX uses group-wise affine quantization: dequantized[i] = scales[i/group_size] * (packed_int4[i] - biases[i/group_size]) Weights are packed as uint32 (8 int4 values per uint32). With group_size=64: - Weight [out, in] becomes [out, in/8] as uint32 - Scales: [out, in/group_size] as bfloat16 - Biases: [out, in/group_size] as bfloat16 ## Tests 33 tests covering: - Low-level NIF operations (6 tests) - Backend integration with Nx.dot (9 tests) - EMLX.Quantization module API (18 tests) - End-to-end LLM inference patterns ## Performance On Apple M-series with Qwen3-8B-4bit: - Single-token latency: ~135 tok/s - Memory: 4-5GB vs 16GB for fp16 - 14x faster than Python mlx_lm (9.5 tok/s) ## Bumblebee Integration Path With this merged, quantized models can use EMLX as a pure backend: 1. Model loader detects quantized safetensors 2. Creates EMLX.Quantization.tensor for each quantized weight 3. Model definition unchanged - Nx.dot works transparently 4. EMLX backend handles all dispatch This enables upstreaming quantized model support to Bumblebee without changing the serving interface. ## References - Use case: https://github.com/notactuallytreyanastasio/bobby_posts - PR discussion: #96 - MLX quantization: https://ml-explore.github.io/mlx/build/html/python/nn.html Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 728c5a3 commit 608226c

7 files changed

Lines changed: 1201 additions & 7 deletions

File tree

c_src/emlx_nif.cpp

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,63 @@ NIF(as_strided) {
970970
TENSOR(mlx::core::as_strided(*t, shape, strides, offset, device));
971971
}
972972

973+
// ============================================================================
974+
// Quantization Operations (for 4-bit model support)
975+
// ============================================================================
976+
977+
// quantized_matmul - Multiplies x with a quantized weight matrix w
978+
// This is the key operation for efficient 4-bit inference
979+
// MLX API: quantized_matmul(x, w, scales, biases, transpose, group_size, bits, stream)
980+
NIF(quantized_matmul) {
981+
TENSOR_PARAM(0, x); // Input tensor [batch, seq, hidden]
982+
TENSOR_PARAM(1, w); // Quantized weights [out/8, in] (uint32 packed)
983+
TENSOR_PARAM(2, scales); // Scales [out/group_size, in] (bfloat16)
984+
TENSOR_PARAM(3, biases); // Biases [out/group_size, in] (bfloat16)
985+
PARAM(4, bool, transpose);
986+
PARAM(5, int, group_size);
987+
PARAM(6, int, bits);
988+
DEVICE_PARAM(7, device);
989+
990+
TENSOR(mlx::core::quantized_matmul(
991+
*x, *w, *scales, *biases, transpose, group_size, bits, device));
992+
}
993+
994+
// dequantize - Converts quantized weights back to float
995+
// Useful for debugging and verification
996+
// MLX API: dequantize(w, scales, biases, group_size, bits, stream)
997+
NIF(dequantize) {
998+
TENSOR_PARAM(0, w); // Quantized weights (uint32 packed)
999+
TENSOR_PARAM(1, scales); // Scales (bfloat16)
1000+
TENSOR_PARAM(2, biases); // Biases (bfloat16)
1001+
PARAM(3, int, group_size);
1002+
PARAM(4, int, bits);
1003+
DEVICE_PARAM(5, device);
1004+
1005+
TENSOR(mlx::core::dequantize(*w, *scales, *biases, group_size, bits, device));
1006+
}
1007+
1008+
// quantize - Quantizes a float tensor to packed format
1009+
// Returns tuple of {weights, scales, biases}
1010+
// MLX API: quantize(w, group_size, bits, stream) -> tuple<array, array, array>
1011+
NIF(quantize) {
1012+
TENSOR_PARAM(0, w); // Float weights to quantize
1013+
PARAM(1, int, group_size);
1014+
PARAM(2, int, bits);
1015+
DEVICE_PARAM(3, device);
1016+
1017+
try {
1018+
auto [qw, scales, biases] = mlx::core::quantize(*w, group_size, bits, device);
1019+
1020+
ERL_NIF_TERM result_tuple[3];
1021+
result_tuple[0] = create_tensor_resource(env, qw);
1022+
result_tuple[1] = create_tensor_resource(env, scales);
1023+
result_tuple[2] = create_tensor_resource(env, biases);
1024+
1025+
return nx::nif::ok(env, enif_make_tuple3(env, result_tuple[0], result_tuple[1], result_tuple[2]));
1026+
}
1027+
CATCH()
1028+
}
1029+
9731030
static ErlNifFunc nif_funcs[] = {
9741031
{"strides", 1, strides},
9751032
{"as_strided", 5, as_strided},
@@ -1087,7 +1144,11 @@ static ErlNifFunc nif_funcs[] = {
10871144
{"max", 4, max},
10881145
{"min", 4, min},
10891146
{"clip", 4, clip},
1090-
{"tri_inv", 3, tri_inv}
1147+
{"tri_inv", 3, tri_inv},
1148+
// Quantization operations
1149+
{"quantized_matmul", 8, quantized_matmul},
1150+
{"dequantize", 6, dequantize},
1151+
{"quantize", 4, quantize}
10911152
};
10921153

10931154
// Update the NIF initialization

lib/emlx.ex

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,97 @@ defmodule EMLX do
258258
defvalue scalar_type(tensor)
259259
defvalue shape(tensor)
260260

261+
## Quantization operations (for 4-bit model support)
262+
263+
@doc """
264+
Performs quantized matrix multiplication.
265+
266+
This is the key operation for efficient 4-bit inference. It multiplies `x` with
267+
quantized weights `w` (packed as uint32), using scales and biases for
268+
dequantization during the computation.
269+
270+
## Parameters
271+
- `x` - Input tensor (e.g., {batch, seq, hidden})
272+
- `w` - Quantized weights as uint32 (8 int4 values packed per uint32)
273+
- `scales` - Per-group scale factors (bfloat16)
274+
- `biases` - Per-group zero points (bfloat16)
275+
- `transpose` - Whether to transpose weights (default: true)
276+
- `group_size` - Number of weights per scale/bias group (default: 64)
277+
- `bits` - Quantization bits (default: 4)
278+
"""
279+
@mlx_function {:quantized_matmul, 8}
280+
def quantized_matmul(
281+
{dev_x, ref_x} = _tensor_x,
282+
{dev_w, ref_w} = _tensor_w,
283+
{dev_s, ref_s} = _tensor_scales,
284+
{dev_b, ref_b} = _tensor_biases,
285+
transpose \\ true,
286+
group_size \\ 64,
287+
bits \\ 4
288+
)
289+
when is_tensor(dev_x, ref_x) and is_tensor(dev_w, ref_w) and
290+
is_tensor(dev_s, ref_s) and is_tensor(dev_b, ref_b) do
291+
device = merge_device(merge_device(dev_x, dev_w), merge_device(dev_s, dev_b))
292+
mlx_device = mlx_device!(device, -1)
293+
294+
EMLX.NIF.quantized_matmul(ref_x, ref_w, ref_s, ref_b, transpose, group_size, bits, mlx_device)
295+
|> unwrap_tensor!(device)
296+
end
297+
298+
@doc """
299+
Dequantizes packed weights to floating point.
300+
301+
Converts quantized weights back to their original floating point representation.
302+
Useful for debugging and verification.
303+
304+
## Parameters
305+
- `w` - Quantized weights as uint32 (packed int4 values)
306+
- `scales` - Per-group scale factors
307+
- `biases` - Per-group zero points
308+
- `group_size` - Number of weights per group (default: 64)
309+
- `bits` - Quantization bits (default: 4)
310+
"""
311+
@mlx_function {:dequantize, 6}
312+
def dequantize(
313+
{dev_w, ref_w} = _tensor_w,
314+
{dev_s, ref_s} = _tensor_scales,
315+
{dev_b, ref_b} = _tensor_biases,
316+
group_size \\ 64,
317+
bits \\ 4
318+
)
319+
when is_tensor(dev_w, ref_w) and is_tensor(dev_s, ref_s) and is_tensor(dev_b, ref_b) do
320+
device = merge_device(dev_w, merge_device(dev_s, dev_b))
321+
mlx_device = mlx_device!(device, -1)
322+
323+
EMLX.NIF.dequantize(ref_w, ref_s, ref_b, group_size, bits, mlx_device)
324+
|> unwrap_tensor!(device)
325+
end
326+
327+
@doc """
328+
Quantizes a floating point tensor to packed format.
329+
330+
Returns a tuple of `{quantized_weights, scales, biases}` where:
331+
- `quantized_weights` - Packed uint32 tensor (8 int4 values per uint32)
332+
- `scales` - Per-group scale factors
333+
- `biases` - Per-group zero points
334+
335+
## Parameters
336+
- `w` - Float tensor to quantize
337+
- `group_size` - Number of weights per group (default: 64)
338+
- `bits` - Quantization bits (default: 4)
339+
"""
340+
@mlx_function {:quantize, 4}
341+
def quantize({dev_w, ref_w} = _tensor_w, group_size \\ 64, bits \\ 4)
342+
when is_tensor(dev_w, ref_w) do
343+
device = dev_w
344+
mlx_device = mlx_device!(device, -1)
345+
346+
{weights_ref, scales_ref, biases_ref} =
347+
EMLX.NIF.quantize(ref_w, group_size, bits, mlx_device) |> unwrap!()
348+
349+
{{device, weights_ref}, {device, scales_ref}, {device, biases_ref}}
350+
end
351+
261352
def to_blob({device, ref} = tensor) when is_tensor(device, ref) do
262353
# Two-step to_blob: eval on main scheduler, then copy on dirty scheduler
263354
eval(tensor)
@@ -323,6 +414,56 @@ defmodule EMLX do
323414
defvalue item(tensor)
324415
defvalue strides(tensor)
325416

417+
# ============================================================================
418+
# Quantized Tensor Operations (Backend-Integrated)
419+
# ============================================================================
420+
421+
@doc """
422+
Creates a quantized Nx.Tensor with backend-level quantization options.
423+
424+
This creates an Nx.Tensor where the EMLX.Backend struct contains
425+
quantization metadata. When this tensor is used in `Nx.dot`, the
426+
backend automatically dispatches to `quantized_matmul`.
427+
428+
## Parameters
429+
430+
- `weight_ref` - EMLX device ref for packed uint32 weights
431+
- `scales_ref` - EMLX device ref for per-group scale factors
432+
- `biases_ref` - EMLX device ref for per-group zero points
433+
- `original_shape` - Shape before quantization {out_features, in_features}
434+
435+
## Options
436+
437+
- `:bits` - Quantization bits (default: 4)
438+
- `:group_size` - Weights per scale/bias group (default: 64)
439+
440+
## Example
441+
442+
# Quantize weights
443+
{q_weight, scales, biases} = EMLX.quantize(weight_tensor, 64, 4)
444+
445+
# Create quantized Nx.Tensor
446+
quantized = EMLX.quantized_tensor(q_weight, scales, biases, {512, 4096})
447+
448+
# Standard Nx.dot automatically uses quantized_matmul!
449+
result = Nx.dot(input, quantized)
450+
"""
451+
def quantized_tensor(weight_ref, scales_ref, biases_ref, original_shape, opts \\ []) do
452+
EMLX.Backend.quantized_tensor(weight_ref, scales_ref, biases_ref, original_shape, opts)
453+
end
454+
455+
@doc """
456+
Converts an EMLX device ref back to an Nx.Tensor.
457+
458+
## Example
459+
460+
result_ref = EMLX.some_operation(input)
461+
result_tensor = EMLX.to_nx(result_ref)
462+
"""
463+
def to_nx({device, ref} = device_ref) when is_atom(device) and is_reference(ref) do
464+
EMLX.Backend.to_nx(device_ref)
465+
end
466+
326467
@behaviour Nx.Defn.Compiler
327468

328469
@impl Nx.Defn.Compiler

0 commit comments

Comments
 (0)