Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions test/quantization/config/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from tico.quantization.config.ptq import PTQConfig
from tico.quantization.config.specs import affine, mx
from tico.quantization.wrapq.dtypes import DType
from tico.quantization.wrapq.dtypes import MXDtype
from tico.quantization.wrapq.observers.mx import MXObserver
from tico.quantization.wrapq.observers.minmax import MinMaxObserver
from tico.quantization.wrapq.observers.mx import MXObserver
from tico.quantization.wrapq.qscheme import QScheme

Expand All @@ -48,10 +51,64 @@ def test_build_norm_override_from_quant_specs(self):
self.assertEqual(override["act_out"]["dtype"], DType.uint(8))
self.assertEqual(override["weight"]["dtype"], DType.uint(4))
self.assertEqual(override["weight"]["qscheme"], QScheme.PER_CHANNEL_ASYMM)
self.assertEqual(
override["weight"]["observer"],
MinMaxObserver,
)

def test_build_norm_override_empty_when_no_specs(self):
self.assertEqual(_build_norm_override(norm=None, norm_weight=None), {})

def test_build_norm_override_weight_observer_not_overridden_by_io_observer(self):
"""Weight observer must always be derived from weight dtype, never from io_observer."""
mx8 = MXDtype(elem_format="int8")
override = _build_norm_override(
norm_dtype=None,
norm_weight_dtype=DType.int(16),
norm_io_dtype=mx8,
norm_io_observer=MXObserver,
)

# Weight observer must be MinMaxObserver (from DType.int(16)), NOT MXObserver
self.assertEqual(
override["weight"]["observer"],
MinMaxObserver,
)
# I/O observers must be MXObserver
self.assertEqual(
override["act_in"]["observer"],
MXObserver,
)
self.assertEqual(
override["act_out"]["observer"],
MXObserver,
)

def test_build_norm_override_weight_observer_not_overridden_by_io_observer(self):
"""Weight observer must always be derived from weight dtype, never from io_observer."""
mx8 = MXDtype(elem_format="int8")
override = _build_norm_override(
norm_dtype=None,
norm_weight_dtype=DType.int(16),
norm_io_dtype=mx8,
norm_io_observer=MXObserver,
)

# Weight observer must be MinMaxObserver (from DType.int(16)), NOT MXObserver
self.assertEqual(
override["weight"]["observer"],
MinMaxObserver,
)
# I/O observers must be MXObserver
self.assertEqual(
override["act_in"]["observer"],
MXObserver,
)
self.assertEqual(
override["act_out"]["observer"],
MXObserver,
)


class TestLlamaOverrideBuilders(unittest.TestCase):
def test_build_llama_layer_overrides(self):
Expand Down Expand Up @@ -95,6 +152,131 @@ def test_build_llama_overrides(self):
QScheme.PER_CHANNEL_ASYMM,
)

def test_build_llama_layer_overrides_with_linear_io_dtype(self):
"""linear_io_dtype produces act_in/act_out on linear projections and fine-grained activations."""
mx8 = MXDtype(elem_format="int8")
overrides = _build_llama_layer_overrides(
linear_weight_dtype=DType.uint(4),
linear_io_dtype=mx8,
)

# Linear projections get act_in/act_out with MX observer
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]:
self.assertEqual(
overrides["self_attn"][proj]["act_in"]["dtype"], mx8
)
self.assertEqual(
overrides["self_attn"][proj]["act_in"]["observer"], MXObserver
)
self.assertEqual(
overrides["self_attn"][proj]["act_out"]["dtype"], mx8
)

# Fine-grained activations (driven by linear_io_dtype)
self.assertEqual(
overrides["self_attn"]["hidden"]["dtype"], mx8
)
self.assertEqual(
overrides["self_attn"]["attn_mask"]["dtype"], mx8
)
self.assertEqual(
overrides["self_attn"]["logits"]["dtype"], mx8
)
self.assertEqual(
overrides["mlp"]["mul"]["dtype"], mx8
)
self.assertEqual(
overrides["attn_mask"]["dtype"], mx8
)
self.assertEqual(
overrides["mlp_residual_out"]["dtype"], mx8
)
self.assertEqual(
overrides["self_attn_residual_out"]["dtype"], mx8
)

