Skip to content

Commit ba0b914

Browse files
committed
Address reviewer comment; add key terms as partial words
1 parent 3c3b3f5 commit ba0b914

13 files changed

Lines changed: 177 additions & 30 deletions

machine/corpora/paratext_backup_terms_corpus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from ..utils.typeshed import StrPath
55
from .dictionary_text_corpus import DictionaryTextCorpus
6-
from .key_term_row import KeyTerm
6+
from .key_term import KeyTerm
77
from .memory_text import MemoryText
88
from .text_row import TextRow
99
from .zip_paratext_project_settings_parser import ZipParatextProjectSettingsParser

machine/corpora/paratext_project_terms_parser_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from ..scripture.constants import ORIGINAL_VERSIFICATION
1111
from ..scripture.verse_ref import VerseRef
12-
from .key_term_row import KeyTerm
12+
from .key_term import KeyTerm
1313
from .paratext_project_file_handler import ParatextProjectFileHandler
1414
from .paratext_project_settings import ParatextProjectSettings
1515
from .paratext_project_settings_parser_base import ParatextProjectSettingsParserBase

machine/jobs/huggingface/hugging_face_nmt_model_factory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,12 @@ def create_source_tokenizer_trainer(self, corpus: TextCorpus) -> Trainer:
7878
def create_target_tokenizer_trainer(self, corpus: TextCorpus) -> Trainer:
7979
return NullTrainer()
8080

