Skip to content

Commit 4a4d1c0

Browse files
committed
tutorial 12
1 parent 3987b9b commit 4a4d1c0

4 files changed

Lines changed: 1049 additions & 0 deletions

File tree

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,11 @@ Feel free to join our weekly online tutorial, for more details, have a look at t
2525

2626
* Tutorial8: Graph Generation.
2727

28+
* Tutorial9: Recurrent Graph Neural Networks.
29+
30+
* Tutorial10: DeepWalk and Node2Vec (Theory).
31+
32+
* Tutorial11: DeepWalk and Node2Vec (Practice).
33+
34+
* Tutorial12: Edge analysis.
35+
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 6,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import os.path as osp\n",
10+
"\n",
11+
"import torch\n",
12+
"import torch.nn.functional as F\n",
13+
"from sklearn.metrics import roc_auc_score\n",
14+
"\n",
15+
"from torch_geometric.utils import negative_sampling\n",
16+
"from torch_geometric.datasets import Planetoid\n",
17+
"import torch_geometric.transforms as T\n",
18+
"from torch_geometric.nn import GCNConv\n",
19+
"from torch_geometric.utils import train_test_split_edges"
20+
]
21+
},
22+
{
23+
"cell_type": "markdown",
24+
"metadata": {},
25+
"source": [
26+
"# GAE for link prediction\n",
27+
"\n",
28+
"[code](https://github.com/rusty1s/pytorch_geometric/blob/master/examples/link_pred.py)\n"
29+
]
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": 7,
34+
"metadata": {},
35+
"outputs": [],
36+
"source": [
37+
"\n",
38+
"\n",
39+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
40+
"device = \"cpu\""
41+
]
42+
},
43+
{
44+
"cell_type": "code",
45+
"execution_count": 8,
46+
"metadata": {},
47+
"outputs": [
48+
{
49+
"name": "stdout",
50+
"output_type": "stream",
51+
"text": [
52+
"Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])\n"
53+
]
54+
}
55+
],
56+
"source": [
57+
"# load the Cora dataset\n",
58+
"dataset = 'Cora'\n",
59+
"path = osp.join('.', 'data', dataset)\n",
60+
"dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())\n",
61+
"data = dataset[0]\n",
62+
"print(dataset.data)"
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": 9,
68+
"metadata": {},
69+
"outputs": [
70+
{
71+
"name": "stdout",
72+
"output_type": "stream",
73+
"text": [
74+
"Data(test_neg_edge_index=[2, 527], test_pos_edge_index=[2, 527], train_neg_adj_mask=[2708, 2708], train_pos_edge_index=[2, 8976], val_neg_edge_index=[2, 263], val_pos_edge_index=[2, 263], x=[2708, 1433])\n"
75+
]
76+
}
77+
],
78+
"source": [
79+
"# use train_test_split_edges to create neg and positive edges\n",
80+
"data.train_mask = data.val_mask = data.test_mask = data.y = None\n",
81+
"data = train_test_split_edges(data)\n",
82+
"print(data)"
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": null,
88+
"metadata": {},
89+
"outputs": [],
90+
"source": []
91+
},
92+
{
93+
"cell_type": "markdown",
94+
"metadata": {},
95+
"source": [
96+
"#### Simple autoencoder model"
97+
]
98+
},
99+
{
100+
"cell_type": "code",
101+
"execution_count": 11,
102+
"metadata": {},
103+
"outputs": [],
104+
"source": [
105+
"class Net(torch.nn.Module):\n",
106+
" def __init__(self):\n",
107+
" super(Net, self).__init__()\n",
108+
" self.conv1 = GCNConv(dataset.num_features, 128)\n",
109+
" self.conv2 = GCNConv(128, 64)\n",
110+
"\n",
111+
" def encode(self):\n",
112+
" x = self.conv1(data.x, data.train_pos_edge_index) # convolution 1\n",
113+
" x = x.relu()\n",
114+
" return self.conv2(x, data.train_pos_edge_index) # convolution 2\n",
115+
"\n",
116+
" def decode(self, z, pos_edge_index, neg_edge_index): # only pos and neg edges\n",
117+
" edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) # concatenate pos and neg edges\n",
118+
" logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) # dot product \n",
119+
" return logits\n",
120+
"\n",
121+
" def decode_all(self, z): \n",
122+
" prob_adj = z @ z.t() # get adj NxN\n",
123+
" return (prob_adj > 0).nonzero(as_tuple=False).t() # get predicted edge_list "
124+
]
125+
},
126+
{
127+
"cell_type": "code",
128+
"execution_count": 12,
129+
"metadata": {},
130+
"outputs": [],
131+
"source": [
132+
"\n",
133+
"model, data = Net().to(device), data.to(device)\n",
134+
"optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)"
135+
]
136+
},
137+
{
138+
"cell_type": "code",
139+
"execution_count": null,
140+
"metadata": {},
141+
"outputs": [],
142+
"source": []
143+
},
144+
{
145+
"cell_type": "code",
146+
"execution_count": 13,
147+
"metadata": {},
148+
"outputs": [],
149+
"source": [
150+
"\n",
151+
"def get_link_labels(pos_edge_index, neg_edge_index):\n",
152+
" # returns a tensor:\n",
153+
" # [1,1,1,1,...,0,0,0,0,0,..] with the number of ones is equel to the lenght of pos_edge_index\n",
154+
" # and the number of zeros is equal to the length of neg_edge_index\n",
155+
" E = pos_edge_index.size(1) + neg_edge_index.size(1)\n",
156+
" link_labels = torch.zeros(E, dtype=torch.float, device=device)\n",
157+
" link_labels[:pos_edge_index.size(1)] = 1.\n",
158+
" return link_labels\n",
159+
"\n",
160+
"\n",
161+
"def train():\n",
162+
" model.train()\n",
163+
"\n",
164+
" neg_edge_index = negative_sampling(\n",
165+
" edge_index=data.train_pos_edge_index, #positive edges\n",
166+
" num_nodes=data.num_nodes, # number of nodes\n",
167+
" num_neg_samples=data.train_pos_edge_index.size(1)) # number of neg_sample equal to number of pos_edges\n",
168+
"\n",
169+
" optimizer.zero_grad()\n",
170+
" \n",
171+
" z = model.encode() #encode\n",
172+
" link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index) # decode\n",
173+
" \n",
174+
" link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index)\n",
175+
" loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)\n",
176+
" loss.backward()\n",
177+
" optimizer.step()\n",
178+
"\n",
179+
" return loss\n",
180+
"\n",
181+
"\n",
182+
"@torch.no_grad()\n",
183+
"def test():\n",
184+
" model.eval()\n",
185+
" perfs = []\n",
186+
" for prefix in [\"val\", \"test\"]:\n",
187+
" pos_edge_index = data[f'{prefix}_pos_edge_index']\n",
188+
" neg_edge_index = data[f'{prefix}_neg_edge_index']\n",
189+
"\n",
190+
" z = model.encode() # encode train\n",
191+
" link_logits = model.decode(z, pos_edge_index, neg_edge_index) # decode test or val\n",
192+
" link_probs = link_logits.sigmoid() # apply sigmoid\n",
193+
" \n",
194+
" link_labels = get_link_labels(pos_edge_index, neg_edge_index) # get link\n",
195+
" \n",
196+
" perfs.append(roc_auc_score(link_labels.cpu(), link_probs.cpu())) #compute roc_auc score\n",
197+
" return perfs\n"
198+
]
199+
},
200+
{
201+
"cell_type": "code",
202+
"execution_count": 14,
203+
"metadata": {},
204+
"outputs": [
205+
{
206+
"name": "stdout",
207+
"output_type": "stream",
208+
"text": [
209+
"Epoch: 010, Loss: 0.6837, Val: 0.7552, Test: 0.7562\n",
210+
"Epoch: 020, Loss: 0.6423, Val: 0.7552, Test: 0.7562\n",
211+
"Epoch: 030, Loss: 0.5490, Val: 0.7935, Test: 0.8021\n",
212+
"Epoch: 040, Loss: 0.5108, Val: 0.8210, Test: 0.8486\n",
213+
"Epoch: 050, Loss: 0.4894, Val: 0.8455, Test: 0.8712\n",
214+
"Epoch: 060, Loss: 0.4656, Val: 0.8637, Test: 0.8966\n",
215+
"Epoch: 070, Loss: 0.4585, Val: 0.8808, Test: 0.9000\n",
216+
"Epoch: 080, Loss: 0.4518, Val: 0.8864, Test: 0.9084\n",
217+
"Epoch: 090, Loss: 0.4458, Val: 0.8905, Test: 0.9093\n",
218+
"Epoch: 100, Loss: 0.4501, Val: 0.8920, Test: 0.9111\n"
219+
]
220+
}
221+
],
222+
"source": [
223+
"\n",
224+
"best_val_perf = test_perf = 0\n",
225+
"for epoch in range(1, 101):\n",
226+
" train_loss = train()\n",
227+
" val_perf, tmp_test_perf = test()\n",
228+
" if val_perf > best_val_perf:\n",
229+
" best_val_perf = val_perf\n",
230+
" test_perf = tmp_test_perf\n",
231+
" log = 'Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}'\n",
232+
" if epoch % 10 == 0:\n",
233+
" print(log.format(epoch, train_loss, best_val_perf, test_perf))\n",
234+
"\n"
235+
]
236+
},
237+
{
238+
"cell_type": "code",
239+
"execution_count": null,
240+
"metadata": {},
241+
"outputs": [],
242+
"source": []
243+
},
244+
{
245+
"cell_type": "code",
246+
"execution_count": 15,
247+
"metadata": {},
248+
"outputs": [],
249+
"source": [
250+
"z = model.encode()\n",
251+
"final_edge_index = model.decode_all(z)"
252+
]
253+
},
254+
{
255+
"cell_type": "code",
256+
"execution_count": null,
257+
"metadata": {},
258+
"outputs": [],
259+
"source": []
260+
},
261+
{
262+
"cell_type": "code",
263+
"execution_count": null,
264+
"metadata": {},
265+
"outputs": [],
266+
"source": []
267+
}
268+
],
269+
"metadata": {
270+
"kernelspec": {
271+
"display_name": "Python 3",
272+
"language": "python",
273+
"name": "python3"
274+
},
275+
"language_info": {
276+
"codemirror_mode": {
277+
"name": "ipython",
278+
"version": 3
279+
},
280+
"file_extension": ".py",
281+
"mimetype": "text/x-python",
282+
"name": "python",
283+
"nbconvert_exporter": "python",
284+
"pygments_lexer": "ipython3",
285+
"version": "3.8.5"
286+
}
287+
},
288+
"nbformat": 4,
289+
"nbformat_minor": 4
290+
}

0 commit comments

Comments
 (0)