def test_build_llama_layer_overrides_with_rms_norm_io(self):
"""rms_norm_io_dtype produces act_in/act_out on norms and mlp.act_in."""
mx8 = MXDtype(elem_format="int8")
overrides = _build_llama_layer_overrides(
linear_weight_dtype=DType.uint(4),
norm_weight_dtype=DType.int(16),
rms_norm_io_dtype=mx8,
)

# Norm act_in/act_out
for norm in ["input_layernorm", "post_attention_layernorm"]:
self.assertEqual(overrides[norm]["act_in"]["dtype"], mx8)
self.assertEqual(overrides[norm]["act_in"]["observer"], MXObserver)
self.assertEqual(overrides[norm]["act_out"]["dtype"], mx8)

# mlp.act_in (driven by rms_norm_io_dtype)
self.assertEqual(overrides["mlp"]["act_in"]["dtype"], mx8)

# self_attn.hidden is now driven by linear_io_dtype, not rms_norm_io_dtype
self.assertNotIn("hidden", overrides["self_attn"])

def test_build_llama_layer_overrides_with_softmax_override(self):
"""softmax_dtype produces override on self_attn.softmax and mask_add."""
mx8 = MXDtype(elem_format="int8")
overrides = _build_llama_layer_overrides(
linear_weight_dtype=DType.uint(4),
softmax_dtype=mx8,
)

self.assertEqual(overrides["self_attn"]["softmax"]["dtype"], mx8)
self.assertEqual(overrides["self_attn"]["softmax"]["observer"], MXObserver)
self.assertEqual(overrides["self_attn"]["mask_add"]["dtype"], mx8)
self.assertEqual(overrides["self_attn"]["mask_add"]["observer"], MXObserver)

def test_build_llama_overrides_with_linear_io_produces_causal_mask(self):
"""linear_io_dtype produces model-level causal_mask override."""
mx8 = MXDtype(elem_format="int8")
overrides = _build_llama_overrides(
num_hidden_layers=1,
linear_weight_dtype=DType.uint(4),
linear_io_dtype=mx8,
)

self.assertEqual(overrides["model"]["causal_mask"]["dtype"], mx8)
self.assertEqual(
overrides["model"]["causal_mask"]["observer"], MXObserver
)

def test_build_llama_overrides_lm_head_gets_act_in_act_out(self):
"""lm_head gets full linear desc (weight + act_in + act_out) when io is specified."""
mx8 = MXDtype(elem_format="int8")
overrides = _build_llama_overrides(
num_hidden_layers=1,
linear_weight_dtype=DType.uint(4),
lm_head_weight_dtype=DType.uint(8),
linear_io_dtype=mx8,
)

self.assertEqual(overrides["lm_head"]["act_in"]["dtype"], mx8)
self.assertEqual(overrides["lm_head"]["act_out"]["dtype"], mx8)
self.assertEqual(overrides["lm_head"]["weight"]["dtype"], DType.uint(8))

def test_no_fine_grained_overrides_when_no_io_specified(self):
"""No fine-grained activation overrides when no io dtype/observer is given."""
overrides = _build_llama_layer_overrides(
linear_weight_dtype=DType.uint(4),
)

# No act_in/act_out on linear projections
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]:
self.assertNotIn("act_in", overrides["self_attn"][proj])
self.assertNotIn("act_out", overrides["self_attn"][proj])

# No fine-grained activations
self.assertNotIn("attn_mask", overrides["self_attn"])
self.assertNotIn("softmax", overrides["self_attn"])
self.assertNotIn("hidden", overrides["self_attn"])
self.assertNotIn("mul", overrides.get("mlp", {}))
self.assertNotIn("attn_mask", overrides)
self.assertNotIn("self_attn_residual_out", overrides)
self.assertNotIn("mlp_residual_out", overrides)


class TestBuildLlmPtqConfig(unittest.TestCase):
def test_build_llm_ptq_config_llama(self):
Expand Down
Loading