From d433f4dd9f4f2b7e4825be87ddf2615f561b3e9e Mon Sep 17 00:00:00 2001 From: seongwoo Date: Wed, 24 Jun 2026 17:55:21 +0900 Subject: [PATCH] [quantization] Fix weight dtype mapping This commit fixes weight dtype mapping. TICO-DCO-1.0-Signed-off-by: seongwoo --- .../quantize_full_qmodel_with_gptq.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index fa666bd2..f7701596 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -86,6 +86,22 @@ # "float16": torch.float16, } +_SUPPORTED_WEIGHT_BITS = (4, 8, 16) + + +def _weight_dtype_from_bits(bits: int) -> DType: + """Return the PTQ weight dtype for a supported bit-width.""" + if bits in (4, 8): + return DType.uint(bits) + if bits == 16: + return DType.int(bits) + + raise ValueError( + f"Unsupported weight bit-width: {bits}. " + f"Expected one of {_SUPPORTED_WEIGHT_BITS}." + ) + + # Hardcoded dataset settings DATASET_NAME = "wikitext" DATASET_CONFIG = "wikitext-2-raw-v1" @@ -949,15 +965,16 @@ def quantize_using_PTQ(q_m, calib_inputs, args): model_type="llama", num_hidden_layers=len(q_m.model.layers), activation=affine(DType.int(16)), - linear_weight=affine(DType.uint(args.linear_weight_bits)), - embedding_weight=affine(DType.uint(args.embedding_weight_bits)), - lm_head_weight=affine(DType.uint(args.lm_head_weight_bits)), + weight=affine(_weight_dtype_from_bits(16)), + linear_weight=affine(_weight_dtype_from_bits(args.linear_weight_bits)), + embedding_weight=affine(_weight_dtype_from_bits(args.embedding_weight_bits)), + lm_head_weight=affine(_weight_dtype_from_bits(args.lm_head_weight_bits)), spin_rotation_weight=( None if args.no_spinquant - else affine(DType.int(args.spin_rotation_weight_bits)) + else affine(_weight_dtype_from_bits(args.spin_rotation_weight_bits)) ), - norm_weight=affine(DType.int(16)), + norm_weight=affine(_weight_dtype_from_bits(16)), strict_wrap=True, profile=args.profile, )