Skip to content

Commit 64d06d4

Browse files
committed
more benchmarks
1 parent f1617cb commit 64d06d4

6 files changed

Lines changed: 181 additions & 5 deletions

File tree

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@ __pycache__/
2121
# Benchmark outputs & downloads
2222
data/
2323
benchmarks/
24+
*.gguf
25+
.cache/
26+
.matplotlib/
27+
.fontconfig/

BENCHMARKS.md

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
#Quantization Benchmark Suite
1+
# Quantization Benchmark Suite
22

33
This repository now provides a reproducible benchmark that compares FP32,
4-
post - training ternary quantization(PTQ), and quantization - aware training(QAT) through a small Fashion-MNIST classifier. The script is located at `scripts/ternary_quantization_benchmark.py` and is designed to log accuracy, latency, and storage so you can understand the benefits of moving from float32 weights to ternary-trained representations.
4+
post-training ternary quantization (PTQ), and quantization-aware training (QAT)
5+
through a small Fashion-MNIST classifier. The script is located at
6+
`scripts/ternary_quantization_benchmark.py` and is designed to log accuracy,
7+
latency, and storage so you can understand the benefits of moving from float32
8+
weights to ternary-trained representations.
59

610
## Benchmark matrix
711

@@ -61,6 +65,67 @@ Expected output:
6165
- Console summary with size, accuracy/loss, and images/s for baseline/PTQ/QAT.
6266
- `benchmarks/vit_cifar10_baseline.json` with stage metrics and model metadata.
6367

68+
### Fast-mode recipes (quick baselines)
69+
70+
Use these when you want a low-latency run to confirm the pipeline without
71+
waiting for full PTQ/QAT loops.
72+
73+
ViT size + accuracy baseline (skip throughput, minimal eval):
74+
75+
```bash
76+
python scripts/vit_ptq_qat_benchmark.py \
77+
--model-id google/vit-base-patch16-224 \
78+
--device cpu \
79+
--threshold 0.45 \
80+
--batch-size 16 \
81+
--max-train-samples 256 \
82+
--max-eval-samples 128 \
83+
--eval-batches 1 \
84+
--max-eval-batches 1 \
85+
--skip-throughput \
86+
--json-output benchmarks/vit_cifar10_quick.json
87+
```
88+
89+
Observed output (CPU, size-only run with `--max-eval-batches 0` + `--skip-throughput`):
90+
- baseline size: 0.32 GiB
91+
- PTQ size: 0.03 GiB
92+
- accuracy/loss/images_per_s: 0.0 (skipped)
93+
94+
Phi-3 baseline PPL only (skip latency + PTQ PPL/QAT):
95+
96+
```bash
97+
python scripts/phi3_ptq_qat_benchmark.py \
98+
--model-id microsoft/Phi-3-mini-4k-instruct \
99+
--device cpu \
100+
--dtype float32 \
101+
--max-eval-tokens 512 \
102+
--eval-texts 16 \
103+
--max-new-tokens 16 \
104+
--skip-latency \
105+
--skip-ptq-ppl \
106+
--json-output benchmarks/phi3_baseline_ppl.json
107+
```
108+
109+
Status: PTQ PPL + short QAT pending (CPU-only PTQ conversion exceeded 2h locally). Resume on GPU:
110+
111+
```bash
112+
python scripts/phi3_ptq_qat_benchmark.py \
113+
--model-id microsoft/Phi-3-mini-4k-instruct \
114+
--device auto \
115+
--dtype bfloat16 \
116+
--threshold 0.45 \
117+
--max-eval-tokens 128 \
118+
--eval-texts 2 \
119+
--max-new-tokens 0 \
120+
--skip-latency \
121+
--run-qat \
122+
--qat-steps 5 \
123+
--train-split 'train[:10]' \
124+
--json-output benchmarks/phi3_ptq_qat_fast.json
125+
```
126+
127+
Note: PTQ still runs on CPU (t81.torch fallback), so keep enough host RAM available.
128+
64129
### 4) GGUF export + load check
65130

