-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy path0_generate_snowball_start.py
More file actions
78 lines (63 loc) · 2.31 KB
/
0_generate_snowball_start.py
File metadata and controls
78 lines (63 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env python3
"""
Script to generate snowball sampling starting points from accepted papers.
Extracts titles from accepted_papers.json, searches Google Scholar for citation numbers,
and outputs in the format of initial.json.
"""
import argparse
import hashlib
import json
import re
import requests
import time
from dotenv import load_dotenv
from enum import Enum
from scholarly import scholarly
from tqdm import tqdm
from typing import List, Dict, Optional
from urllib.parse import quote_plus
from utils.proxy_generator import get_proxy
from utils.db_management import (
DBManager,
initialize_db,
SelectionStage
)
from utils.pipeline.generate_snowball_start_utils import (
generate_snowball_start,
extract_titles_from_file
)
from utils.article_search.article_search_method import SearchMethod
ITERATION_0 = 0
load_dotenv()
with open("confs/search_conf.json", "r") as f:
search_conf = json.load(f)
pg = get_proxy(search_conf["proxy_key"])
def parse_args():
parser = argparse.ArgumentParser(description='Generate snowball sampling starting points from file')
parser.add_argument('--input_file', help='Path to the input file (json or text)', default=search_conf["initial_file"])
parser.add_argument('--delay', type=float, default=1.0, help='Delay between API requests in seconds (default: 1.0)')
parser.add_argument('--db_path', help='db path', type=str, default=search_conf["db_path"])
parser.add_argument(
'--search_method',
help='Search method to use',
type=str,
default=search_conf["search_method"],
choices=[method.value for method in SearchMethod]
)
args = parser.parse_args()
return args
def main():
args = parse_args()
try:
search_method = SearchMethod(args.search_method)
except ValueError:
print(f"Error: Invalid search method '{args.search_method}'. Available options: {[method.value for method in SearchMethod]}")
return
db_manager = initialize_db(args.db_path, search_conf)
initial_pubs, seen_titles = generate_snowball_start(args.input_file, ITERATION_0, args.delay, search_method)
db_manager.insert_iteration_data(initial_pubs)
db_manager.insert_seen_titles_data(seen_titles)
db_manager.cursor.close()
db_manager.conn.close()
if __name__ == "__main__":
main()