Skip to content

Commit e9ecf08

Browse files
Merge pull request #441 from puneetmatharu/fix-answer-questions-py
Patch `answer_questions.py`
2 parents 3114c62 + 153f2c2 commit e9ecf08

1 file changed

Lines changed: 55 additions & 39 deletions

File tree

ML-Frameworks/pytorch-aarch64/examples/answer_questions.py

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import sys
5+
# System packages
6+
import argparse
67
import random
8+
import sys
9+
import time
10+
11+
# Installed packages
712
import torch
813
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
914
from torchao.quantization.quant_api import (
@@ -13,11 +18,24 @@
1318
from torchao.quantization.granularity import PerAxis
1419
from torchao.quantization.quant_primitives import MappingType
1520

21+
# Local modules
1622
from utils import nlp
1723

18-
import time
1924

20-
import argparse
25+
def get_best_span_from_scores(start_scores, end_scores, max_answer_len, top_k):
26+
best_score = -1e18
27+
(start_idx, end_idx) = (0, 0)
28+
topk_start_posns = torch.topk(start_scores[0], k=top_k).indices.tolist()
29+
topk_end_posns = torch.topk(end_scores[0], k=top_k).indices.tolist()
30+
for start in topk_start_posns:
31+
for end in topk_end_posns:
32+
if (end < start) or ((end + 1) - start > max_answer_len):
33+
continue
34+
score = start_scores[0][start].item() + end_scores[0][end].item()
35+
if score > best_score:
36+
(best_score, start_idx, end_idx) = (score, start, end)
37+
return (start_idx, end_idx)
38+
2139

2240
def main():
2341
"""
@@ -46,7 +64,7 @@ def main():
4664
subject = args.get("subject","")
4765
context = ""
4866
question = args.get("question","")
49-
answer = args.get("answer","")
67+
answer = ""
5068
squadid = args.get("squadid","")
5169

5270
# Setup the question, either from a specified SQuAD record
@@ -80,13 +98,8 @@ def main():
8098
i_record = 0
8199
else:
82100
if subject:
83-
print(
84-
"Picking a question at random on the subject: ",
85-
subject,
86-
)
87-
squad_records = squad_data.loc[
88-
squad_data["subject"] == subject
89-
]
101+
print("Picking a question at random on the subject: ", subject)
102+
squad_records = squad_data.loc[squad_data["subject"] == subject]
90103
else:
91104
print(
92105
"No SQuAD ID or question provided, picking one at random!"
@@ -109,16 +122,17 @@ def main():
109122
question = squad_records["question"].iloc[i_record]
110123
answer = squad_records["answer"].iloc[i_record]
111124

125+
# Select model and tokenizer
112126
if args["bert_large"]:
113127
model_hf_path = "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad"
114128
model_name = "BERT Large"
115129
else:
116130
model_hf_path = "distilbert-base-uncased-distilled-squad"
117131
model_name = "DistilBERT"
118-
119132
token = AutoTokenizer.from_pretrained(model_hf_path, return_token_type_ids=True)
120133
model = AutoModelForQuestionAnswering.from_pretrained(model_hf_path)
121134

135+
# Optional: quantize
122136
if args["quantize"]:
123137
quantize_(
124138
model,
@@ -133,50 +147,52 @@ def main():
133147
filter_fn=lambda m, _: isinstance(m, torch.nn.Linear),
134148
)
135149

136-
encoding = token.encode_plus(
137-
question,
138-
context,
139-
max_length=512, truncation=True
140-
)
141-
142-
input_ids, attention_mask = (
143-
encoding["input_ids"],
144-
encoding["attention_mask"],
145-
)
150+
# Encode context
151+
encoding = token.encode_plus(question, context, max_length=512, truncation=True)
152+
(input_ids, attention_mask) = (encoding["input_ids"], encoding["attention_mask"])
146153

154+
# Warm-up
147155
if args["warmup"]:
148156
model(
149157
torch.tensor([input_ids]),
150158
attention_mask=torch.tensor([attention_mask]),
151159
return_dict=False,
152160
)
153161

162+
# Process
154163
start_time = time.time()
155-
start_scores, end_scores = model(
156-
torch.tensor([input_ids]),
157-
attention_mask=torch.tensor([attention_mask]),
158-
return_dict=False,
159-
)
164+
with torch.no_grad():
165+
start_scores, end_scores = model(
166+
torch.tensor([input_ids]),
167+
attention_mask=torch.tensor([attention_mask]),
168+
return_dict=False,
169+
)
160170
end_time = time.time()
161171

162-
answer_ids = input_ids[
163-
torch.argmax(start_scores) : torch.argmax(end_scores) + 1
164-
]
165-
answer_tokens = token.convert_ids_to_tokens(
166-
answer_ids, skip_special_tokens=True
167-
)
168-
answer_tokens_to_string = token.convert_tokens_to_string(answer_tokens)
172+
# Post-process scores to find most likely answer
173+
(start_idx, end_idx) = get_best_span_from_scores(
174+
start_scores, end_scores, max_answer_len=30, top_k=20)
175+
176+
# Decode answer
177+
answer_ids = input_ids[start_idx:end_idx + 1]
178+
answer_tokens_to_string = token.decode(answer_ids, skip_special_tokens=True).strip()
169179

170180
# Display results
171181
print(f"\n{model_name} question answering example.")
172182
print("======================================")
173183
print("Reading from: ", subject, source)
174-
print("\nContext: ", context)
175-
print(f"Inference time: {end_time - start_time}s")
184+
max_context_to_print = 1000
185+
if len(context) <= max_context_to_print:
186+
print("\nContext: ", context)
187+
else:
188+
print(f"\nContext (limited to {max_context_to_print} chars):")
189+
print(context[:max_context_to_print])
190+
print("...")
176191
print("--")
177-
print("Question: ", question)
178-
print("Answer: ", answer_tokens_to_string)
179-
print("Reference Answer: ", answer)
192+
print("Question:", question)
193+
print("Answer:", answer_tokens_to_string)
194+
print("Reference Answer:", answer)
195+
print(f"Inference time: {end_time - start_time:.6f}s")
180196

181197

182198
if __name__ == "__main__":

0 commit comments

Comments
 (0)