81-
def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer:
81+
def create_model_trainer(self, corpus: ParallelTextCorpus, terms_corpus: ParallelTextCorpus) -> Trainer:
8282
return HuggingFaceNmtModelTrainer(
8383
self._model,
8484
self._training_args,
8585
corpus,
86+
terms_corpus,
8687
src_lang=self._config.src_lang,
8788
tgt_lang=self._config.trg_lang,
8889
add_unk_src_tokens=self._config.huggingface.tokenizer.add_unk_src_tokens,

machine/jobs/nmt_engine_build_job.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def _train_model(
5959
source_corpus: TextCorpus,
6060
target_corpus: TextCorpus,
6161
parallel_corpus: ParallelTextCorpus,
62+
parallel_terms_corpus: ParallelTextCorpus,
6263
progress_reporter: PhasedProgressReporter,
6364
check_canceled: Optional[Callable[[], None]],
6465
) -> Tuple[int, float]:
@@ -85,7 +86,7 @@ def _train_model(
8586
logger.info("Training NMT model")
8687
with (
8788
progress_reporter.start_next_phase() as phase_progress,
88-
self._nmt_model_factory.create_model_trainer(parallel_corpus) as model_trainer,
89+
self._nmt_model_factory.create_model_trainer(parallel_corpus, parallel_terms_corpus) as model_trainer,
8990
):
9091
model_trainer.train(progress=phase_progress, check_canceled=check_canceled)
9192
model_trainer.save()

machine/jobs/nmt_model_factory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def create_source_tokenizer_trainer(self, corpus: TextCorpus) -> Trainer: ...
2222
def create_target_tokenizer_trainer(self, corpus: TextCorpus) -> Trainer: ...
2323

2424
@abstractmethod
25-
def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer: ...
25+
def create_model_trainer(self, corpus: ParallelTextCorpus, terms_corpus: ParallelTextCorpus) -> Trainer: ...
2626

2727
@abstractmethod
2828
def create_engine(self) -> TranslationEngine: ...

machine/jobs/smt_engine_build_job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def _train_model(
4545
source_corpus: TextCorpus,
4646
target_corpus: TextCorpus,
4747
parallel_corpus: ParallelTextCorpus,
48+
parallel_terms_corpus: ParallelTextCorpus,
4849
progress_reporter: PhasedProgressReporter,
4950
check_canceled: Optional[Callable[[], None]],
5051
) -> Tuple[int, float]:

machine/jobs/translation_engine_build_job.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,20 @@ def run(
2828
target_corpus = self._translation_file_service.create_target_corpus()
2929
parallel_corpus: ParallelTextCorpus = source_corpus.align_rows(target_corpus)
3030

31-
parallel_corpus_size = parallel_corpus.count(include_empty=False)
31+
source_terms_corpus = self._translation_file_service.create_source_terms_corpus()
32+
target_terms_corpus = self._translation_file_service.create_target_terms_corpus()
33+
parallel_terms_corpus: ParallelTextCorpus = source_terms_corpus.align_rows(target_terms_corpus)
34+
35+
parallel_corpus_size = parallel_corpus.count(include_empty=False) + parallel_terms_corpus.count(
36+
include_empty=False
37+
)
3238
progress_reporter = self._get_progress_reporter(progress, parallel_corpus_size)
3339

3440
if parallel_corpus_size == 0:
3541
train_corpus_size, confidence = self._respond_to_no_training_corpus()
3642
else:
3743
train_corpus_size, confidence = self._train_model(
38-
source_corpus, target_corpus, parallel_corpus, progress_reporter, check_canceled
44+
source_corpus, target_corpus, parallel_corpus, parallel_terms_corpus, progress_reporter, check_canceled
3945
)
4046

4147
if check_canceled is not None:
@@ -63,6 +69,7 @@ def _train_model(
6369
source_corpus: TextCorpus,
6470
target_corpus: TextCorpus,
6571
parallel_corpus: ParallelTextCorpus,
72+
parallel_terms_corpus: ParallelTextCorpus,
6673
progress_reporter: PhasedProgressReporter,
6774
check_canceled: Optional[Callable[[], None]],
6875
) -> Tuple[int, float]: ...

machine/jobs/translation_file_service.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from contextlib import contextmanager
22
from pathlib import Path
3-
from typing import Any, Generator, Iterator, List, Optional, TypedDict, Union
3+
from typing import Any, Generator, Iterator, List, Optional, TypedDict
44

55
import json_stream
66

@@ -26,46 +26,61 @@ def __init__(
2626
self,
2727
type: SharedFileServiceType,
2828
config: Any,
29-
source_filenames: Optional[Union[str, List[str]]] = None,
30-
target_filenames: Optional[Union[str, List[str]]] = None,
29+
source_filename: Optional[str] = "train.src.txt",
30+
target_filename: Optional[str] = "train.trg.txt",
31+
source_terms_filename: Optional[str] = "train.key-terms.src.txt",
32+
target_terms_filename: Optional[str] = "train.key-terms.trg.txt",
3133
source_pretranslation_filename: str = "pretranslate.src.json",
3234
target_pretranslation_filename: str = "pretranslate.trg.json",
3335
) -> None:
3436

35-
if source_filenames is None:
36-
source_filenames = ["train.src.txt", "train.key-terms.src.txt"]
37-
if target_filenames is None:
38-
target_filenames = ["train.trg.txt", "train.key-terms.trg.txt"]
39-
40-
self._source_filenames = [source_filenames] if isinstance(source_filenames, str) else list(source_filenames)
41-
self._target_filenames = [target_filenames] if isinstance(target_filenames, str) else list(target_filenames)
37+
self._source_filename = source_filename
38+
self._target_filename = target_filename
39+
self._source_terms_filename = source_terms_filename
40+
self._target_terms_filename = target_terms_filename
4241
self._source_pretranslation_filename = source_pretranslation_filename
4342
self._target_pretranslation_filename = target_pretranslation_filename
4443

4544
self.shared_file_service: SharedFileServiceBase = get_shared_file_service(type, config)
4645

4746
def create_source_corpus(self) -> TextCorpus:
4847
return TextFileTextCorpus(
49-
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{source_filename}")
50-
for source_filename in self._source_filenames
48+
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._source_filename}")
5149
)
5250

5351
def create_target_corpus(self) -> TextCorpus:
5452
return TextFileTextCorpus(
55-
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{target_filename}")
56-
for target_filename in self._target_filenames
53+
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._target_filename}")
5754
)
5855

5956
def exists_source_corpus(self) -> bool:
60-
return all(
61-
self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{source_filename}")
62-
for source_filename in self._source_filenames
63-
)
57+
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._source_filename}")
6458

