Skip to content

Commit fc84539

Browse files
fix: add special edge handling in bfs, dfs and base partitioner (#150)
* fix: add special edge handling in bfs, dfs and base partitioner * refactor: extract a helper method in the base class * fix: add List import to dfs_partitioner * fix: align edge type * fix: fix lint error --------- Co-authored-by: chenzihong <522023320011@smail.nju.edu.cn>
1 parent 97c3552 commit fc84539

6 files changed

Lines changed: 25 additions & 98 deletions

File tree

graphgen/bases/base_partitioner.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,17 @@ def community2batch(
3939
if node_data:
4040
nodes_data.append((node, node_data))
4141
edges_data = []
42-
for u, v in edges:
43-
edge_data = g.get_edge(u, v)
42+
for edge in edges:
43+
# Filter out self-loops and invalid edges
44+
if not isinstance(edge, tuple) or len(edge) != 2:
45+
continue
46+
u, v = edge
47+
if u == v:
48+
continue
49+
50+
edge_data = g.get_edge(u, v) or g.get_edge(v, u)
4451
if edge_data:
4552
edges_data.append((u, v, edge_data))
46-
else:
47-
edge_data = g.get_edge(v, u)
48-
if edge_data:
49-
edges_data.append((v, u, edge_data))
5053
return nodes_data, edges_data
5154

5255
@staticmethod
@@ -61,9 +64,11 @@ def _build_adjacency_list(
6164
"""
6265
adj: dict[str, List[str]] = {n[0]: [] for n in nodes}
6366
edge_set: set[tuple[str, str]] = set()
64-
for e in edges:
65-
adj[e[0]].append(e[1])
66-
adj[e[1]].append(e[0])
67-
edge_set.add((e[0], e[1]))
68-
edge_set.add((e[1], e[0]))
67+
for u, v, _ in edges:
68+
if u == v:
69+
continue
70+
adj[u].append(v)
71+
adj[v].append(u)
72+
edge_set.add((u, v))
73+
edge_set.add((v, u))
6974
return adj, edge_set

graphgen/models/partitioner/bfs_partitioner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ def partition(
6363
if it in used_e:
6464
continue
6565
used_e.add(it)
66-
67-
u, v = it
68-
comm_e.append((u, v))
66+
comm_e.append(tuple(sorted(it)))
6967
cnt += 1
7068
# push nodes that are not visited
7169
for n in it:

graphgen/models/partitioner/dfs_partitioner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from collections.abc import Iterable
3-
from typing import Any
3+
from typing import Any, List
44

55
from graphgen.bases import BaseGraphStorage, BasePartitioner
66
from graphgen.bases.datatypes import Community
@@ -42,7 +42,8 @@ def partition(
4242
):
4343
continue
4444

45-
comm_n, comm_e = [], []
45+
comm_n: List[str] = []
46+
comm_e: List[tuple[str, str]] = []
4647
stack = [(kind, seed)]
4748
cnt = 0
4849

@@ -63,7 +64,7 @@ def partition(
6364
if it in used_e:
6465
continue
6566
used_e.add(it)
66-
comm_e.append(tuple(it))
67+
comm_e.append(tuple(sorted(it)))
6768
cnt += 1
6869
# push neighboring nodes
6970
for n in it:

graphgen/models/partitioner/ece_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def _add_unit(u):
142142
return Community(
143143
id=seed_unit[1],
144144
nodes=list(community_nodes.keys()),
145-
edges=[tuple(edge) for edge in community_edges if isinstance(edge, frozenset) and len(edge)==2],
145+
edges=[tuple(sorted(e)) for e in community_edges]
146146
)
147147

148148
for unit in tqdm(all_units, desc="ECE partition"):

graphgen/utils/help_nltk.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
from typing import Dict, List, Final, Optional
44
import warnings
55
import nltk
6-
import jieba
7-
86
warnings.filterwarnings(
97
"ignore",
108
category=UserWarning,
119
module=r"jieba\._compat"
1210
)
11+
# pylint: disable=wrong-import-position
12+
import jieba
13+
1314

1415
class NLTKHelper:
1516
"""

tests/integration_tests/test_engine.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

0 commit comments

Comments
 (0)