Skip to content

Commit 9cc4cd0

Browse files
authored
Add sequence confidence to pretranslations (#279)
1 parent b9219fc commit 9cc4cd0

9 files changed

Lines changed: 95 additions & 36 deletions

machine/jobs/nmt_engine_build_job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def _batch_inference(
115115
check_canceled()
116116
for i, result in enumerate(engine.translate_batch(seg_batch)):
117117
pretranslations[current_inference_step + i]["translation"] = result.translation
118+
pretranslations[current_inference_step + i]["sequenceConfidence"] = result.sequence_confidence
118119
current_inference_step += len(seg_batch)
119120
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
120121

machine/jobs/translation_file_service.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class PretranslationInfo(TypedDict):
2020
sourceTokens: List[str] # noqa: N815
2121
translationTokens: List[str] # noqa: N815
2222
alignment: str
23+
sequenceConfidence: float # noqa: N815
2324

2425

2526
class TranslationFileService:
@@ -98,6 +99,7 @@ def generator() -> Generator[PretranslationInfo, None, None]:
9899
sourceTokens=list(),
99100
translationTokens=list(),
100101
alignment="",
102+
sequenceConfidence=0,
101103
)
102104

103105
return ContextManagedGenerator(generator())

machine/translation/huggingface/hugging_face_nmt_engine.py

Lines changed: 54 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def _try_translate_n_batch(
164164
builder = TranslationResultBuilder(input_tokens)
165165
for token, score in zip(output["translation_tokens"], output["token_scores"]):
166166
builder.append_token(token, TranslationSources.NMT, exp(score))
167+
if output["sequence_score"] is not None:
168+
builder.set_sequence_confidence(exp(output["sequence_score"]))
167169
word_pairs: Optional[Collection[Union[AlignedWordPair, Tuple[int, int]]]] = None
168170
if output.get("token_attentions") is not None:
169171
src_indices = torch.argmax(output["token_attentions"], dim=1).tolist()
@@ -257,36 +259,56 @@ def _forward(self, model_inputs, **generate_kwargs):
257259
output_ids = output.sequences
258260
beam_indices = output.beam_indices
259261
scores = output.scores
262+
assert scores is not None and beam_indices is not None
263+
sequences_scores = output.sequences_scores
260264
attentions = output.cross_attentions
261265
elif isinstance(output, GreedySearchEncoderDecoderOutput):
262266
output_ids = output.sequences
263-
beam_indices = torch.zeros_like(output_ids)
267+
beam_indices = None
264268
assert output.scores is not None
265-
scores = tuple(torch.nn.functional.log_softmax(logits, dim=-1) for logits in output.scores)
269+
scores = output.scores
270+
sequences_scores = None
266271
attentions = output.cross_attentions
267272
else:
268273
raise RuntimeError("Cannot postprocess the output of the model.")
269274

270-
assert beam_indices is not None and scores is not None
271-
out_b = output_ids.shape[0]
275+
transition_scores = cast(
276+
torch.Tensor,
277+
self.model.compute_transition_scores(
278+
output_ids, # type: ignore
279+
scores, # type: ignore
280+
beam_indices, # type: ignore
281+
normalize_logits=True,
282+
),
283+
)
284+
285+
if beam_indices is None:
286+
beam_indices = torch.zeros_like(output_ids)
287+
288+
out_b, seq_len = output_ids.shape
272289
num_beams = scores[0].shape[0] // in_b
273290
n_sequences = out_b // in_b
291+
292+
ts_len = transition_scores.shape[1]
293+
if ts_len == seq_len:
294+
token_logprobs = transition_scores
295+
elif ts_len == seq_len - 1:
296+
token_logprobs = torch.cat(
297+
[
298+
torch.zeros(out_b, 1, device=transition_scores.device, dtype=transition_scores.dtype),
299+
transition_scores,
300+
],
301+
dim=1,
302+
)
303+
else:
304+
raise RuntimeError(
305+
f"Unexpected transition_scores length {ts_len} for sequences length {seq_len}. "
306+
"Cannot align token scores robustly."
307+
)
308+
274309
start_index = 0
275310
if self.model.config.decoder_start_token_id is not None:
276311
start_index = 1
277-
indices = torch.stack(
278-
(
279-
torch.arange(output_ids.shape[1] - start_index, device=output_ids.device).expand(in_b, n_sequences, -1),
280-
torch.reshape(beam_indices[:, start_index:] % num_beams, (in_b, n_sequences, -1)),
281-
torch.reshape(output_ids[:, start_index:], (in_b, n_sequences, -1)),
282-
),
283-
dim=3,
284-
)
285-
scores = torch.stack(scores, dim=0).reshape(len(scores), in_b, num_beams, -1).transpose(0, 1)
286-
scores = torch_gather_nd(scores, indices, 1)
287-
if self.model.config.decoder_start_token_id is not None:
288-
scores = torch.cat((torch.zeros(scores.shape[0], scores.shape[1], 1, device=scores.device), scores), dim=2)
289-
290312
if generate_kwargs["output_attentions"] is True:
291313
assert attentions is not None
292314
num_heads = attentions[0][0].shape[1]
@@ -320,13 +342,15 @@ def _forward(self, model_inputs, **generate_kwargs):
320342
),
321343
dim=2,
322344
)
345+
output_ids = output_ids.reshape(in_b, n_sequences, seq_len)
346+
token_logprobs = token_logprobs.reshape(in_b, n_sequences, seq_len)
323347