66131
```bash
@@ -148,6 +213,50 @@ Each row contains:
148213

149214
Use this CSV to plot accuracy vs. storage or compare latency across the three modes.
150215

216+
## JSON artifact schema (ViT + Phi-3)
217+
218+
The ViT and Phi-3 scripts emit JSON when you pass `--json-output`. These files
219+
are intended to be committed alongside baseline numbers.
220+
221+
ViT JSON keys (from `scripts/vit_ptq_qat_benchmark.py`):
222+
223+
```json
224+
{
225+
"model_id": "google/vit-base-patch16-224",
226+
"dataset": "cifar10",
227+
"device": "cpu",
228+
"threshold": 0.45,
229+
"baseline": {"size_gib": 0.00, "accuracy": 0.0, "loss": 0.0, "images_per_s": 0.0},
230+
"ptq": {"size_gib": 0.00, "accuracy": 0.0, "loss": 0.0, "images_per_s": 0.0},
231+
"qat": null
232+
}
233+
```
234+
235+
Phi-3 JSON keys (from `scripts/phi3_ptq_qat_benchmark.py`):
236+
237+
```json
238+
{
239+
"model_id": "microsoft/Phi-3-mini-4k-instruct",
240+
"dataset": "wikitext-2-raw-v1",
241+
"device": "cpu",
242+
"dtype": "float32",
243+
"threshold": 0.45,
244+
"max_eval_tokens": 1024,
245+
"eval_texts": 32,
246+
"max_new_tokens": 64,
247+
"skip_latency": true,
248+
"skip_ptq_ppl": false,
249+
"run_qat": false,
250+
"qat_steps": 5,
251+
"train_split": "train[:1%]",
252+
"learning_rate": 5e-5,
253+
"compression_ratio": 0.0,
254+
"baseline": {"size_gib": 0.00, "ppl": 0.0, "tok_s": 0.0},
255+
"ptq": {"size_gib": 0.00, "ppl": null, "tok_s": 0.0},
256+
"qat": null
257+
}
258+
```
259+
151260
## Diagrams
152261

153262
View the [benchmark comparison diagram](docs/diagrams/benchmarks.mermaid.md) for a quick latency/storage summary that highlights the 15–22× wins.

docs/ROADMAP.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,15 @@ Recent work has delivered parts of this roadmap:
6565
* **Recommendation 3** — Python entry-points table added to `docs/python-api.md` and `docs/python-cookbook.md`, with links from `docs/index.md`. **In progress (benchmark visibility added in `README.md`, `BENCHMARKS.md`, and the Phi-3 notebook).**
6666
* **GGUF compatibility** — Phi-3 export validated (`phi3-tq1-fixed12.gguf`); QKV split experiment reverted for llama.cpp parity.
6767
* **QAT benchmark groundwork** — ViT CIFAR-10 PTQ/QAT script added with size-only baseline captured; Phi-3 baseline PPL captured (PTQ PPL/QAT pending).
68+
* **GPU fallback safety**`t81.torch` now warns + falls back to CPU for PTQ when tensors originate on GPU; smoke test added and troubleshooting docs updated.
6869

6970
### Status timeline (recent highlights)
7071

7172
* Python entry-point discoverability refreshed (docs landing page + cookbook + API entry table).
7273
* Phi-3 GGUF export validated with llama.cpp baseline metrics captured for reference.
7374
* CLI documentation updated to call out Phi-3 GGUF compatibility expectations.
7475
* ViT size-only baseline logged; Phi-3 baseline PPL captured with PTQ PPL/QAT queued.
76+
* GPU fallback behavior documented; `.gitignore` hardened against GGUF/cache artifacts.
7577

7678
### High-impact next priorities (effort vs. impact)
7779

docs/troubleshooting.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ This guide complements [README.md](../README.md) and the other reference pages b
3636
- **Progress bar missing?** The progress reporting relies on `tqdm`; install it (`pip install tqdm`) if the CLI skips bars or prints raw percentages.
3737
- **Meta device / accelerate offload errors.** When converting large Hugging Face checkpoints with the default `device_map=auto`, Accelerate may place many layers onto disk/`meta`. If `t81 convert`/`t81 gguf` (or the legacy `t81-convert`/`t81-gguf` scripts) later tries to call `.to("cpu")` you’ll hit `NotImplementedError: Cannot copy out of meta tensor` or `RuntimeError: You can't move a model that has some modules offloaded to cpu or disk.` Always rerun with `--force-cpu-device-map` or `--device-map none/cpu` so the checkpoints stay on host RAM, and set `ACCELERATE_DISABLE=1` or `HF_ACCELERATE_DISABLE=1` before launching the CLI so no accelerate hooks re-enable offloading. This makes every `nn.Linear` serializable and avoids the meta-device save failure that occurs after the “Some parameters are on the meta device” log.
3838
- **Large GGUF conversions.** Extremely large ternary bundles (Gemma 3.x / Llama 3.x) may exhaust RAM when you read them with older readers because the whole file was loaded before parsing. The new `t81.gguf.read_gguf` implementation parses metadata, tensor infos, and tensor payloads directly from the file handle, seeks to each sorted tensor offset, and never slices the entire bundle into memory. When you still hit memory pressure or Matplotlib font-cache warnings, define `MPLCONFIGDIR=$PWD/data/cache/matplotlib` and `FONTCONFIG_PATH=$PWD/data/cache/fontconfig`, prefer `--force-cpu-device-map`, and keep `ACCELERATE_DISABLE=1`/`HF_ACCELERATE_DISABLE=1` set before rerunning the CLI so every tensor stays on the CPU.
39+
- **GPU PTQ fallback.** `t81.torch.TernaryTensor.from_float` currently quantizes on CPU; when your model lives on GPU it will warn and move tensors to CPU for PTQ, then return outputs back to the original device. Keep enough host RAM available and avoid meta/offload tensors (`device_map=auto`) if you plan to run PTQ PPL or short QAT loops.
3940
4041
## Testing & benchmarking hiccups
4142

