Skip to content

Commit c848f64

Browse files
authored
Add support for separate key terms training data files (#254)
1 parent f67de52 commit c848f64

4 files changed

Lines changed: 67 additions & 23 deletions

File tree

machine/corpora/corpora_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,11 @@ def get_files(file_patterns: Iterable[str]) -> Iterable[Tuple[str, str]]:
5454
if len(file_patterns) == 1 and os.path.isfile(file_patterns[0]):
5555
yield ("*all*", file_patterns[0])
5656
else:
57-
for file_pattern in file_patterns:
57+
for i, file_pattern in enumerate(file_patterns):
58+
if os.path.isfile(file_pattern):
59+
yield (str(i), file_pattern)
60+
continue
61+
5862
if "*" not in file_pattern and "?" not in file_pattern and not os.path.exists(file_pattern):
5963
raise FileNotFoundError(f"The specified path does not exist: {file_pattern}.")
6064

machine/jobs/translation_file_service.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from contextlib import contextmanager
22
from pathlib import Path
3-
from typing import Any, Generator, Iterator, List, TypedDict
3+
from typing import Any, Generator, Iterator, List, Optional, TypedDict, Union
44

55
import json_stream
66

@@ -21,40 +21,56 @@ class PretranslationInfo(TypedDict):
2121
alignment: str
2222

2323

24-
SOURCE_FILENAME = "train.src.txt"
25-
TARGET_FILENAME = "train.trg.txt"
26-
SOURCE_PRETRANSLATION_FILENAME = "pretranslate.src.json"
27-
TARGET_PRETRANSLATION_FILENAME = "pretranslate.trg.json"
28-
29-
3024
class TranslationFileService:
3125
def __init__(
3226
self,
3327
type: SharedFileServiceType,
3428
config: Any,
29+
source_filenames: Optional[Union[str, List[str]]] = None,
30+
target_filenames: Optional[Union[str, List[str]]] = None,
31+
source_pretranslation_filename: str = "pretranslate.src.json",
32+
target_pretranslation_filename: str = "pretranslate.trg.json",
3533
) -> None:
3634

35+
if source_filenames is None:
36+
source_filenames = ["train.src.txt", "train.key-terms.src.txt"]
37+
if target_filenames is None:
38+
target_filenames = ["train.trg.txt", "train.key-terms.trg.txt"]
39+
40+
self._source_filenames = [source_filenames] if isinstance(source_filenames, str) else list(source_filenames)
41+
self._target_filenames = [target_filenames] if isinstance(target_filenames, str) else list(target_filenames)
42+
self._source_pretranslation_filename = source_pretranslation_filename
43+
self._target_pretranslation_filename = target_pretranslation_filename
44+
3745
self.shared_file_service: SharedFileServiceBase = get_shared_file_service(type, config)
3846

3947
def create_source_corpus(self) -> TextCorpus:
4048
return TextFileTextCorpus(
41-
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{SOURCE_FILENAME}")
49+
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{source_filename}")
50+
for source_filename in self._source_filenames
4251
)
4352

4453
def create_target_corpus(self) -> TextCorpus:
4554
return TextFileTextCorpus(
46-
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{TARGET_FILENAME}")
55+
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{target_filename}")
56+
for target_filename in self._target_filenames
4757
)
4858

4959
def exists_source_corpus(self) -> bool:
50-
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{SOURCE_FILENAME}")
60+
return all(
61+
self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{source_filename}")
62+
for source_filename in self._source_filenames
63+
)
5164

5265
def exists_target_corpus(self) -> bool:
53-
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{TARGET_FILENAME}")
66+
return all(
67+
self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{target_filename}")
68+
for target_filename in self._target_filenames
69+
)
5470

5571
def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]:
5672
src_pretranslate_path = self.shared_file_service.download_file(
57-
f"{self.shared_file_service.build_path}/{SOURCE_PRETRANSLATION_FILENAME}"
73+
f"{self.shared_file_service.build_path}/{self._source_pretranslation_filename}"
5874
)
5975

6076
def generator() -> Generator[PretranslationInfo, None, None]:
@@ -77,4 +93,4 @@ def save_model(self, model_path: Path, destination: str) -> None:
7793

