Skip to content

Commit 65cee7c

Browse files
committed
Fix sorted map sentinel reference count
1 parent f7b703c commit 65cee7c

2 files changed

Lines changed: 27 additions & 12 deletions

File tree

src/rbtree.c

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -568,20 +568,14 @@ static void RBTree_RemoveNode(CtsRBTree *tree, CtsRBTreeNode *node) {
568568
assert(root != sentinel);
569569
assert(node != sentinel);
570570
if (root == node && tree->length == 1) {
571-
Py_INCREF(sentinel);
572571
tree->root = sentinel;
573572
} else {
574573
rbtree_delete(tree, node);
575574
}
576575

576+
Py_DECREF(sentinel);
577577
Py_DECREF(node->key);
578578
Py_DECREF(node->value);
579-
if (node->left == tree->sentinel) {
580-
Py_DECREF(tree->sentinel);
581-
}
582-
if (node->right == tree->sentinel) {
583-
Py_DECREF(tree->sentinel);
584-
}
585579
node->left = NULL;
586580
node->right = NULL;
587581
node->key = NULL;

tests/test_sortedmap.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class TestSortedMap(unittest.TestCase):
3434
def assert_ref(self, o1, o2, msg=None):
3535
self.assertEqual(sys.getrefcount(o1), sys.getrefcount(o2), msg=msg)
3636

37+
def _create_sorted_map(self, cmp=None):
38+
return ctools.SortedMap(cmp)
39+
3740
def _build_v(self, seq):
3841
sorted_map = ctools.SortedMap()
3942
mapping = dict()
@@ -133,11 +136,6 @@ def test_contains(self):
133136
self.assertTrue(key2 in mapping)
134137
self.assert_ref(key2, key1)
135138

136-
def test_len(self):
137-
seq = list(range(1024))
138-
keys1, keys2, sorted_map, mapping = self._build_v(seq)
139-
self.assertEqual(len(keys1), len(sorted_map))
140-
141139
def _test_iter(self, name):
142140
seq = list(range(1024))
143141
random.shuffle(seq)
@@ -439,6 +437,29 @@ def test_max_min(self):
439437
k2 = keys2[i]
440438
self.assert_ref(k2, k1)
441439

440+
def test_len(self):
441+
seq = list(range(1024))
442+
s = self._create_sorted_map()
443+
for i, v in enumerate(seq):
444+
s[v] = v
445+
self.assertEqual(i + 1, len(s))
446+
447+
total = len(seq)
448+
for i, v in enumerate(seq):
449+
del s[v]
450+
self.assertEqual(total - i - 1, len(s))
451+
452+
def test_sentinel(self):
453+
seq = list(range(1024))
454+
s = self._create_sorted_map()
455+
sentinel = _ctools.SortedMapSentinel
456+
ref_count = sys.getrefcount(sentinel)
457+
for i, v in enumerate(seq):
458+
s[v] = v
459+
for i, v in enumerate(seq):
460+
del s[v]
461+
self.assertEqual(ref_count, sys.getrefcount(sentinel))
462+
442463

443464
if __name__ == "__main__":
444465
unittest.main()

0 commit comments

Comments
 (0)