Skip to content

Commit b6e6549

Browse files
Bug fix on bitnet models quantizations
1 parent d523873 commit b6e6549

4 files changed

Lines changed: 262 additions & 9 deletions

File tree

README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,14 @@ Inference on CPU for a 1.58-bit LLM decoding step. Click the image to view the o
3232
```bash
3333
git clone https://github.com/UIC-InDeXLab/RSR-Core.git
3434
cd RSR-Core
35-
pip install -e .
35+
pip install -e . --no-build-isolation
3636
```
3737

3838
#### Building the kernels
3939

40+
Both CPU and CUDA kernels are automatically built during `pip install -e . --no-build-isolation`.
41+
You can also build them manually:
42+
4043
**CPU kernels** — Compile the C shared libraries via the provided Makefiles.
4144
Requires `gcc` with AVX2 and OpenMP support.
4245

@@ -45,8 +48,9 @@ make -C kernels/bit_1/cpu
4548
make -C kernels/bit_1_58/cpu
4649
```
4750

48-
**CUDA kernels** — No manual build step needed. CUDA kernels are JIT-compiled
49-
by PyTorch on first use (`torch.utils.cpp_extension`). Requirements:
51+
**CUDA kernels** — Pre-built during install if a GPU is available. If not,
52+
they are JIT-compiled by PyTorch on first use (`torch.utils.cpp_extension`).
53+
Requirements:
5054
- CUDA toolkit (matching your PyTorch build)
5155
- `ninja` (`pip install ninja`)
5256

integrations/hf/model_infer.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,10 +582,37 @@ def load_hf_model(
582582

583583
# The @torch.compile-decorated unpack_weights in transformers' BitNet
584584
# integration fails on CPU with dynamo. Force eager execution.
585+
_prev_suppress = torch._dynamo.config.suppress_errors
585586
torch._dynamo.config.suppress_errors = True
586587

587588
model = AutoModelForCausalLM.from_pretrained(model_source, **load_kwargs)
588589

590+
torch._dynamo.config.suppress_errors = _prev_suppress
591+
592+
# Work around a bug in transformers' BitNetDeserialize.convert: it unpacks
593+
# ternary weights with dtype=uint8 (the storage dtype) instead of the
594+
# model's compute dtype, so -1 wraps to 255 and F.linear gets a dtype
595+
# mismatch. Only apply to BitNet models (detected via quantization_config).
596+
# Fix by reinterpreting uint8 as int8 then casting to the model's dtype.
597+
_is_bitnet = getattr(model.config, "quantization_config", None) is not None and (
598+
getattr(model.config.quantization_config, "quant_method", None) == "bitnet"
599+
or (isinstance(model.config.quantization_config, dict)
600+
and model.config.quantization_config.get("quant_method") == "bitnet")
601+
)
602+
if _is_bitnet:
603+
# Determine the correct target dtype: use the explicitly requested dtype,
604+
# otherwise infer from the non-quantized parameters already in the model.
605+
if dtype:
606+
_target_dtype = getattr(torch, dtype)
607+
else:
608+
_non_uint8 = [
609+
p.dtype for p in model.parameters() if p.dtype != torch.uint8
610+
]
611+
_target_dtype = _non_uint8[0] if _non_uint8 else torch.bfloat16
612+
for _name, param in model.named_parameters():
613+
if param.dtype == torch.uint8:
614+
param.data = param.data.view(torch.int8).to(_target_dtype)
615+
589616
# bitsandbytes models are already placed by device_map; skip .to()
590617
if quantize not in ("8bit", "4bit"):
591618
model = model.to(device)

setup.py

Lines changed: 114 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,142 @@
1-
"""Custom build: compile CPU kernels (make) during pip install."""
1+
"""Custom build: compile CPU and CUDA kernels during pip install."""
22

33
import subprocess
4+
import sys
45
import os
56
from setuptools import setup
67
from setuptools.command.build_py import build_py
78
from setuptools.command.develop import develop
89

910
ROOT = os.path.dirname(os.path.abspath(__file__))
1011

11-
KERNEL_DIRS = [
12+
CPU_KERNEL_DIRS = [
1213
os.path.join(ROOT, "kernels", "bit_1", "cpu"),
1314
os.path.join(ROOT, "kernels", "bit_1_58", "cpu"),
1415
]
1516

17+
CUDA_KERNEL_DIR_BIT1 = os.path.join(ROOT, "kernels", "bit_1", "cuda")
18+
CUDA_KERNEL_DIR_BIT158 = os.path.join(ROOT, "kernels", "bit_1_58", "cuda")
1619

17-
def _build_kernels():
18-
for d in KERNEL_DIRS:
20+
21+
def _build_cpu_kernels():
22+
for d in CPU_KERNEL_DIRS:
1923
if os.path.isdir(d) and os.path.isfile(os.path.join(d, "Makefile")):
2024
subprocess.check_call(["make", "-C", d])
2125

2226

27+
def _print_cuda_skip_warning():
28+
"""Print a warning that CUDA kernels were not pre-built."""
29+
BOLD_RED = "\033[1;31m"
30+
RESET = "\033[0m"
31+
YELLOW = "\033[33m"
32+
print()
33+
print(f"{YELLOW}setup.py: CUDA not available — CUDA kernels were not pre-built.{RESET}")
34+
print(f"{YELLOW} They will be JIT-compiled on the first CUDA run, if available.{RESET}")
35+
print()
36+
print(f" {BOLD_RED}FOR BENCHMARKS PAY ATTENTION TO FIRST BUILD TIME{RESET}")
37+
print()
38+
39+
40+
def _build_cuda_kernels():
41+
"""JIT-compile all CUDA kernels so first run has zero compilation delay."""
42+
try:
43+
import torch
44+
if not torch.cuda.is_available():
45+
_print_cuda_skip_warning()
46+
return
47+
except ImportError:
48+
_print_cuda_skip_warning()
49+
return
50+
51+
from torch.utils.cpp_extension import load
52+
53+
major, minor = torch.cuda.get_device_capability()
54+
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{major}.{minor}"
55+
56+
# Ensure ninja is on PATH
57+
bindir = os.path.dirname(sys.executable)
58+
path_entries = os.environ.get("PATH", "").split(os.pathsep)
59+
if bindir and bindir not in path_entries:
60+
os.environ["PATH"] = os.pathsep.join([bindir, *path_entries])
61+
62+
# -- bit_1 CUDA kernels (torch JIT) --
63+
bit1_kernels = [
64+
("rsr_cuda_v5_9", "rsr_v5_9.cu"),
65+
("rsr_cuda_v5_8", "rsr_v5_8.cu"),
66+
("rsr_cuda_v5_6", "rsr_v5_6.cu"),
67+
("rsr_cuda_v4_10", "rsr_v4_10.cu"),
68+
]
69+
for name, source in bit1_kernels:
70+
source_path = os.path.join(CUDA_KERNEL_DIR_BIT1, source)
71+
if not os.path.isfile(source_path):
72+
continue
73+
print(f"setup.py: JIT compiling {name} ...")
74+
try:
75+
load(
76+
name=name,
77+
sources=[source_path],
78+
extra_cuda_cflags=["-O3", "--use_fast_math"],
79+
verbose=False,
80+
)
81+
except Exception as e:
82+
print(f"setup.py: WARNING: failed to compile {name}: {e}")
83+
84+
# -- bit_1_58 CUDA kernels (torch JIT) --
85+
bit158_jit_kernels = [
86+
("rsr_ternary_cuda_v2_0", "rsr_ternary_v2_0.cu"),
87+
]
88+
for name, source in bit158_jit_kernels:
89+
source_path = os.path.join(CUDA_KERNEL_DIR_BIT158, source)
90+
if not os.path.isfile(source_path):
91+
continue
92+
print(f"setup.py: JIT compiling {name} ...")
93+
try:
94+
load(
95+
name=name,
96+
sources=[source_path],
97+
extra_cuda_cflags=["-O3", "--use_fast_math"],
98+
verbose=False,
99+
)
100+
except Exception as e:
101+
print(f"setup.py: WARNING: failed to compile {name}: {e}")
102+
103+
# -- bit_1_58 BitNet kernel (nvcc direct) --
104+
bitnet_source = os.path.join(CUDA_KERNEL_DIR_BIT158, "bitnet_kernels.cu")
105+
bitnet_lib = os.path.join(CUDA_KERNEL_DIR_BIT158, "libbitnet.so")
106+
if os.path.isfile(bitnet_source) and not os.path.isfile(bitnet_lib):
107+
cuda_home = os.environ.get("CUDA_HOME", "/usr/local/cuda")
108+
nvcc = os.path.join(cuda_home, "bin", "nvcc")
109+
if os.path.isfile(nvcc):
110+
arch = f"{major}{minor}"
111+
cmd = [
112+
nvcc, "-std=c++17", "--shared", "--compiler-options", "-fPIC",
113+
"-O3", "--use_fast_math", "-lineinfo",
114+
f"-gencode=arch=compute_{arch},code=sm_{arch}",
115+
f"-gencode=arch=compute_{arch},code=compute_{arch}",
116+
bitnet_source, "-o", bitnet_lib,
117+
]
118+
print(f"setup.py: compiling libbitnet.so ...")
119+
try:
120+
subprocess.run(cmd, cwd=CUDA_KERNEL_DIR_BIT158, check=True,
121+
capture_output=True, text=True)
122+
except Exception as e:
123+
print(f"setup.py: WARNING: failed to compile libbitnet.so: {e}")
124+
125+
126+
def _build_all_kernels():
127+
_build_cpu_kernels()
128+
_build_cuda_kernels()
129+
130+
23131
class BuildPyWithKernels(build_py):
24132
def run(self):
25-
_build_kernels()
133+
_build_all_kernels()
26134
super().run()
27135

28136

29137
class DevelopWithKernels(develop):
30138
def run(self):
31-
_build_kernels()
139+
_build_all_kernels()
32140
super().run()
33141

34142

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import { useEffect, useState } from "react";
2+
import { Link } from "react-router-dom";
3+
import { listModels, listMultipliers, getSystemInfo } from "../api";
4+
import Card from "../components/Card";
5+
6+
function StatCard({ label, value, sub, to, loading }) {
7+
const inner = (
8+
<div className="text-center">
9+
{loading ? (
10+
<div className="flex justify-center">
11+
<div className="h-8 w-16 rounded bg-gray-700 animate-pulse" />
12+
</div>
13+
) : (
14+
<p className="text-3xl font-bold text-cyan-400">{value}</p>
15+
)}
16+
<p className="text-sm text-gray-400 mt-1">{label}</p>
17+
{loading ? (
18+
<div className="flex justify-center mt-0.5">
19+
<div className="h-3 w-24 rounded bg-gray-700 animate-pulse" />
20+
</div>
21+
) : (
22+
sub && <p className="text-xs text-gray-600 mt-0.5">{sub}</p>
23+
)}
24+
</div>
25+
);
26+
if (to) return <Link to={to} className="block hover:scale-105 transition-transform">{inner}</Link>;
27+
return inner;
28+
}
29+
30+
export default function DashboardPage() {
31+
const [models, setModels] = useState([]);
32+
const [multipliers, setMultipliers] = useState([]);
33+
const [sys, setSys] = useState(null);
34+
const [loading, setLoading] = useState(true);
35+
36+
useEffect(() => {
37+
Promise.all([
38+
listModels().then(setModels).catch(() => {}),
39+
listMultipliers().then(setMultipliers).catch(() => {}),
40+
getSystemInfo().then(setSys).catch(() => {}),
41+
]).finally(() => setLoading(false));
42+
}, []);
43+
44+
const cpuModels = models.filter((m) => m.device === "cpu").length;
45+
const cudaModels = models.filter((m) => m.device === "cuda").length;
46+
const totalSize = models.reduce((s, m) => s + m.size_mb, 0);
47+
48+
return (
49+
<div className="max-w-5xl mx-auto space-y-6">
50+
<div>
51+
<h1 className="text-2xl font-bold text-white">Dashboard</h1>
52+
<p className="text-gray-500 text-sm mt-1">RSR-core project overview</p>
53+
</div>
54+
55+
<div className="grid grid-cols-2 md:grid-cols-4 gap-4">
56+
<Card><StatCard label="Preprocessed Models" value={models.length} sub={`${cpuModels} CPU / ${cudaModels} CUDA`} to="/models" loading={loading} /></Card>
57+
<Card><StatCard label="Multipliers" value={multipliers.length} to="/multipliers" loading={loading} /></Card>
58+
<Card><StatCard label="Total Size" value={`${(totalSize / 1024).toFixed(1)} GB`} sub="preprocessed data" loading={loading} /></Card>
59+
<Card><StatCard label="CUDA" value={sys?.cuda_available ? "Available" : "N/A"} sub={sys?.cuda_device || "CPU only"} loading={loading} /></Card>
60+
</div>
61+
62+
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
63+
<Card title="Quick Actions">
64+
<div className="space-y-2">
65+
<Link to="/preprocess" className="block w-full text-left px-4 py-3 rounded-lg bg-gray-800 hover:bg-gray-750 hover:bg-cyan-500/5 border border-gray-700 transition-colors">
66+
<span className="text-sm font-medium text-gray-200">Preprocess a Model</span>
67+
<span className="block text-xs text-gray-500 mt-0.5">Search HuggingFace and apply RSR preprocessing</span>
68+
</Link>
69+
<Link to="/inference" className="block w-full text-left px-4 py-3 rounded-lg bg-gray-800 hover:bg-cyan-500/5 border border-gray-700 transition-colors">
70+
<span className="text-sm font-medium text-gray-200">Run Inference</span>
71+
<span className="block text-xs text-gray-500 mt-0.5">Generate text with RSR-accelerated models</span>
72+
</Link>
73+
<Link to="/benchmarks" className="block w-full text-left px-4 py-3 rounded-lg bg-gray-800 hover:bg-cyan-500/5 border border-gray-700 transition-colors">
74+
<span className="text-sm font-medium text-gray-200">View Benchmarks</span>
75+
<span className="block text-xs text-gray-500 mt-0.5">Compare RSR performance vs baselines</span>
76+
</Link>
77+
</div>
78+
</Card>
79+
80+
<Card title="Preprocessed Models">
81+
{loading ? (
82+
<div className="space-y-2">
83+
{[...Array(3)].map((_, i) => (
84+
<div key={i} className="flex items-center justify-between px-3 py-2 bg-gray-800 rounded-lg">
85+
<div className="space-y-1.5">
86+
<div className="h-4 w-32 rounded bg-gray-700 animate-pulse" />
87+
<div className="h-3 w-20 rounded bg-gray-700 animate-pulse" />
88+
</div>
89+
<div className="h-5 w-10 rounded bg-gray-700 animate-pulse" />
90+
</div>
91+
))}
92+
</div>
93+
) : models.length === 0 ? (
94+
<p className="text-gray-500 text-sm">No preprocessed models yet. <Link to="/preprocess" className="text-cyan-400 hover:underline">Preprocess one</Link>.</p>
95+
) : (
96+
<div className="space-y-2 max-h-64 overflow-auto">
97+
{models.map((m) => (
98+
<div key={m.name} className="flex items-center justify-between px-3 py-2 bg-gray-800 rounded-lg">
99+
<div>
100+
<p className="text-sm text-gray-200">{m.name}</p>
101+
<p className="text-xs text-gray-500">{m.num_layers} layers, k={m.k}</p>
102+
</div>
103+
<span className={`text-xs px-2 py-0.5 rounded ${m.device === "cuda" ? "bg-green-500/10 text-green-400" : "bg-blue-500/10 text-blue-400"}`}>
104+
{m.device}
105+
</span>
106+
</div>
107+
))}
108+
</div>
109+
)}
110+
</Card>
111+
</div>
112+
</div>
113+
);
114+
}

0 commit comments

Comments
 (0)