324-
output_ids = output_ids.reshape(in_b, n_sequences, *output_ids.shape[1:])
325348
return {
326349
"input_ids": model_inputs["input_ids"],
327350
"input_tokens": input_tokens,
328351
"output_ids": output_ids,
329-
"scores": scores,
352+
"scores": token_logprobs,
353+
"sequences_scores": sequences_scores,
330354
"attentions": attentions,
331355
}
332356

@@ -346,24 +370,17 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=False):
346370
records = []
347371

348372
has_attentions = model_outputs.get("attentions") is not None and model_outputs["attentions"][0] is not None
349-
if has_attentions:
350-
zipped = zip(
351-
model_outputs["output_ids"][0],
352-
model_outputs["scores"][0],
353-
model_outputs["attentions"][0],
354-
)
355-
else:
356-
zipped = zip(
357-
model_outputs["output_ids"][0],
358-
model_outputs["scores"][0],
359-
)
360-
373+
has_sequence_scores = model_outputs["sequences_scores"] is not None
374+
zipped = zip(
375+
model_outputs["output_ids"][0],
376+
model_outputs["scores"][0],
377+
model_outputs["sequences_scores"] if has_sequence_scores else iter(lambda: None, 1),
378+
model_outputs["attentions"][0] if has_attentions else iter(lambda: None, 1),
379+
)
361380
for item in zipped:
362-
if has_attentions:
363-
output_ids, scores, attentions = cast(Tuple[torch.Tensor, torch.Tensor, torch.Tensor], item)
364-
else:
365-
output_ids, scores = cast(Tuple[torch.Tensor, torch.Tensor], item)
366-
attentions = None
381+
output_ids, scores, sequence_score, attentions = cast(
382+
Tuple[torch.Tensor, torch.Tensor, Optional[float], Optional[torch.Tensor]], item
383+
)
367384

