33from typing import List , Dict
44from tqdm import tqdm
55import tarfile
6- import csv
7- import random
6+
7+
8+ from vrdu import utils
9+ from vrdu import logger
10+
11+ log = logger .setup_app_level_logger (file_name = "arxiv_download.log" )
812
913
1014def arxiv_download (data : List [Dict ], path : str ) -> None :
@@ -26,54 +30,56 @@ def arxiv_download(data: List[Dict], path: str) -> None:
2630 Returns:
2731 None
2832 """
33+ client = arxiv .Client ()
2934 for row in tqdm (data ):
30- category , count = row ["categories" ], int (row ["count" ])
31- print (f"category: { category } , count: { count } " )
32- sub_directory = os .path .join (path , category )
33- os .makedirs (sub_directory , exist_ok = True )
34-
35- search = arxiv .Search (
36- query = category ,
37- max_results = count ,
38- sort_by = arxiv .SortCriterion .SubmittedDate ,
39- )
40-
41- for result in search .results ():
42- file_name = result ._get_default_filename ()
43- if os .path .exists (os .path .join (sub_directory , file_name )):
35+ if row ["auto_annotated_paper_path" ]:
36+ continue
37+ discipline = row ["discipline" ]
38+ discipline_path = os .path .join (path , discipline )
39+ os .makedirs (discipline_path , exist_ok = True )
40+
41+ if os .path .exists (os .path .join (discipline_path , row ["paper_id" ])):
42+ log .debug (f'{ os .path .join (discipline_path , row ["paper_id" ])} exists' )
43+ continue
44+
45+ if os .path .exists (os .path .join (discipline_path , row ["paper_id" ], ".tar.gz" )):
46+ log .debug (
47+ f'{ os .path .join (discipline_path , row ["paper_id" ], ".tar.gz" )} exists'
48+ )
49+ continue
50+
51+ search_results = client .results (arxiv .Search (id_list = [row ["paper_id" ]]))
52+
53+ for result in search_results :
54+ tar_file_path = result .download_source (dirpath = discipline_path )
55+ log .debug (f"Downloading tar file { tar_file_path } " )
56+ paper_path = os .path .join (discipline_path , row ["paper_id" ])
57+ try :
58+ with tarfile .open (tar_file_path , "r:gz" ) as tar :
59+ tar .extractall (paper_path )
60+ except tarfile .ReadError :
61+ log .error (f"{ tar_file_path } is not a tar.gz file" )
4462 continue
4563
46- result .download_source (dirpath = sub_directory )
4764
65+ def main ():
66+ import argparse
4867
49- def extract_all_tar_gz (directory ):
50- for root , dirs , files in os .walk (directory ):
51- for file in files :
52- if not file .endswith (".tar.gz" ):
53- continue
54- file_path = os .path .join (root , file )
55- extract_path = os .path .splitext (file_path )[0 ]
56- extract_path = os .path .splitext (extract_path )[0 ]
57- if os .path .exists (extract_path ):
58- continue
59- extract_tar_gz (file_path , extract_path )
68+ parser = argparse .ArgumentParser ()
69+ parser .add_argument (
70+ "-p" , "--path" , type = str , required = True , help = "Path to save result"
71+ )
72+ parser .add_argument (
73+ "-f" , "--file" , type = str , required = True , help = "json file for saving result"
74+ )
75+
76+ args = parser .parse_args ()
77+ output_path , json_file = args .path , args .file
6078
79+ json_data = utils .load_json (json_file )
6180
62- def extract_tar_gz (file_path , extract_path ):
63- with tarfile .open (file_path , "r:gz" ) as tar :
64- tar .extractall (extract_path )
81+ arxiv_download (json_data , output_path )
6582
6683
6784if __name__ == "__main__" :
68- path = os .path .expanduser ("/cpfs01/shared/ADLab/datasets/vrdu_arxiv" )
69- data = []
70- with open ("scripts/category_count.csv" , "r" ) as f :
71- reader = csv .DictReader (f )
72- for row in reader :
73- data .append (row )
74-
75- random .shuffle (data )
76- arxiv_download (data = data , path = path )
77- for root , dirs , files in os .walk (path ):
78- for dir_ in dirs :
79- extract_all_tar_gz (os .path .join (root , dir_ ))
85+ main ()
0 commit comments