Skip to content

Commit dc9de68

Browse files
committed
Merge branch 'main' of https://github.com/MaoSong2022/vrdu_data_process into main
2 parents fbcf5a1 + ce67aa3 commit dc9de68

2 files changed

Lines changed: 191 additions & 1 deletion

File tree

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ share/python-wheels/
2525
.installed.cfg
2626
*.egg
2727
MANIFEST
28+
.idea
29+
.vscode
2830

2931
# PyInstaller
3032
# Usually these files are written by a python script from a template
@@ -173,4 +175,5 @@ TexSoup/examples/
173175
TexSoup/tests/
174176
output.txt
175177

176-
vrdu_arxiv
178+
vrdu_arxiv
179+
COCO_datasets

scripts/collect_coco_dataset.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
2+
import os
3+
import re
4+
import cv2
5+
import time
6+
import json
7+
import shutil
8+
import random
9+
import argparse
10+
from tqdm import tqdm
11+
12+
13+
def extract_tex_files(path, target_pattern):
14+
tex_files = []
15+
16+
for root, dirs, files in os.walk(path):
17+
for file in files:
18+
if not file.endswith(".tex"):
19+
continue
20+
if file.startswith("paper_"):
21+
continue
22+
23+
tex_file = os.path.join(root, file)
24+
25+
try:
26+
with open(tex_file) as f:
27+
content = f.read()
28+
except UnicodeDecodeError:
29+
continue
30+
31+
if "\\begin{document}" not in content:
32+
continue
33+
34+
if not any(re.match(pattern, root.split('/')[-2]) for pattern in target_pattern):
35+
continue
36+
37+
if os.path.exists(f'{root}/output/result/layout_annotation.json'):
38+
tex_files.append(tex_file)
39+
return tex_files
40+
41+
42+
def main(path, target_pattern, ratio):
43+
now_time = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
44+
coco_dataset_name = f'COCO_datasets/Multi-modal_COCO_dataset_{now_time}'
45+
46+
target_images_folder = f'{coco_dataset_name}/images'
47+
os.makedirs(coco_dataset_name, exist_ok=True)
48+
os.makedirs(target_images_folder, exist_ok=True)
49+
50+
tex_files = sorted(extract_tex_files(path, target_pattern))
51+
tex_files_length = len(tex_files)
52+
53+
random.seed(0)
54+
random.shuffle(tex_files)
55+
train_list = tex_files[:int(tex_files_length * ratio)]
56+
val_list = tex_files[int(tex_files_length * ratio):]
57+
dataset_dict = {
58+
"train": train_list,
59+
"val": val_list
60+
}
61+
62+
info = {
63+
"year": 2023,
64+
"version": "1.0",
65+
"description": "COCO format dataset converted form document genome",
66+
"contributor": "ADLab",
67+
"url": "",
68+
"date_created": f"{time.ctime()}"
69+
}
70+
licenses = [
71+
{
72+
"url": "http://creativecommons.org/licenses/by/2.0/",
73+
"id": 4,
74+
"name": "Attribution License"
75+
}
76+
]
77+
images = []
78+
annotations = []
79+
categories = [
80+
{"id": 0, "name": "Algorithm", "supercategory": "Algorithm"},
81+
{"id": 1, "name": "Caption", "supercategory": "Caption"},
82+
{"id": 2, "name": "Equation", "supercategory": "Equation"},
83+
{"id": 3, "name": "Figure", "supercategory": "Figure"},
84+
{"id": 4, "name": "Footnote", "supercategory": "Footnote"},
85+
{"id": 5, "name": "List", "supercategory": "List"},
86+
{"id": 6, "name": "Others", "supercategory": "Others"},
87+
{"id": 7, "name": "Table", "supercategory": "Table"},
88+
{"id": 8, "name": "Text", "supercategory": "Text"},
89+
{"id": 9, "name": "Text-EQ", "supercategory": "Text"},
90+
{"id": 10, "name": "Title", "supercategory": "Title"},
91+
{"id": 11, "name": "Reference", "supercategory": "Reference"},
92+
{"id": 12, "name": "PaperTitle", "supercategory": "Title"},
93+
{"id": 13, "name": "Code", "supercategory": "Algorithm"},
94+
{"id": 14, "name": "Abstract", "supercategory": "Text"}
95+
]
96+
97+
anno_id = 0
98+
image_id = 0
99+
pattern = r'\d+\.\d+(v\d+)?'
100+
for key, tex_files in dataset_dict.items():
101+
print(f"Processing {key} set...")
102+
103+
images = []
104+
annotations = []
105+
106+
for tex_file in tqdm(tex_files):
107+
coco_annotation_file = f'{os.path.dirname(tex_file)}/output/result/layout_annotation.json'
108+
images_path = f'{os.path.dirname(tex_file)}/output/colored'
109+
110+
if not re.search(pattern, tex_file): raise NotImplementedError
111+
arxiv_paper_id = re.search(pattern, tex_file).group()
112+
113+
with open(coco_annotation_file, 'r') as fp:
114+
coco_annotation = json.load(fp)
115+
sub_images = coco_annotation['images']
116+
sub_annotations_list = coco_annotation['annotations']
117+
118+
grouped_annotations = {}
119+
for annotation in sub_annotations_list:
120+
anno_image_id = annotation['image_id']
121+
# 检查image_id是否已经在字典中
122+
if anno_image_id not in grouped_annotations:
123+
# 如果不在,创建一个新的列表
124+
grouped_annotations[anno_image_id] = []
125+
# 将注释添加到相应的列表中
126+
grouped_annotations[anno_image_id].append(annotation)
127+
128+
grouped_annotations_key_list = sorted(grouped_annotations.keys())
129+
for idx in grouped_annotations_key_list:
130+
file_name = arxiv_paper_id.replace('.', '_') + f'-page_{idx:04d}.png'
131+
page_image = cv2.imread(f'{images_path}/{idx}.png')
132+
H, W, _ = page_image.shape
133+
page_annotations = grouped_annotations[idx]
134+
135+
images.append(
136+
{
137+
"id": image_id,
138+
"width": W,
139+
"height": H,
140+
"file_name": file_name,
141+
"coco_url": "https://github.com/MaoSong2022/vrdu_data_process",
142+
"date_captured": now_time,
143+
"flickr_url": "",
144+
"licenses": 4
145+
}
146+
)
147+
shutil.copyfile(f'{images_path}/{idx}.png', f'{target_images_folder}/{file_name}')
148+
149+
for anno in page_annotations:
150+
annotations.append(
151+
{
152+
"id": anno_id,
153+
"image_id": image_id,
154+
"category_id": anno["category_id"],
155+
"segmentation": anno["segmentation"],
156+
"bbox": anno["bbox"],
157+
"area": anno["area"],
158+
"iscrowd": anno["iscrowd"]
159+
}
160+
161+
)
162+
anno_id += 1
163+
image_id += 1
164+
165+
coco_json_content = {
166+
"info": info,
167+
"licenses": licenses,
168+
"images": images,
169+
"annotations": annotations,
170+
"categories": categories,
171+
}
172+
173+
with open(f'{coco_dataset_name}/{key}.json', 'w') as fp:
174+
json.dump(coco_json_content, fp, indent=4)
175+
176+
177+
if __name__ == "__main__":
178+
# parser = argparse.ArgumentParser()
179+
# parser.add_argument("-p", "--path", type=str, required=True)
180+
# parser.add_argument("-r", "--ratio", type=float, default=0.8)
181+
# args = parser.parse_args()
182+
# path = args.path
183+
184+
target_pattern = [r'^cs\.\w+$']
185+
path = os.path.expanduser("/cpfs01/user/penghaoyang/code/vrdu_data_process/vrdu_arxiv")
186+
ratio = 0.8
187+
main(path, target_pattern, ratio)

0 commit comments

Comments
 (0)