Skip to content

Commit dddb66d

Browse files
committed
refactor(arxiv_download.py): download paper files for a given paper paths
1 parent 96b4424 commit dddb66d

1 file changed

Lines changed: 49 additions & 43 deletions

File tree

scripts/arxiv_download.py

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
from typing import List, Dict
44
from tqdm import tqdm
55
import tarfile
6-
import csv
7-
import random
6+
7+
8+
from vrdu import utils
9+
from vrdu import logger
10+
11+
log = logger.setup_app_level_logger(file_name="arxiv_download.log")
812

913

1014
def arxiv_download(data: List[Dict], path: str) -> None:
@@ -26,54 +30,56 @@ def arxiv_download(data: List[Dict], path: str) -> None:
2630
Returns:
2731
None
2832
"""
33+
client = arxiv.Client()
2934
for row in tqdm(data):
30-
category, count = row["categories"], int(row["count"])
31-
print(f"category: {category}, count: {count}")
32-
sub_directory = os.path.join(path, category)
33-
os.makedirs(sub_directory, exist_ok=True)
34-
35-
search = arxiv.Search(
36-
query=category,
37-
max_results=count,
38-
sort_by=arxiv.SortCriterion.SubmittedDate,
39-
)
40-
41-
for result in search.results():
42-
file_name = result._get_default_filename()
43-
if os.path.exists(os.path.join(sub_directory, file_name)):
35+
if row["auto_annotated_paper_path"]:
36+
continue
37+
discipline = row["discipline"]
38+
discipline_path = os.path.join(path, discipline)
39+
os.makedirs(discipline_path, exist_ok=True)
40+
41+
if os.path.exists(os.path.join(discipline_path, row["paper_id"])):
42+
log.debug(f'{os.path.join(discipline_path, row["paper_id"])} exists')
43+
continue
44+
45+
if os.path.exists(os.path.join(discipline_path, row["paper_id"], ".tar.gz")):
46+
log.debug(
47+
f'{os.path.join(discipline_path, row["paper_id"], ".tar.gz")} exists'
48+
)
49+
continue
50+
51+
search_results = client.results(arxiv.Search(id_list=[row["paper_id"]]))
52+
53+
for result in search_results:
54+
tar_file_path = result.download_source(dirpath=discipline_path)
55+
log.debug(f"Downloading tar file {tar_file_path}")
56+
paper_path = os.path.join(discipline_path, row["paper_id"])
57+
try:
58+
with tarfile.open(tar_file_path, "r:gz") as tar:
59+
tar.extractall(paper_path)
60+
except tarfile.ReadError:
61+
log.error(f"{tar_file_path} is not a tar.gz file")
4462
continue
4563

46-
result.download_source(dirpath=sub_directory)
4764

65+
def main():
66+
import argparse
4867

49-
def extract_all_tar_gz(directory):
50-
for root, dirs, files in os.walk(directory):
51-
for file in files:
52-
if not file.endswith(".tar.gz"):
53-
continue
54-
file_path = os.path.join(root, file)
55-
extract_path = os.path.splitext(file_path)[0]
56-
extract_path = os.path.splitext(extract_path)[0]
57-
if os.path.exists(extract_path):
58-
continue
59-
extract_tar_gz(file_path, extract_path)
68+
parser = argparse.ArgumentParser()
69+
parser.add_argument(
70+
"-p", "--path", type=str, required=True, help="Path to save result"
71+
)
72+
parser.add_argument(
73+
"-f", "--file", type=str, required=True, help="json file for saving result"
74+
)
75+
76+
args = parser.parse_args()
77+
output_path, json_file = args.path, args.file
6078

79+
json_data = utils.load_json(json_file)
6180

62-
def extract_tar_gz(file_path, extract_path):
63-
with tarfile.open(file_path, "r:gz") as tar:
64-
tar.extractall(extract_path)
81+
arxiv_download(json_data, output_path)
6582

6683

6784
if __name__ == "__main__":
68-
path = os.path.expanduser("/cpfs01/shared/ADLab/datasets/vrdu_arxiv")
69-
data = []
70-
with open("scripts/category_count.csv", "r") as f:
71-
reader = csv.DictReader(f)
72-
for row in reader:
73-
data.append(row)
74-
75-
random.shuffle(data)
76-
arxiv_download(data=data, path=path)
77-
for root, dirs, files in os.walk(path):
78-
for dir_ in dirs:
79-
extract_all_tar_gz(os.path.join(root, dir_))
85+
main()

0 commit comments

Comments
 (0)