7894
@contextmanager
7995
def open_target_pretranslation_writer(self) -> Iterator[DictToJsonWriter]:
80-
return self.shared_file_service.open_target_writer(TARGET_PRETRANSLATION_FILENAME)
96+
return self.shared_file_service.open_target_writer(self._target_pretranslation_filename)

machine/jobs/word_alignment_file_service.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from contextlib import contextmanager
22
from pathlib import Path
3-
from typing import Any, Iterator, List, TypedDict
3+
from typing import Any, Iterator, List, Optional, TypedDict, Union
44

55
import json_stream
66

@@ -23,27 +23,34 @@ def __init__(
2323
self,
2424
type: SharedFileServiceType,
2525
config: Any,
26-
source_filename: str = "train.src.txt",
27-
target_filename: str = "train.trg.txt",
26+
source_filenames: Optional[Union[str, List[str]]] = None,
27+
target_filenames: Optional[Union[str, List[str]]] = None,
2828
word_alignment_input_filename: str = "word_alignments.inputs.json",
2929
word_alignment_output_filename: str = "word_alignments.outputs.json",
3030
) -> None:
3131

32-
self._source_filename = source_filename
33-
self._target_filename = target_filename
32+
if source_filenames is None:
33+
source_filenames = ["train.src.txt", "train.key-terms.src.txt"]
34+
if target_filenames is None:
35+
target_filenames = ["train.trg.txt", "train.key-terms.trg.txt"]
36+
37+
self._source_filenames = [source_filenames] if isinstance(source_filenames, str) else list(source_filenames)
38+
self._target_filenames = [target_filenames] if isinstance(target_filenames, str) else list(target_filenames)
3439
self._word_alignment_input_filename = word_alignment_input_filename
3540
self._word_alignment_output_filename = word_alignment_output_filename
3641

3742
self.shared_file_service: SharedFileServiceBase = get_shared_file_service(type, config)
3843

3944
def create_source_corpus(self) -> TextCorpus:
4045
return TextFileTextCorpus(
41-
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._source_filename}")
46+
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{source_filename}")
47+
for source_filename in self._source_filenames
4248
)
4349

4450
def create_target_corpus(self) -> TextCorpus:
4551
return TextFileTextCorpus(
46-
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._target_filename}")
52+
self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{target_filename}")
53+
for target_filename in self._target_filenames
4754
)
4855

4956
def get_word_alignment_inputs(self) -> List[WordAlignmentInput]:
@@ -64,10 +71,16 @@ def get_word_alignment_inputs(self) -> List[WordAlignmentInput]:
6471
return wa_inputs
6572

6673
def exists_source_corpus(self) -> bool:
67-
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._source_filename}")
74+
return all(
75+
self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{source_filename}")
76+
for source_filename in self._source_filenames
77+
)
6878

6979
def exists_target_corpus(self) -> bool:
70-
return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._target_filename}")
80+
return all(
81+
self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{target_filename}")
82+
for target_filename in self._target_filenames
83+
)
7184

7285
def exists_word_alignment_inputs(self) -> bool:
7386
return self.shared_file_service._exists_file(

tests/corpora/test_text_file_text_corpus.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,17 @@ def test_folder() -> None:
1414
assert [t.id for t in corpus.texts] == ["Test1", "Test2", "Test3"]
1515

1616

17+
def test_multiple_files() -> None:
18+
corpus = TextFileTextCorpus(
19+
[
20+
TEXT_TEST_PROJECT_PATH / "Test1.txt",
21+
TEXT_TEST_PROJECT_PATH / "Test2.txt",
22+
TEXT_TEST_PROJECT_PATH / "Test3.txt",
23+
]
24+
)
25+
assert [t.id for t in corpus.texts] == ["0", "1", "2"]
26+
27+
1728
def test_single_file() -> None:
1829
corpus = TextFileTextCorpus(TEXT_TEST_PROJECT_PATH / "Test1.txt")
1930
assert [t.id for t in corpus.texts] == ["*all*"]

0 commit comments

Comments
 (0)