6559
def exists_target_corpus(self) -> bool:
66-
return all(
67-
self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{target_filename}")
68-
for target_filename in self._target_filenames
60+
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._target_filename}")
61+
62+
def create_source_terms_corpus(self) -> TextCorpus:
63+
return TextFileTextCorpus(
64+
self.shared_file_service.download_file(
65+
f"{self.shared_file_service.build_path}/{self._source_terms_filename}"
66+
)
67+
)
68+
69+
def create_target_terms_corpus(self) -> TextCorpus:
70+
return TextFileTextCorpus(
71+
self.shared_file_service.download_file(
72+
f"{self.shared_file_service.build_path}/{self._target_terms_filename}"
73+
)
74+
)
75+
76+
def exists_source_terms_corpus(self) -> bool:
77+
return self.shared_file_service._exists_file(
78+
f"{self.shared_file_service.build_path}/{self._source_terms_filename}"
79+
)
80+
81+
def exists_target_terms_corpus(self) -> bool:
82+
return self.shared_file_service._exists_file(
83+
f"{self.shared_file_service.build_path}/{self._target_terms_filename}"
6984
)
7085

7186
def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]:

machine/translation/huggingface/hugging_face_nmt_model_trainer.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch # pyright: ignore[reportMissingImports]
1010
from accelerate import Accelerator # pyright: ignore[reportMissingImports]
1111
from accelerate.utils.memory import should_reduce_batch_size # pyright: ignore[reportMissingImports]
12+
from datasets import concatenate_datasets
1213
from datasets.arrow_dataset import Dataset
1314
from sacremoses import MosesPunctNormalizer
1415
from torch import Tensor # pyright: ignore[reportMissingImports]
@@ -36,9 +37,11 @@
3637
PreTrainedTokenizerFast,
3738
Seq2SeqTrainer,
3839
Seq2SeqTrainingArguments,
40+
TensorType,
3941
TrainerCallback,
4042
set_seed,
4143
)
44+
from transformers.tokenization_utils import BatchEncoding
4245
from transformers.trainer_callback import TrainerControl, TrainerState
4346
from transformers.trainer_utils import get_last_checkpoint
4447
from transformers.training_args import TrainingArguments
@@ -88,6 +91,7 @@ def __init__(
8891
model: Union[PreTrainedModel, str],
8992
training_args: Seq2SeqTrainingArguments,
9093
corpus: Union[ParallelTextCorpus, Dataset],
94+
terms_corpus: Optional[Union[ParallelTextCorpus, Dataset]] = None,
9195
src_lang: Optional[str] = None,
9296
tgt_lang: Optional[str] = None,
9397
max_src_length: Optional[int] = None,
@@ -98,6 +102,7 @@ def __init__(
98102
self._model = model
99103
self._training_args = training_args
100104
self._corpus = corpus
105+
self._terms_corpus = terms_corpus
101106
self._src_lang = src_lang
102107
self._tgt_lang = tgt_lang
103108
self._trainer: Optional[Seq2SeqTrainer] = None
@@ -170,6 +175,13 @@ def train(
170175
else:
171176
train_dataset = self._corpus.filter_nonempty().to_hf_dataset(src_lang, tgt_lang)
172177

178+
train_terms_dataset = None
179+
if self._terms_corpus is not None:
180+
if isinstance(self._terms_corpus, Dataset):
181+
train_terms_dataset = self._terms_corpus
182+
else:
183+
train_terms_dataset = self._terms_corpus.filter_nonempty().to_hf_dataset(src_lang, tgt_lang)
184+
173185
def find_missing_characters(tokenizer: Any, train_dataset: Dataset, lang_codes: List[str]) -> List[str]:
174186
vocab = tokenizer.get_vocab().keys()
175187
charset = set()
@@ -222,7 +234,15 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any:
222234
lang_codes.append(src_lang)
223235
if self._add_unk_tgt_tokens:
224236
lang_codes.append(tgt_lang)
225-
missing_tokens = find_missing_characters(tokenizer, train_dataset, lang_codes)
237+
missing_tokens = find_missing_characters(
238+
tokenizer,
239+
(
240+
concatenate_datasets([train_dataset, train_terms_dataset])
241+
if train_terms_dataset is not None
242+
else train_dataset
243+
),
244+
lang_codes,
245+
)
226246
if missing_tokens:
227247
tokenizer = add_tokens(tokenizer, missing_tokens)
228248

@@ -291,6 +311,22 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any:
291311
"memory"
292312
)
293313

314+
def batch_prepare_for_model(
315+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
316+
batch_tokens: List[List[str]],
317+
return_tensors: Optional[Union[str, TensorType]] = None,
318+
) -> BatchEncoding:
319+
batch_outputs: Dict[str, Any] = {}
320+
for tokens in batch_tokens:
321+
ids = cast(List[int], tokenizer.convert_tokens_to_ids(tokens))
322+
outputs = tokenizer.prepare_for_model(ids, add_special_tokens=False)
323+
324+
for key, value in outputs.items():
325+
if key not in batch_outputs:
326+
batch_outputs[key] = []
327+
batch_outputs[key].append(value)
328+
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
329+
294330
def preprocess_function(examples):
295331
if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)):
296332
inputs = [self._mpn.normalize(prefix + ex[src_lang]) for ex in examples["translation"]]
@@ -306,6 +342,42 @@ def preprocess_function(examples):
306342
model_inputs["labels"] = labels["input_ids"]
307343
return model_inputs
308344

