99import torch # pyright: ignore[reportMissingImports]
1010from accelerate import Accelerator # pyright: ignore[reportMissingImports]
1111from accelerate .utils .memory import should_reduce_batch_size # pyright: ignore[reportMissingImports]
12+ from datasets import concatenate_datasets
1213from datasets .arrow_dataset import Dataset
1314from sacremoses import MosesPunctNormalizer
1415from torch import Tensor # pyright: ignore[reportMissingImports]
3637 PreTrainedTokenizerFast ,
3738 Seq2SeqTrainer ,
3839 Seq2SeqTrainingArguments ,
40+ TensorType ,
3941 TrainerCallback ,
4042 set_seed ,
4143)
44+ from transformers .tokenization_utils import BatchEncoding
4245from transformers .trainer_callback import TrainerControl , TrainerState
4346from transformers .trainer_utils import get_last_checkpoint
4447from transformers .training_args import TrainingArguments
@@ -88,6 +91,7 @@ def __init__(
8891 model : Union [PreTrainedModel , str ],
8992 training_args : Seq2SeqTrainingArguments ,
9093 corpus : Union [ParallelTextCorpus , Dataset ],
94+ terms_corpus : Optional [Union [ParallelTextCorpus , Dataset ]] = None ,
9195 src_lang : Optional [str ] = None ,
9296 tgt_lang : Optional [str ] = None ,
9397 max_src_length : Optional [int ] = None ,
@@ -98,6 +102,7 @@ def __init__(
98102 self ._model = model
99103 self ._training_args = training_args
100104 self ._corpus = corpus
105+ self ._terms_corpus = terms_corpus
101106 self ._src_lang = src_lang
102107 self ._tgt_lang = tgt_lang
103108 self ._trainer : Optional [Seq2SeqTrainer ] = None
@@ -170,6 +175,13 @@ def train(
170175 else :
171176 train_dataset = self ._corpus .filter_nonempty ().to_hf_dataset (src_lang , tgt_lang )
172177
178+ train_terms_dataset = None
179+ if self ._terms_corpus is not None :
180+ if isinstance (self ._terms_corpus , Dataset ):
181+ train_terms_dataset = self ._terms_corpus
182+ else :
183+ train_terms_dataset = self ._terms_corpus .filter_nonempty ().to_hf_dataset (src_lang , tgt_lang )
184+
173185 def find_missing_characters (tokenizer : Any , train_dataset : Dataset , lang_codes : List [str ]) -> List [str ]:
174186 vocab = tokenizer .get_vocab ().keys ()
175187 charset = set ()
@@ -222,7 +234,15 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any:
222234 lang_codes .append (src_lang )
223235 if self ._add_unk_tgt_tokens :
224236 lang_codes .append (tgt_lang )
225- missing_tokens = find_missing_characters (tokenizer , train_dataset , lang_codes )
237+ missing_tokens = find_missing_characters (
238+ tokenizer ,
239+ (
240+ concatenate_datasets ([train_dataset , train_terms_dataset ])
241+ if train_terms_dataset is not None
242+ else train_dataset
243+ ),
244+ lang_codes ,
245+ )
226246 if missing_tokens :
227247 tokenizer = add_tokens (tokenizer , missing_tokens )
228248
@@ -291,6 +311,22 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any:
291311 "memory"
292312 )
293313
314+ def batch_prepare_for_model (
315+ tokenizer : Union [PreTrainedTokenizer , PreTrainedTokenizerFast ],
316+ batch_tokens : List [List [str ]],
317+ return_tensors : Optional [Union [str , TensorType ]] = None ,
318+ ) -> BatchEncoding :
319+ batch_outputs : Dict [str , Any ] = {}
320+ for tokens in batch_tokens :
321+ ids = cast (List [int ], tokenizer .convert_tokens_to_ids (tokens ))
322+ outputs = tokenizer .prepare_for_model (ids , add_special_tokens = False )
323+
324+ for key , value in outputs .items ():
325+ if key not in batch_outputs :
326+ batch_outputs [key ] = []
327+ batch_outputs [key ].append (value )
328+ return BatchEncoding (batch_outputs , tensor_type = return_tensors )
329+
294330 def preprocess_function (examples ):
295331 if isinstance (tokenizer , (NllbTokenizer , NllbTokenizerFast )):
296332 inputs = [self ._mpn .normalize (prefix + ex [src_lang ]) for ex in examples ["translation" ]]
@@ -306,6 +342,42 @@ def preprocess_function(examples):
306342 model_inputs ["labels" ] = labels ["input_ids" ]
307343 return model_inputs
308344
345+ def preprocess_terms_function (examples ):
346+ if isinstance (tokenizer , (NllbTokenizer , NllbTokenizerFast )):
347+ inputs = [self ._mpn .normalize (ex [src_lang ]) for ex in examples ["translation" ]]
348+ targets = [self ._mpn .normalize (ex [tgt_lang ]) for ex in examples ["translation" ]]
349+ else :
350+ inputs = [ex [src_lang ] for ex in examples ["translation" ]]
351+ targets = [ex [tgt_lang ] for ex in examples ["translation" ]]
352+
353+ src_term_tokens = tokenizer (
354+ [prefix + i for i in inputs ], max_length = max_src_length , truncation = True
355+ ).tokens ()
356+ trg_term_tokens = tokenizer (text_target = targets , max_length = max_tgt_length , truncation = True ).tokens ()
357+
358+ src_term_partial_word_tokens = tokenizer (
359+ [prefix + "\ufffc " + i for i in inputs ], max_length = max_src_length + 2 , truncation = True
360+ ).tokens ()
361+ src_term_partial_word_tokens .remove ("▁" )
362+ src_term_partial_word_tokens .remove ("\ufffc " )
363+
364+ trg_term_partial_word_tokens = tokenizer (
365+ text_target = ["\ufffc " + t for t in targets ], max_length = max_tgt_length + 2 , truncation = True
366+ ).tokens ()
367+ trg_term_partial_word_tokens .remove ("▁" )
368+ trg_term_partial_word_tokens .remove ("\ufffc " )
369+
370+ model_inputs = batch_prepare_for_model (
371+ tokenizer , [[ex .strip () for ex in src_term_tokens + src_term_partial_word_tokens ]]
372+ )
373+ # Tokenize targets with the `text_target` keyword argument
374+ labels = batch_prepare_for_model (
375+ tokenizer , [[ex .strip () for ex in trg_term_tokens + trg_term_partial_word_tokens ]]
376+ )
377+
378+ model_inputs ["labels" ] = labels ["input_ids" ]
379+ return model_inputs
380+
309381 logger .info ("Run tokenizer" )
310382 train_dataset = train_dataset .map (
311383 preprocess_function ,
@@ -315,6 +387,22 @@ def preprocess_function(examples):
315387 desc = "Running tokenizer on train dataset" ,
316388 )
317389
390+ if train_terms_dataset is not None :
391+ if not isinstance (tokenizer , PreTrainedTokenizerFast ):
392+ logger .warning (
393+ f"Adding key terms as partial words is not possible when using the non-fast tokenizer '{ type (tokenizer )} '."
394+ )
395+ train_terms_dataset = train_terms_dataset .map (
396+ preprocess_terms_function if isinstance (tokenizer , PreTrainedTokenizerFast ) else preprocess_function ,
397+ batched = True ,
398+ remove_columns = train_terms_dataset .column_names ,
399+ load_from_cache_file = True ,
400+ desc = "Running tokenizer on train terms dataset" ,
401+ )
402+
403+ # combine terms and non-terms datasets
404+ train_dataset = concatenate_datasets ([train_dataset , train_terms_dataset ])
405+
318406 data_collator = DataCollatorForSeq2Seq (
319407 tokenizer ,
320408 model = model ,
0 commit comments