-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
141 lines (128 loc) · 6.46 KB
/
main.py
File metadata and controls
141 lines (128 loc) · 6.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import sys
from load_datasets import *
from feature_extraction import *
from graph_structure import *
from utils import *
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from random import shuffle
from torch_geometric.data import DataLoader
from nn import *
import warnings
warnings.filterwarnings("ignore")
CONFIG = {
'dataset': 'tag_my_news', # one out of ['twitter', 'mr', 'snippets', 'tag_my_news']
'skip_data_generation': True,
'train_eval_samples_per_class': 20,
'shuffle_train': True,
'batch_size': 8,
'hidden_dim': 128,
'epochs': 200,
'eval_best': True,
'num_runs': 5,
'verbose': 0,
'random_samples': False, # False = use same training samples, True = shuffle data and select random training samples
'random_seed': 100, # this will only be used if 'random_samples'=False
'num_attention_heads': 1,
'n_hop_neighborhood': 2, # this has to be set n_hops + 1, e.g. for 2-hop graph = 3
'gnn_type': 'GAT' # GAT or SAGE or GCN
}
# (1) generate graphs
if not CONFIG['skip_data_generation']:
model_dict = {
'twitter': 'Twitter/twhin-bert-base',
'mr': 'google-bert/bert-base-uncased',
'snippets': 'google-bert/bert-base-uncased',
'tag_my_news': 'google-bert/bert-base-uncased',
}
bert_tokenizer = AutoTokenizer.from_pretrained(model_dict[CONFIG['dataset']])
bert_model = AutoModel.from_pretrained(model_dict[CONFIG['dataset']])
existing_graphs = get_saved_graphs(CONFIG['dataset'])
if CONFIG['dataset'] == 'twitter':
positive_tweets, negative_tweets = load_twitter_corpus()
for idx, tweet in enumerate(tqdm(positive_tweets)):
if 'pos_' + str(idx) + '.pickle' in existing_graphs:
continue
graph = generate_graph_from_text(tweet, label=1, tokenizer=bert_tokenizer, llm=bert_model)
save_graph(graph, file_name=f'pos_{idx}', dataset=CONFIG['dataset'])
for idx, tweet in enumerate(tqdm(negative_tweets)):
if 'neg_' + str(idx) + '.pickle' in existing_graphs:
continue
graph = generate_graph_from_text(tweet, label=0, tokenizer=bert_tokenizer, llm=bert_model)
save_graph(graph, file_name=f'neg_{idx}', dataset=CONFIG['dataset'])
elif CONFIG['dataset'] == 'mr':
positive_samples, negative_samples = load_mr_corpus()
for idx, sample in enumerate(tqdm(positive_samples)):
if 'pos_' + str(idx) + '.pickle' in existing_graphs:
continue
graph = generate_graph_from_text(sample, label=1, tokenizer=bert_tokenizer, llm=bert_model)
save_graph(graph, file_name=f'pos_{idx}', dataset=CONFIG['dataset'])
for idx, sample in enumerate(tqdm(negative_samples)):
if 'neg_' + str(idx) + '.pickle' in existing_graphs:
continue
graph = generate_graph_from_text(sample, label=0, tokenizer=bert_tokenizer, llm=bert_model)
save_graph(graph, file_name=f'neg_{idx}', dataset=CONFIG['dataset'])
elif CONFIG['dataset'] == 'snippets':
samples = load_snippets_corpus()
for cls, class_samples in enumerate(samples):
for idx, sample in enumerate(tqdm(class_samples)):
if str(cls) + '_' + str(idx) + '.pickle' in existing_graphs:
continue
graph = generate_graph_from_text(sample, label=cls, tokenizer=bert_tokenizer, llm=bert_model)
save_graph(graph, file_name=f'{cls}_{idx}', dataset=CONFIG['dataset'])
elif CONFIG['dataset'] == 'tag_my_news':
samples = load_tagmynews_corpus()
for cls, class_samples in enumerate(samples):
for idx, sample in enumerate(tqdm(class_samples)):
if str(cls) + '_' + str(idx) + '.pickle' in existing_graphs:
continue
graph = generate_graph_from_text(sample, label=cls, tokenizer=bert_tokenizer, llm=bert_model)
save_graph(graph, file_name=f'{cls}_{idx}', dataset=CONFIG['dataset'])
acc_runs = []
precision_runs = []
recall_runs = []
f1_runs = []
for i in range(CONFIG['num_runs']):
# (2) load data
train_graphs, eval_graph, test_graphs = load_train_eval_test(
dataset=CONFIG['dataset'],
num_per_class_train=CONFIG['train_eval_samples_per_class'],
num_per_class_eval=CONFIG['train_eval_samples_per_class'],
n_hop_neighborhood=CONFIG['n_hop_neighborhood'],
random_samples=CONFIG['random_samples'],
random_seed=CONFIG['random_seed']
)
print(len(train_graphs), len(eval_graph), len(test_graphs))
train_loader = DataLoader(train_graphs, batch_size=CONFIG['batch_size'], shuffle=CONFIG['shuffle_train'])
eval_loader = DataLoader(eval_graph, batch_size=CONFIG['batch_size'], shuffle=False)
test_loader = DataLoader(test_graphs, batch_size=CONFIG['batch_size'], shuffle=False)
# (3) prepare model
num_classes_dict = {
'twitter': 2,
'mr': 2,
'snippets': 8,
'tag_my_news': 7,
}
if CONFIG['gnn_type'] == 'GAT':
model = GAT(hidden_channels=CONFIG['hidden_dim'], out_channels=num_classes_dict[CONFIG['dataset']],
num_attention_heads=CONFIG['num_attention_heads'], use_hypergraph=False)
elif CONFIG['gnn_type'] == 'GCN':
model = GCN(hidden_channels=CONFIG['hidden_dim'], out_channels=num_classes_dict[CONFIG['dataset']])
elif CONFIG['gnn_type'] == 'SAGE':
model = SAGE(hidden_channels=CONFIG['hidden_dim'], out_channels=num_classes_dict[CONFIG['dataset']])
else:
raise 'no valid GNN selected, must be one out of [GAT, GCN, SAGE]'
optimizer = torch.optim.Adam(model.parameters(), lr=0.00005)
criterion = torch.nn.CrossEntropyLoss()
acc, precision, recall, f1 = train_eval_model(model=model, train_loader=train_loader, eval_loader=eval_loader,
test_loader=test_loader, loss_fct=criterion, optimizer=optimizer,
num_epochs=CONFIG['epochs'], verbose=CONFIG['verbose'],
eval_best=CONFIG['eval_best'])
acc_runs.append(acc)
precision_runs.append(precision)
recall_runs.append(recall)
f1_runs.append(f1)
print('Accuracy:', acc_runs, sum(acc_runs)/CONFIG['num_runs'])
print('Precision:', precision_runs, sum(precision_runs)/CONFIG['num_runs'])
print('Recall:', recall_runs, sum(recall_runs)/CONFIG['num_runs'])
print('F1:', f1_runs, sum(f1_runs)/CONFIG['num_runs'])