11from contextlib import contextmanager
22from pathlib import Path
3- from typing import Any , Generator , Iterator , List , TypedDict
3+ from typing import Any , Generator , Iterator , List , Optional , TypedDict , Union
44
55import json_stream
66
@@ -21,40 +21,56 @@ class PretranslationInfo(TypedDict):
2121 alignment : str
2222
2323
24- SOURCE_FILENAME = "train.src.txt"
25- TARGET_FILENAME = "train.trg.txt"
26- SOURCE_PRETRANSLATION_FILENAME = "pretranslate.src.json"
27- TARGET_PRETRANSLATION_FILENAME = "pretranslate.trg.json"
28-
29-
3024class TranslationFileService :
3125 def __init__ (
3226 self ,
3327 type : SharedFileServiceType ,
3428 config : Any ,
29+ source_filenames : Optional [Union [str , List [str ]]] = None ,
30+ target_filenames : Optional [Union [str , List [str ]]] = None ,
31+ source_pretranslation_filename : str = "pretranslate.src.json" ,
32+ target_pretranslation_filename : str = "pretranslate.trg.json" ,
3533 ) -> None :
3634
35+ if source_filenames is None :
36+ source_filenames = ["train.src.txt" , "train.key-terms.src.txt" ]
37+ if target_filenames is None :
38+ target_filenames = ["train.trg.txt" , "train.key-terms.trg.txt" ]
39+
40+ self ._source_filenames = [source_filenames ] if isinstance (source_filenames , str ) else list (source_filenames )
41+ self ._target_filenames = [target_filenames ] if isinstance (target_filenames , str ) else list (target_filenames )
42+ self ._source_pretranslation_filename = source_pretranslation_filename
43+ self ._target_pretranslation_filename = target_pretranslation_filename
44+
3745 self .shared_file_service : SharedFileServiceBase = get_shared_file_service (type , config )
3846
3947 def create_source_corpus (self ) -> TextCorpus :
4048 return TextFileTextCorpus (
41- self .shared_file_service .download_file (f"{ self .shared_file_service .build_path } /{ SOURCE_FILENAME } " )
49+ self .shared_file_service .download_file (f"{ self .shared_file_service .build_path } /{ source_filename } " )
50+ for source_filename in self ._source_filenames
4251 )
4352
4453 def create_target_corpus (self ) -> TextCorpus :
4554 return TextFileTextCorpus (
46- self .shared_file_service .download_file (f"{ self .shared_file_service .build_path } /{ TARGET_FILENAME } " )
55+ self .shared_file_service .download_file (f"{ self .shared_file_service .build_path } /{ target_filename } " )
56+ for target_filename in self ._target_filenames
4757 )
4858
4959 def exists_source_corpus (self ) -> bool :
50- return self .shared_file_service ._exists_file (f"{ self .shared_file_service .build_path } /{ SOURCE_FILENAME } " )
60+ return all (
61+ self .shared_file_service ._exists_file (f"{ self .shared_file_service .build_path } /{ source_filename } " )
62+ for source_filename in self ._source_filenames
63+ )
5164
5265 def exists_target_corpus (self ) -> bool :
53- return self .shared_file_service ._exists_file (f"{ self .shared_file_service .build_path } /{ TARGET_FILENAME } " )
66+ return all (
67+ self .shared_file_service ._exists_file (f"{ self .shared_file_service .build_path } /{ target_filename } " )
68+ for target_filename in self ._target_filenames
69+ )
5470
5571 def get_source_pretranslations (self ) -> ContextManagedGenerator [PretranslationInfo , None , None ]:
5672 src_pretranslate_path = self .shared_file_service .download_file (
57- f"{ self .shared_file_service .build_path } /{ SOURCE_PRETRANSLATION_FILENAME } "
73+ f"{ self .shared_file_service .build_path } /{ self . _source_pretranslation_filename } "
5874 )
5975
6076 def generator () -> Generator [PretranslationInfo , None , None ]:
@@ -77,4 +93,4 @@ def save_model(self, model_path: Path, destination: str) -> None:
7793
7894 @contextmanager
7995 def open_target_pretranslation_writer (self ) -> Iterator [DictToJsonWriter ]:
80- return self .shared_file_service .open_target_writer (TARGET_PRETRANSLATION_FILENAME )
96+ return self .shared_file_service .open_target_writer (self . _target_pretranslation_filename )
0 commit comments