@@ -164,6 +164,8 @@ def _try_translate_n_batch(
164164 builder = TranslationResultBuilder (input_tokens )
165165 for token , score in zip (output ["translation_tokens" ], output ["token_scores" ]):
166166 builder .append_token (token , TranslationSources .NMT , exp (score ))
167+ if output ["sequence_score" ] is not None :
168+ builder .set_sequence_confidence (exp (output ["sequence_score" ]))
167169 word_pairs : Optional [Collection [Union [AlignedWordPair , Tuple [int , int ]]]] = None
168170 if output .get ("token_attentions" ) is not None :
169171 src_indices = torch .argmax (output ["token_attentions" ], dim = 1 ).tolist ()
@@ -257,36 +259,56 @@ def _forward(self, model_inputs, **generate_kwargs):
257259 output_ids = output .sequences
258260 beam_indices = output .beam_indices
259261 scores = output .scores
262+ assert scores is not None and beam_indices is not None
263+ sequences_scores = output .sequences_scores
260264 attentions = output .cross_attentions
261265 elif isinstance (output , GreedySearchEncoderDecoderOutput ):
262266 output_ids = output .sequences
263- beam_indices = torch . zeros_like ( output_ids )
267+ beam_indices = None
264268 assert output .scores is not None
265- scores = tuple (torch .nn .functional .log_softmax (logits , dim = - 1 ) for logits in output .scores )
269+ scores = output .scores
270+ sequences_scores = None
266271 attentions = output .cross_attentions
267272 else :
268273 raise RuntimeError ("Cannot postprocess the output of the model." )
269274
270- assert beam_indices is not None and scores is not None
271- out_b = output_ids .shape [0 ]
275+ transition_scores = cast (
276+ torch .Tensor ,
277+ self .model .compute_transition_scores (
278+ output_ids , # type: ignore
279+ scores , # type: ignore
280+ beam_indices , # type: ignore
281+ normalize_logits = True ,
282+ ),
283+ )
284+
285+ if beam_indices is None :
286+ beam_indices = torch .zeros_like (output_ids )
287+
288+ out_b , seq_len = output_ids .shape
272289 num_beams = scores [0 ].shape [0 ] // in_b
273290 n_sequences = out_b // in_b
291+
292+ ts_len = transition_scores .shape [1 ]
293+ if ts_len == seq_len :
294+ token_logprobs = transition_scores
295+ elif ts_len == seq_len - 1 :
296+ token_logprobs = torch .cat (
297+ [
298+ torch .zeros (out_b , 1 , device = transition_scores .device , dtype = transition_scores .dtype ),
299+ transition_scores ,
300+ ],
301+ dim = 1 ,
302+ )
303+ else :
304+ raise RuntimeError (
305+ f"Unexpected transition_scores length { ts_len } for sequences length { seq_len } . "
306+ "Cannot align token scores robustly."
307+ )
308+
274309 start_index = 0
275310 if self .model .config .decoder_start_token_id is not None :
276311 start_index = 1
277- indices = torch .stack (
278- (
279- torch .arange (output_ids .shape [1 ] - start_index , device = output_ids .device ).expand (in_b , n_sequences , - 1 ),
280- torch .reshape (beam_indices [:, start_index :] % num_beams , (in_b , n_sequences , - 1 )),
281- torch .reshape (output_ids [:, start_index :], (in_b , n_sequences , - 1 )),
282- ),
283- dim = 3 ,
284- )
285- scores = torch .stack (scores , dim = 0 ).reshape (len (scores ), in_b , num_beams , - 1 ).transpose (0 , 1 )
286- scores = torch_gather_nd (scores , indices , 1 )
287- if self .model .config .decoder_start_token_id is not None :
288- scores = torch .cat ((torch .zeros (scores .shape [0 ], scores .shape [1 ], 1 , device = scores .device ), scores ), dim = 2 )
289-
290312 if generate_kwargs ["output_attentions" ] is True :
291313 assert attentions is not None
292314 num_heads = attentions [0 ][0 ].shape [1 ]
@@ -320,13 +342,15 @@ def _forward(self, model_inputs, **generate_kwargs):
320342 ),
321343 dim = 2 ,
322344 )
345+ output_ids = output_ids .reshape (in_b , n_sequences , seq_len )
346+ token_logprobs = token_logprobs .reshape (in_b , n_sequences , seq_len )
323347
324- output_ids = output_ids .reshape (in_b , n_sequences , * output_ids .shape [1 :])
325348 return {
326349 "input_ids" : model_inputs ["input_ids" ],
327350 "input_tokens" : input_tokens ,
328351 "output_ids" : output_ids ,
329- "scores" : scores ,
352+ "scores" : token_logprobs ,
353+ "sequences_scores" : sequences_scores ,
330354 "attentions" : attentions ,
331355 }
332356
@@ -346,24 +370,17 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=False):
346370 records = []
347371
348372 has_attentions = model_outputs .get ("attentions" ) is not None and model_outputs ["attentions" ][0 ] is not None
349- if has_attentions :
350- zipped = zip (
351- model_outputs ["output_ids" ][0 ],
352- model_outputs ["scores" ][0 ],
353- model_outputs ["attentions" ][0 ],
354- )
355- else :
356- zipped = zip (
357- model_outputs ["output_ids" ][0 ],
358- model_outputs ["scores" ][0 ],
359- )
360-
373+ has_sequence_scores = model_outputs ["sequences_scores" ] is not None
374+ zipped = zip (
375+ model_outputs ["output_ids" ][0 ],
376+ model_outputs ["scores" ][0 ],
377+ model_outputs ["sequences_scores" ] if has_sequence_scores else iter (lambda : None , 1 ),
378+ model_outputs ["attentions" ][0 ] if has_attentions else iter (lambda : None , 1 ),
379+ )
361380 for item in zipped :
362- if has_attentions :
363- output_ids , scores , attentions = cast (Tuple [torch .Tensor , torch .Tensor , torch .Tensor ], item )
364- else :
365- output_ids , scores = cast (Tuple [torch .Tensor , torch .Tensor ], item )
366- attentions = None
381+ output_ids , scores , sequence_score , attentions = cast (
382+ Tuple [torch .Tensor , torch .Tensor , Optional [float ], Optional [torch .Tensor ]], item
383+ )
367384
368385 output_tokens : List [str ] = []
369386 output_indices : List [int ] = []
@@ -379,6 +396,7 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=False):
379396 "input_tokens" : input_tokens ,
380397 "translation_tokens" : output_tokens ,
381398 "token_scores" : scores ,
399+ "sequence_score" : sequence_score ,
382400 "translation_text" : self .tokenizer .decode (
383401 output_ids ,
384402 skip_special_tokens = True ,
0 commit comments