22#
33# SPDX-License-Identifier: Apache-2.0
44
5- import sys
5+ # System packages
6+ import argparse
67import random
8+ import sys
9+ import time
10+
11+ # Installed packages
712import torch
813from transformers import AutoTokenizer , AutoModelForQuestionAnswering
914from torchao .quantization .quant_api import (
1318from torchao .quantization .granularity import PerAxis
1419from torchao .quantization .quant_primitives import MappingType
1520
21+ # Local modules
1622from utils import nlp
1723
18- import time
1924
20- import argparse
25+ def get_best_span_from_scores (start_scores , end_scores , max_answer_len , top_k ):
26+ best_score = - 1e18
27+ (start_idx , end_idx ) = (0 , 0 )
28+ topk_start_posns = torch .topk (start_scores [0 ], k = top_k ).indices .tolist ()
29+ topk_end_posns = torch .topk (end_scores [0 ], k = top_k ).indices .tolist ()
30+ for start in topk_start_posns :
31+ for end in topk_end_posns :
32+ if (end < start ) or ((end + 1 ) - start > max_answer_len ):
33+ continue
34+ score = start_scores [0 ][start ].item () + end_scores [0 ][end ].item ()
35+ if score > best_score :
36+ (best_score , start_idx , end_idx ) = (score , start , end )
37+ return (start_idx , end_idx )
38+
2139
2240def main ():
2341 """
@@ -46,7 +64,7 @@ def main():
4664 subject = args .get ("subject" ,"" )
4765 context = ""
4866 question = args .get ("question" ,"" )
49- answer = args . get ( "answer" , "" )
67+ answer = ""
5068 squadid = args .get ("squadid" ,"" )
5169
5270 # Setup the question, either from a specified SQuAD record
@@ -80,13 +98,8 @@ def main():
8098 i_record = 0
8199 else :
82100 if subject :
83- print (
84- "Picking a question at random on the subject: " ,
85- subject ,
86- )
87- squad_records = squad_data .loc [
88- squad_data ["subject" ] == subject
89- ]
101+ print ("Picking a question at random on the subject: " , subject )
102+ squad_records = squad_data .loc [squad_data ["subject" ] == subject ]
90103 else :
91104 print (
92105 "No SQuAD ID or question provided, picking one at random!"
@@ -109,16 +122,17 @@ def main():
109122 question = squad_records ["question" ].iloc [i_record ]
110123 answer = squad_records ["answer" ].iloc [i_record ]
111124
125+ # Select model and tokenizer
112126 if args ["bert_large" ]:
113127 model_hf_path = "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad"
114128 model_name = "BERT Large"
115129 else :
116130 model_hf_path = "distilbert-base-uncased-distilled-squad"
117131 model_name = "DistilBERT"
118-
119132 token = AutoTokenizer .from_pretrained (model_hf_path , return_token_type_ids = True )
120133 model = AutoModelForQuestionAnswering .from_pretrained (model_hf_path )
121134
135+ # Optional: quantize
122136 if args ["quantize" ]:
123137 quantize_ (
124138 model ,
@@ -133,50 +147,52 @@ def main():
133147 filter_fn = lambda m , _ : isinstance (m , torch .nn .Linear ),
134148 )
135149
136- encoding = token .encode_plus (
137- question ,
138- context ,
139- max_length = 512 , truncation = True
140- )
141-
142- input_ids , attention_mask = (
143- encoding ["input_ids" ],
144- encoding ["attention_mask" ],
145- )
150+ # Encode context
151+ encoding = token .encode_plus (question , context , max_length = 512 , truncation = True )
152+ (input_ids , attention_mask ) = (encoding ["input_ids" ], encoding ["attention_mask" ])
146153
154+ # Warm-up
147155 if args ["warmup" ]:
148156 model (
149157 torch .tensor ([input_ids ]),
150158 attention_mask = torch .tensor ([attention_mask ]),
151159 return_dict = False ,
152160 )
153161
162+ # Process
154163 start_time = time .time ()
155- start_scores , end_scores = model (
156- torch .tensor ([input_ids ]),
157- attention_mask = torch .tensor ([attention_mask ]),
158- return_dict = False ,
159- )
164+ with torch .no_grad ():
165+ start_scores , end_scores = model (
166+ torch .tensor ([input_ids ]),
167+ attention_mask = torch .tensor ([attention_mask ]),
168+ return_dict = False ,
169+ )
160170 end_time = time .time ()
161171
162- answer_ids = input_ids [
163- torch . argmax ( start_scores ) : torch . argmax ( end_scores ) + 1
164- ]
165- answer_tokens = token . convert_ids_to_tokens (
166- answer_ids , skip_special_tokens = True
167- )
168- answer_tokens_to_string = token .convert_tokens_to_string ( answer_tokens )
172+ # Post-process scores to find most likely answer
173+ ( start_idx , end_idx ) = get_best_span_from_scores (
174+ start_scores , end_scores , max_answer_len = 30 , top_k = 20 )
175+
176+ # Decode answer
177+ answer_ids = input_ids [ start_idx : end_idx + 1 ]
178+ answer_tokens_to_string = token .decode ( answer_ids , skip_special_tokens = True ). strip ( )
169179
170180 # Display results
171181 print (f"\n { model_name } question answering example." )
172182 print ("======================================" )
173183 print ("Reading from: " , subject , source )
174- print ("\n Context: " , context )
175- print (f"Inference time: { end_time - start_time } s" )
184+ max_context_to_print = 1000
185+ if len (context ) <= max_context_to_print :
186+ print ("\n Context: " , context )
187+ else :
188+ print (f"\n Context (limited to { max_context_to_print } chars):" )
189+ print (context [:max_context_to_print ])
190+ print ("..." )
176191 print ("--" )
177- print ("Question: " , question )
178- print ("Answer: " , answer_tokens_to_string )
179- print ("Reference Answer: " , answer )
192+ print ("Question:" , question )
193+ print ("Answer:" , answer_tokens_to_string )
194+ print ("Reference Answer:" , answer )
195+ print (f"Inference time: { end_time - start_time :.6f} s" )
180196
181197
182198if __name__ == "__main__" :
0 commit comments