Skip to content

Commit 10ab0ad

Browse files
authored
Merge pull request #401 from puneetmatharu/patch-quantization-in-answer-questions
Patch quantization in PyTorch examples
2 parents 0c8573c + 202681c commit 10ab0ad

3 files changed

Lines changed: 49 additions & 11 deletions

File tree

ML-Frameworks/pytorch-aarch64/examples/answer_questions.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,16 @@
1919
import random
2020
import torch
2121
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
22+
from torchao.quantization.quant_api import (
23+
Int8DynamicActivationIntxWeightConfig,
24+
quantize_,
25+
)
26+
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
27+
PackedLinearInt8DynamicActivationIntxWeightLayout,
28+
Target,
29+
)
30+
from torchao.quantization.granularity import PerAxis
31+
from torchao.quantization.quant_primitives import MappingType
2232

2333
from utils import nlp
2434

@@ -43,7 +53,7 @@ def main():
4353
parser.add_argument("--bert-large", action='store_true',
4454
help="Use BERT large instead of DistilBERT")
4555
parser.add_argument("--quantize", action='store_true',
46-
help="Quantize the model to int8 using dynamic quantization")
56+
help="Quantize the model to int4 using dynamic quantization")
4757
parser.add_argument("--warmup", action='store_true',
4858
help="Run warmup")
4959

@@ -127,10 +137,20 @@ def main():
127137
model = AutoModelForQuestionAnswering.from_pretrained(model_hf_path)
128138

129139
if args["quantize"]:
130-
model = torch.ao.quantization.quantize_dynamic(
140+
layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target=Target.ATEN)
141+
quantize_(
131142
model,
132-
{torch.nn.Linear},
133-
dtype=torch.qint8)
143+
Int8DynamicActivationIntxWeightConfig(
144+
weight_scale_dtype=torch.float32,
145+
weight_granularity=PerAxis(0),
146+
weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR,
147+
layout=layout,
148+
weight_dtype=torch.int4,
149+
intx_packing_format="opaque_aten_kleidiai",
150+
version=2,
151+
),
152+
filter_fn=lambda m, _: isinstance(m, torch.nn.Linear),
153+
)
134154

135155
encoding = token.encode_plus(
136156
question,

ML-Frameworks/pytorch-aarch64/examples/llama_vision_instruct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
PackedLinearInt8DynamicActivationIntxWeightLayout,
3030
Target,
3131
)
32-
from torchao.quantization.granularity import PerGroup, PerAxis
32+
from torchao.quantization.granularity import PerAxis
3333
from torchao.quantization.quant_primitives import MappingType
3434
import numpy as np
3535
import os
@@ -53,7 +53,7 @@ def main(args):
5353
layout=layout,
5454
weight_dtype=torch.int4,
5555
intx_packing_format="opaque_aten_kleidiai",
56-
version=1,
56+
version=2,
5757
),
5858
)
5959

ML-Frameworks/pytorch-aarch64/examples/quantized_linear.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# *******************************************************************************
2-
# Copyright 2024 Arm Limited and affiliates.
2+
# Copyright 2024-2025 Arm Limited and affiliates.
33
# SPDX-License-Identifier: Apache-2.0
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,10 +16,19 @@
1616
# *******************************************************************************
1717

1818
import sys
19-
import os
2019

2120
import torch
2221
import torch.nn as nn
22+
from torchao.quantization.quant_api import (
23+
Int8DynamicActivationIntxWeightConfig,
24+
quantize_,
25+
)
26+
from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import (
27+
PackedLinearInt8DynamicActivationIntxWeightLayout,
28+
Target,
29+
)
30+
from torchao.quantization.granularity import PerAxis
31+
from torchao.quantization.quant_primitives import MappingType
2332

2433
import time
2534

@@ -53,10 +62,19 @@ def forward(self, x):
5362
model(data)
5463
fp32_runtimes.append(time.time() - t0)
5564

56-
model = torch.ao.quantization.quantize_dynamic(
65+
quantize_(
5766
model,
58-
{torch.nn.Linear},
59-
dtype=torch.qint8)
67+
Int8DynamicActivationIntxWeightConfig(
68+
weight_scale_dtype=torch.float32,
69+
weight_granularity=PerAxis(0),
70+
weight_mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR,
71+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target=Target.ATEN),
72+
weight_dtype=torch.int4,
73+
intx_packing_format="opaque_aten_kleidiai",
74+
version=2,
75+
),
76+
filter_fn=lambda m, _: isinstance(m, torch.nn.Linear),
77+
)
6078

6179
# Quantized
6280
runtimes = []

0 commit comments

Comments
 (0)