t81/torch/__init__.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
from typing import Any, Callable, Dict, Mapping, Optional, Sequence
11+
import warnings
1112

1213
import numpy as np
1314
import torch
@@ -50,8 +51,17 @@ def _quantize_tensor(tensor: torch.Tensor, threshold: float = 0.5) -> torch.Tens
5051

5152

5253
def _to_cpu_float(tensor: torch.Tensor) -> torch.Tensor:
54+
if tensor.is_meta:
55+
raise NotImplementedError(
56+
"t81.trit does not support meta tensors; load weights on CPU or disable offload."
57+
)
5358
if tensor.device.type != "cpu":
54-
raise NotImplementedError("t81.trit currently only supports CPU tensors")
59+
warnings.warn(
60+
f"t81.trit runs on CPU; moving tensor from {tensor.device} to CPU.",
61+
RuntimeWarning,
62+
stacklevel=2,
63+
)
64+
return tensor.detach().to(device="cpu", dtype=torch.float32)
5565
return tensor.detach().to(dtype=torch.float32, copy=False)
5666

5767

@@ -272,7 +282,11 @@ def forward(ctx, ternary_weight: TernaryTensor, rhs: torch.Tensor) -> torch.Tens
272282
rhs_cpu = _to_cpu_float(rhs)
273283
ctx.save_for_backward(rhs_cpu)
274284
ctx.ternary = ternary_weight
275-
return ternary_weight._compute_gemm(rhs_cpu)
285+
ctx.rhs_device = rhs.device
286+
output = ternary_weight._compute_gemm(rhs_cpu)
287+
if rhs.device.type != "cpu":
288+
output = output.to(rhs.device)
289+
return output
276290

277291
@staticmethod
278292
def backward(ctx, grad_output: torch.Tensor) -> tuple[None, torch.Tensor]:
@@ -281,8 +295,12 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[None, torch.Tensor]:
281295
_limbs_to_trits(ctx.ternary._packed, ctx.ternary._rows, ctx.ternary._k_limbs)
282296
.astype(np.float32)
283297
)[:, : ctx.ternary._k_actual]
298+
grad_output_cpu = grad_output.to(device="cpu")
284299
# Gradient for rhs follows the usual matmul gradient formula.
285-
grad_rhs = weight_float.transpose(-2, -1).matmul(grad_output)
300+
grad_rhs = weight_float.transpose(-2, -1).matmul(grad_output_cpu)
301+
rhs_device = getattr(ctx, "rhs_device", torch.device("cpu"))
302+
if rhs_device.type != "cpu":
303+
grad_rhs = grad_rhs.to(rhs_device)
286304
return None, grad_rhs
287305

288306

tests/python/test_torch_ternary.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import warnings
2+
3+
import pytest
4+
5+
6+
torch = pytest.importorskip("torch")
7+
8+
9+
def _best_device() -> torch.device | None:
10+
if torch.cuda.is_available():
11+
return torch.device("cuda")
12+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
13+
return torch.device("mps")
14+
return None
15+
16+
17+
def test_ternary_tensor_cpu_roundtrip():
18+
import t81.torch as t81_torch
19+
20+
weight = torch.linspace(-1.0, 1.0, steps=48, dtype=torch.float32).reshape(3, 16)
21+
ternary = t81_torch.TernaryTensor.from_float(weight, threshold=0.45)
22+
rhs = torch.randn(16, 4, dtype=torch.float32)
23+
out = torch.matmul(ternary, rhs)
24+
assert out.shape == (3, 4)
25+
assert out.device.type == "cpu"
26+
27+
28+
def test_ternary_tensor_gpu_fallback_warning():
29+
device = _best_device()
30+
if device is None:
31+
pytest.skip("No GPU/MPS device available for fallback test.")
32+
33+
import t81.torch as t81_torch
34+
35+
weight = torch.linspace(-1.0, 1.0, steps=48, dtype=torch.float32, device=device).reshape(3, 16)
36+
with warnings.catch_warnings(record=True) as caught:
37+
warnings.simplefilter("always", RuntimeWarning)
38+
ternary = t81_torch.TernaryTensor.from_float(weight, threshold=0.45)
39+
assert any("moving tensor" in str(item.message) for item in caught)
40+
rhs = torch.randn(16, 4, dtype=torch.float32, device=device)
41+
out = torch.matmul(ternary, rhs)
42+
assert out.device == device

0 commit comments

Comments
 (0)