345+
def preprocess_terms_function(examples):
346+
if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)):
347+
inputs = [self._mpn.normalize(ex[src_lang]) for ex in examples["translation"]]
348+
targets = [self._mpn.normalize(ex[tgt_lang]) for ex in examples["translation"]]
349+
else:
350+
inputs = [ex[src_lang] for ex in examples["translation"]]
351+
targets = [ex[tgt_lang] for ex in examples["translation"]]
352+
353+
src_term_tokens = tokenizer(
354+
[prefix + i for i in inputs], max_length=max_src_length, truncation=True
355+
).tokens()
356+
trg_term_tokens = tokenizer(text_target=targets, max_length=max_tgt_length, truncation=True).tokens()
357+
358+
src_term_partial_word_tokens = tokenizer(
359+
[prefix + "\ufffc" + i for i in inputs], max_length=max_src_length + 2, truncation=True
360+
).tokens()
361+
src_term_partial_word_tokens.remove("▁")
362+
src_term_partial_word_tokens.remove("\ufffc")
363+
364+
trg_term_partial_word_tokens = tokenizer(
365+
text_target=["\ufffc" + t for t in targets], max_length=max_tgt_length + 2, truncation=True
366+
).tokens()
367+
trg_term_partial_word_tokens.remove("▁")
368+
trg_term_partial_word_tokens.remove("\ufffc")
369+
370+
model_inputs = batch_prepare_for_model(
371+
tokenizer, [[ex.strip() for ex in src_term_tokens + src_term_partial_word_tokens]]
372+
)
373+
# Tokenize targets with the `text_target` keyword argument
374+
labels = batch_prepare_for_model(
375+
tokenizer, [[ex.strip() for ex in trg_term_tokens + trg_term_partial_word_tokens]]
376+
)
377+
378+
model_inputs["labels"] = labels["input_ids"]
379+
return model_inputs
380+
309381
logger.info("Run tokenizer")
310382
train_dataset = train_dataset.map(
311383
preprocess_function,
@@ -315,6 +387,22 @@ def preprocess_function(examples):
315387
desc="Running tokenizer on train dataset",
316388
)
317389

390+
if train_terms_dataset is not None:
391+
if not isinstance(tokenizer, PreTrainedTokenizerFast):
392+
logger.warning(
393+
f"Adding key terms as partial words is not possible when using the non-fast tokenizer '{type(tokenizer)}'."
394+
)
395+
train_terms_dataset = train_terms_dataset.map(
396+
preprocess_terms_function if isinstance(tokenizer, PreTrainedTokenizerFast) else preprocess_function,
397+
batched=True,
398+
remove_columns=train_terms_dataset.column_names,
399+
load_from_cache_file=True,
400+
desc="Running tokenizer on train terms dataset",
401+
)
402+
403+
# combine terms and non-terms datasets
404+
train_dataset = concatenate_datasets([train_dataset, train_terms_dataset])
405+
318406
data_collator = DataCollatorForSeq2Seq(
319407
tokenizer,
320408
model=model,

0 commit comments

Comments
 (0)