368385
output_tokens: List[str] = []
369386
output_indices: List[int] = []
@@ -379,6 +396,7 @@ def postprocess(self, model_outputs, clean_up_tokenization_spaces=False):
379396
"input_tokens": input_tokens,
380397
"translation_tokens": output_tokens,
381398
"token_scores": scores,
399+
"sequence_score": sequence_score,
382400
"translation_text": self.tokenizer.decode(
383401
output_ids,
384402
skip_special_tokens=True,

machine/translation/translation_result.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def __init__(
1212
source_tokens: Iterable[str],
1313
target_tokens: Iterable[str],
1414
confidences: Iterable[float],
15+
sequence_confidence: float,
1516
sources: Iterable[TranslationSources],
1617
alignment: WordAlignmentMatrix,
1718
phrases: Iterable[Phrase],
@@ -20,6 +21,7 @@ def __init__(
2021
self._source_tokens = list(source_tokens)
2122
self._target_tokens = list(target_tokens)
2223
self._confidences = list(confidences)
24+
self._sequence_confidence = sequence_confidence
2325
self._sources = list(sources)
2426
self._alignment = alignment
2527
self._phrases = list(phrases)
@@ -49,6 +51,10 @@ def target_tokens(self) -> Sequence[str]:
4951
def confidences(self) -> Sequence[float]:
5052
return self._confidences
5153

54+
@property
55+
def sequence_confidence(self) -> float:
56+
return self._sequence_confidence
57+
5258
@property
5359
def sources(self) -> Sequence[TranslationSources]:
5460
return self._sources

machine/translation/translation_result_builder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
self._confidences: List[float] = []
2929
self._sources: List[TranslationSources] = []
3030
self._phrases: List[PhraseInfo] = []
31+
self._sequence_confidence: float = -1.0
3132

3233
@property
3334
def source_tokens(self) -> Sequence[str]:
@@ -49,6 +50,10 @@ def sources(self) -> Sequence[TranslationSources]:
4950
def phrases(self) -> Sequence[PhraseInfo]:
5051
return self._phrases
5152

53+
@property
54+
def sequence_confidence(self) -> float:
55+
return self.sequence_confidence
56+
5257
def append_token(self, token: str, source: TranslationSources, confidence: float) -> None:
5358
self._target_tokens.append(token)
5459
self._sources.append(source)
@@ -60,6 +65,9 @@ def mark_phrase(self, source_segment_range: Range[int], alignment: WordAlignment
6065
def set_confidence(self, index: int, confidence: float) -> None:
6166
self._confidences[index] = confidence
6267

68+
def set_sequence_confidence(self, sequence_confidence: float):
69+
self._sequence_confidence = sequence_confidence
70+
6371
def correct_prefix(
6472
self,
6573
word_ops: Iterable[EditOperation],
@@ -165,6 +173,7 @@ def to_result(self, translation: Optional[str] = None) -> TranslationResult:
165173
self._source_tokens,
166174
self._target_tokens,
167175
self._confidences,
176+
self._sequence_confidence,
168177
sources,
169178
alignment,
170179
phrases,

machine/translation/truecaser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def truecase_translation_result(
2929
result.source_tokens,
3030
target_tokens,
3131
result.confidences,
32+
result.sequence_confidence,
3233
result.sources,
3334
result.alignment,
3435
result.phrases,

tests/jobs/test_nmt_engine_build_job.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ def test_run(decoy: Decoy) -> None:
5050
]
5151
assert pretranslations[0]["translationTokens"] == ["Please", ",", "I", "have", "booked", "a", "room", "."]
5252
assert len(pretranslations[0]["alignment"]) > 0
53+
assert pretranslations[0]["sequenceConfidence"] == 0.5
5354
else:
5455
assert pretranslations[0]["sourceTokens"] == []
5556
assert pretranslations[0]["translationTokens"] == []
5657
assert len(pretranslations[0]["alignment"]) == 0
58+
assert pretranslations[0]["sequenceConfidence"] == 0.5
5759
decoy.verify(env.translation_file_service.save_model(Path("model.tar.gz"), "models/save-model.tar.gz"), times=1)
5860

5961

@@ -86,6 +88,7 @@ def __init__(self, decoy: Decoy) -> None:
8688
source_tokens="Por favor , tengo reservada una habitación .".split(),
8789
target_tokens="Please , I have booked a room .".split(),
8890
confidences=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
91+
sequence_confidence=0.5,
8992
sources=[
9093
TranslationSources.NMT,
9194
TranslationSources.NMT,
@@ -135,6 +138,7 @@ def __init__(self, decoy: Decoy) -> None:
135138
sourceTokens=[],
136139
translationTokens=[],
137140
alignment="",
141+
sequenceConfidence=0.5,
138142
)
139143
]
140144
)

tests/jobs/test_smt_engine_build_job.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(self, decoy: Decoy) -> None:
6565
source_tokens="Por favor , tengo reservada una habitación .".split(),
6666
target_tokens="Please , I have booked a room .".split(),
6767
confidences=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
68+
sequence_confidence=0.5,
6869
sources=[
6970
TranslationSources.SMT,
7071
TranslationSources.SMT,
@@ -140,6 +141,7 @@ def __init__(self, decoy: Decoy) -> None:
140141
sourceTokens=[],
141142
translationTokens=[],
142143
alignment="",
144+
sequenceConfidence=0.5,
143145
)
144146
]
145147
)

tests/translation/huggingface/test_hugging_face_nmt_engine.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55

66
skip("skipping Hugging Face tests on MacOS", allow_module_level=True)
77

8+
from math import exp, log
9+
810
from pytest import approx, mark, raises
911

1012
from machine.translation.huggingface import HuggingFaceNmtEngine
13+
from machine.translation.translation_result import TranslationResult
1114

1215

1316
@mark.parametrize("output_attentions", [True, False])
@@ -26,16 +29,23 @@ def test_translate_n_batch_beam(output_attentions: bool) -> None:
2629
)
2730
assert results[0][0].translation == "skaberskaber Dollar Dollar ፤ ፤ gerekir gerekir"
2831
assert results[0][0].confidences[0] == approx(1.08e-05, 0.01)
32+
assert results[0][0].sequence_confidence == approx(_get_sequence_confidence(results[0][0]), 0.01)
2933
assert str(results[0][0].alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "")
34+
3035
assert results[0][1].translation == "skaberskaber Dollar Dollar ፤ ፤ ፤ gerekir"
3136
assert results[0][1].confidences[0] == approx(1.08e-05, 0.01)
37+
assert results[0][1].sequence_confidence == approx(_get_sequence_confidence(results[0][0]), 0.01)
3238
assert str(results[0][1].alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "")
39+
3340
assert results[1][0].translation == "skaberskaber Dollar Dollar ፤ ፤ gerekir gerekir"
3441
assert results[1][0].confidences[0] == approx(1.08e-05, 0.01)
3542
assert str(results[1][0].alignment) == ("0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6" if output_attentions else "")
43+
assert results[1][0].sequence_confidence == approx(_get_sequence_confidence(results[0][0]), 0.01)
44+
3645
assert results[1][1].translation == "skaberskaber Dollar Dollar ፤ ፤ ፤ gerekir"
3746
assert results[1][1].confidences[0] == approx(1.08e-05, 0.01)
3847
assert str(results[1][1].alignment) == ("0-1 0-2 0-7 1-0 3-3 3-4 3-5 3-6" if output_attentions else "")
48+
assert results[1][1].sequence_confidence == approx(_get_sequence_confidence(results[0][0]), 0.01)
3949

4050

4151
@mark.parametrize("output_attentions", [True, False])
@@ -46,10 +56,16 @@ def test_translate_greedy(output_attentions: bool) -> None:
4656
result = engine.translate("This is a test string")
4757
assert result.translation == "skaberskaber Dollar Dollar Dollar ፤ gerekir gerekir"
4858
assert result.confidences[0] == approx(1.08e-05, 0.01)
59+
assert result.sequence_confidence == -1.0
4960
assert str(result.alignment) == ("2-0 2-1 2-2 2-3 4-4 4-5 4-6 4-7" if output_attentions else "")
5061

5162

5263
@mark.parametrize("output_attentions", [True, False])
5364
def test_construct_invalid_lang(output_attentions: bool) -> None:
5465
with raises(ValueError):
5566
HuggingFaceNmtEngine("stas/tiny-m2m_100", src_lang="qaa", tgt_lang="es", output_attentions=output_attentions)
67+
68+
69+
def _get_sequence_confidence(result: TranslationResult) -> float:
70+
# Inject a 0 score for the BOS token
71+
return exp(sum([log(c) for c in result.confidences] + [0]) / (len(result.confidences) + 1))

0 commit comments

Comments
 (0)