|
1 | 1 | import os |
2 | 2 | import shutil |
3 | 3 | import uuid |
| 4 | +import datetime |
4 | 5 |
|
5 | 6 | from vrdu import utils |
6 | 7 | from vrdu.config import config |
7 | 8 |
|
8 | 9 |
|
9 | | -def extract_category(path, category_name, output_path): |
10 | | - print(f"extract category {category_name} from {path} to {output_path}") |
11 | | - json_file = os.path.join(path, "reading_annotation.json") |
12 | | - data = utils.load_json(json_file) |
13 | | - |
14 | | - category = config.name2category[category_name] |
15 | | - result_json = os.path.join(output_path, "reading_annotation.json") |
| 10 | +def extract_category(path, category, output_path): |
| 11 | + source_json_file = os.path.join(path, "reading_annotation.json") |
| 12 | + data = utils.load_json(source_json_file) |
16 | 13 |
|
17 | 14 | result = [] |
18 | 15 |
|
19 | | - for x, pairs in data.items(): |
20 | | - if not x.isnumeric(): |
| 16 | + for key, blocks in data.items(): |
| 17 | + # x must be page index |
| 18 | + if not key.isnumeric(): |
21 | 19 | continue |
22 | | - for p in pairs: |
23 | | - if "category" not in p: |
| 20 | + for block in blocks: |
| 21 | + if "category" not in block: |
24 | 22 | return |
25 | | - if p["category"] == category: |
26 | | - result.append(p) |
| 23 | + if block["category"] == category: |
| 24 | + result.append(block) |
27 | 25 |
|
28 | | - for x in result: |
| 26 | + for key in result: |
29 | 27 | output_image_name = f"{uuid.uuid4()}.png" |
30 | 28 | shutil.copyfile( |
31 | | - os.path.join(path, x["image_path"]), |
| 29 | + os.path.join(path, key["image_path"]), |
32 | 30 | os.path.join(output_path, output_image_name), |
33 | 31 | ) |
34 | | - x["image_path"] = output_image_name |
35 | | - x["paper_source"] = path |
| 32 | + key["image_path"] = output_image_name |
| 33 | + key["paper_source"] = path |
| 34 | + key["added_date"] = str(datetime.date.today()) |
36 | 35 |
|
37 | | - if os.path.exists(result_json): |
38 | | - data = utils.load_json(result_json) |
39 | | - result.extend(data) |
| 36 | + return result |
40 | 37 |
|
41 | | - utils.export_to_json(result, result_json) |
42 | 38 |
|
| 39 | +def main(category_name, input_directory, output_path): |
| 40 | + """extract all blocks that is of the given category to a given output directory""" |
| 41 | + if category_name not in config.name2category.keys(): |
| 42 | + raise KeyError( |
| 43 | + f"Unknown category name, avalaible category names: {list(config.name2category.keys())}" |
| 44 | + ) |
43 | 45 |
|
44 | | -if __name__ == "__main__": |
45 | | - category_name = "Table" |
46 | | - input_directory = os.path.expanduser("/home/PJLAB/maosong/vrdu_data") |
47 | | - output_path = os.path.expanduser(f"~/Desktop/sample_data/{category_name}") |
48 | | - if os.path.exists(output_path): |
49 | | - shutil.rmtree(output_path) |
50 | | - os.makedirs(output_path) |
| 46 | + category = config.name2category[category_name] |
| 47 | + result_json = os.path.join(output_path, "reading_annotation.json") |
| 48 | + |
| 49 | + existed_source = set() |
| 50 | + if os.path.exists(result_json): |
| 51 | + existed_source = set( |
| 52 | + item["paper_source"] for item in utils.load_json(result_json) |
| 53 | + ) |
51 | 54 |
|
52 | 55 | count = 0 |
53 | 56 | for root, dirs, files in os.walk(input_directory): |
54 | 57 | if "reading_annotation.json" not in files: |
55 | 58 | continue |
| 59 | + |
| 60 | + # if data of this folder has been extracted |
| 61 | + if root in existed_source: |
| 62 | + continue |
| 63 | + |
56 | 64 | count += 1 |
57 | 65 |
|
58 | | - extract_category(root, category_name, output_path) |
| 66 | + print(f"extract data from {root} to {output_path}") |
| 67 | + extract_category(root, category, output_path) |
| 68 | + |
| 69 | + # exclude the reading_annotation.json file |
| 70 | + num_of_samples = len(os.listdir(output_path)) - 1 |
| 71 | + print(f"extracted {count} files, {num_of_samples} samples obtained.") |
59 | 72 |
|
60 | | - num_of_sanples = len(os.listdir(output_path)) - 1 |
61 | 73 |
|
62 | | - print(f"extracted {count} files, {num_of_sanples} samples obtained.") |
| 74 | +if __name__ == "__main__": |
| 75 | + category_name = "Equation" |
| 76 | + input_directory = os.path.expanduser("/cpfs01/shared/ADLab/datasets/vrdu_arxiv") |
| 77 | + output_path = os.path.expanduser( |
| 78 | + f"/cpfs01/shared/ADLab/datasets/vrdu_{category_name.lower()}" |
| 79 | + ) |
| 80 | + |
| 81 | + main(category_name, input_directory, output_path) |
0 commit comments