Skip to content

Commit d52b222

Browse files
committed
Cleaning up the immutability code
Signed-off-by: Matthew A Johnson <matjoh@microsoft.com>
1 parent 3b832e6 commit d52b222

4 files changed

Lines changed: 150 additions & 53 deletions

File tree

Include/internal/pycore_object.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ static inline void _Py_SetImmutable(PyObject *op)
101101
#define _Py_SetImmutable(op) _Py_SetImmutable(_PyObject_CAST(op))
102102

103103
// Check whether an object is writeable.
104-
// Note that during runtime finalization, all objects must be mutable
104+
// This check will always succeed during runtime finalization.
105105
#define Py_CHECKWRITE(op) ((op) && (!_Py_IsImmutable(op) || _Py_IsFinalizing()))
106106
#define Py_REQUIREWRITE(op, msg) {if (Py_CHECKWRITE(op)) { _PyObject_ASSERT_FAILED_MSG(op, msg); }}
107107

Lib/test/test_freeze.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import unittest
2+
from collections import deque
3+
from array import array
24

35
# This is a canary to check that global variables are not made immutable
46
# when others are made immutable
@@ -97,6 +99,83 @@ def test_reverse(self):
9799
with self.assertRaises(NotWriteableError):
98100
self.obj.reverse()
99101

102+
def test_inplace_repeat(self):
103+
with self.assertRaises(NotWriteableError):
104+
self.obj *= 2
105+
106+
def test_inplace_concat(self):
107+
with self.assertRaises(NotWriteableError):
108+
self.obj += [TestList.C()]
109+
110+
def test_clear(self):
111+
with self.assertRaises(NotWriteableError):
112+
self.obj.clear()
113+
114+
def test_sort(self):
115+
with self.assertRaises(NotWriteableError):
116+
self.obj.sort()
117+
118+
119+
class TestDeque(BaseObjectTest):
120+
class C:
121+
pass
122+
123+
def __init__(self, *args, **kwargs):
124+
obj = deque([self.C(), self.C(), 1, "two", None])
125+
BaseObjectTest.__init__(self, *args, obj=obj, **kwargs)
126+
127+
def test_set_item(self):
128+
with self.assertRaises(NotWriteableError):
129+
self.obj[0] = None
130+
131+
def test_set_slice(self):
132+
with self.assertRaises(NotWriteableError):
133+
self.obj[1:3] = [None, None]
134+
135+
def test_append(self):
136+
with self.assertRaises(NotWriteableError):
137+
self.obj.append(TestList.C())
138+
139+
def test_appendleft(self):
140+
with self.assertRaises(NotWriteableError):
141+
self.obj.appendleft(TestList.C())
142+
143+
def test_extend(self):
144+
with self.assertRaises(NotWriteableError):
145+
self.obj.extend([TestList.C()])
146+
147+
def test_extendleft(self):
148+
with self.assertRaises(NotWriteableError):
149+
self.obj.extendleft([TestList.C()])
150+
151+
def test_insert(self):
152+
with self.assertRaises(NotWriteableError):
153+
self.obj.insert(0, TestList.C())
154+
155+
def test_pop(self):
156+
with self.assertRaises(NotWriteableError):
157+
self.obj.pop()
158+
159+
def test_popleft(self):
160+
with self.assertRaises(NotWriteableError):
161+
self.obj.popleft()
162+
163+
def test_remove(self):
164+
with self.assertRaises(NotWriteableError):
165+
self.obj.remove(1)
166+
167+
def test_inplace_repeat(self):
168+
with self.assertRaises(NotWriteableError):
169+
self.obj *= 2
170+
171+
def test_inplace_concat(self):
172+
with self.assertRaises(NotWriteableError):
173+
self.obj += [TestList.C()]
174+
175+
def test_reverse(self):
176+
with self.assertRaises(NotWriteableError):
177+
self.obj.reverse()
178+
100179
def test_clear(self):
101180
with self.assertRaises(NotWriteableError):
102181
self.obj.clear()

Objects/listobject.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,12 @@ list_inplace_repeat(PyListObject *self, Py_ssize_t n)
774774
if (input_size > PY_SSIZE_T_MAX / n) {
775775
return PyErr_NoMemory();
776776
}
777+
778+
if(!Py_CHECKWRITE(self))
779+
{
780+
return PyErr_WriteToImmutable(self);
781+
}
782+
777783
Py_ssize_t output_size = input_size * n;
778784

779785
if (list_resize(self, output_size) < 0)

Python/immutability.c

Lines changed: 64 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,32 @@
88
#include "pycore_immutability.h"
99

1010

11+
static int push(PyObject* s, PyObject* item){
12+
if(item == NULL){
13+
return 0;
14+
}
15+
16+
if(!PyList_Check(s)){
17+
PyErr_SetString(PyExc_TypeError, "Expected a list");
18+
return -1;
19+
}
20+
21+
return _PyList_AppendTakeRef(_PyList_CAST(s), item);
22+
}
23+
1124
static PyObject* pop(PyObject* s){
25+
PyObject* item;
1226
Py_ssize_t size = PyList_Size(s);
1327
if(size == 0){
1428
return NULL;
1529
}
1630

17-
PyObject* item = PyList_GetItem(s, size - 1);
31+
item = PyList_GetItem(s, size - 1);
1832
if(item == NULL){
1933
return NULL;
2034
}
2135

2236
Py_INCREF(item);
23-
2437
if(PyList_SetSlice(s, size - 1, size, NULL)){
2538
Py_DECREF(item);
2639
return NULL;
@@ -35,7 +48,7 @@ static bool is_c_wrapper(PyObject* obj){
3548

3649
#define _Py_VISIT_FUNC_ATTR(attr, frontier) do { \
3750
if(attr != NULL && !_Py_IsImmutable(attr)){ \
38-
if(PyList_Append(frontier, attr)){ \
51+
if(push((frontier), (attr))){ \
3952
return PyErr_NoMemory(); \
4053
} \
4154
} \
@@ -118,7 +131,7 @@ static PyObject* walk_function(PyObject* op, PyObject* frontier)
118131
}
119132

120133
f_ptr = f->func_code;
121-
if(PyList_Append(f_stack, f_ptr)){
134+
if(push(f_stack, f_ptr)){
122135
goto nomemory;
123136
}
124137

@@ -174,7 +187,7 @@ static PyObject* walk_function(PyObject* op, PyObject* frontier)
174187
_PyDict_SetKeyImmutable((PyDictObject*)module_dict, name);
175188

176189
if(!_Py_IsImmutable(value)){
177-
if(PyList_Append(frontier, value)){
190+
if(push(frontier, value)){
178191
goto nomemory;
179192
}
180193
}
@@ -188,11 +201,11 @@ static PyObject* walk_function(PyObject* op, PyObject* frontier)
188201
if(PyCode_Check(value)){
189202
_Py_SetImmutable(value);
190203

191-
if(PyList_Append(f_stack, value)){
204+
if(push(f_stack, value)){
192205
goto nomemory;
193206
}
194207
}else{
195-
if(PyList_Append(frontier, value)){
208+
if(push(frontier, value)){
196209
goto nomemory;
197210
}
198211
}
@@ -250,14 +263,14 @@ static PyObject* walk_function(PyObject* op, PyObject* frontier)
250263
}
251264
}
252265

253-
if(PyList_Append(frontier, frozen_globals)){
266+
if(push(frontier, frozen_globals)){
254267
goto nomemory;
255268
}
256269

257270
f->func_globals = frozen_globals;
258271
Py_DECREF(globals);
259272

260-
if(PyList_Append(frontier, frozen_builtins)){
273+
if(push(frontier, frozen_builtins)){
261274
goto nomemory;
262275
}
263276

@@ -276,7 +289,7 @@ static PyObject* walk_function(PyObject* op, PyObject* frontier)
276289
static int freeze_visit(PyObject* obj, void* frontier)
277290
{
278291
if(!_Py_IsImmutable(obj)){
279-
if(PyList_Append((PyObject*)frontier, obj)){
292+
if(push(frontier, obj)){
280293
PyErr_NoMemory();
281294
return -1;
282295
}
@@ -287,44 +300,49 @@ static int freeze_visit(PyObject* obj, void* frontier)
287300

288301
PyObject* _Py_Freeze(PyObject* obj)
289302
{
303+
PyObject* frontier = NULL;
304+
PyObject* frozen_importlib = NULL;
305+
PyObject* blocking_on = NULL;
306+
PyObject* module_locks = NULL;
307+
PyObject* result = Py_None;
308+
290309
if(_Py_IsImmutable(obj)){
291-
Py_RETURN_NONE;
310+
return result;
292311
}
293312

294-
PyObject* frontier = PyList_New(0);
313+
frontier = PyList_New(0);
295314
if(frontier == NULL){
296-
return PyErr_NoMemory();
315+
result = PyErr_NoMemory();
316+
goto cleanup;
297317
}
298318

299-
if(PyList_Append(frontier, obj)){
300-
Py_DECREF(frontier);
301-
return PyErr_NoMemory();
319+
if(push(frontier, obj)){
320+
result = PyErr_NoMemory();
321+
goto cleanup;
302322
}
303323

304-
PyObject* frozen_importlib = PyImport_ImportModule("_frozen_importlib");
324+
frozen_importlib = PyImport_ImportModule("_frozen_importlib");
305325
if(frozen_importlib == NULL){
306-
Py_DECREF(frontier);
307-
return NULL;
326+
result = NULL;
327+
goto cleanup;
308328
}
309329

310-
PyObject* blocking_on = PyObject_GetAttrString(frozen_importlib, "_blocking_on");
330+
blocking_on = PyObject_GetAttrString(frozen_importlib, "_blocking_on");
311331
if(blocking_on == NULL){
312-
Py_DECREF(frozen_importlib);
313-
Py_DECREF(frontier);
314-
return NULL;
332+
result = NULL;
333+
goto cleanup;
315334
}
316335

317-
PyObject* module_locks = PyObject_GetAttrString(frozen_importlib, "_module_locks");
336+
module_locks = PyObject_GetAttrString(frozen_importlib, "_module_locks");
318337
if(module_locks == NULL){
319-
Py_DECREF(blocking_on);
320-
Py_DECREF(frozen_importlib);
321-
Py_DECREF(frontier);
322-
return NULL;
338+
result = NULL;
339+
goto cleanup;
323340
}
324341

325-
Py_DECREF(frozen_importlib);
326-
327342
while(PyList_Size(frontier) != 0){
343+
PyTypeObject* type;
344+
PyObject* type_op;
345+
traverseproc traverse;
328346
PyObject* item = pop(frontier);
329347

330348
if(item == blocking_on ||
@@ -334,9 +352,8 @@ PyObject* _Py_Freeze(PyObject* obj)
334352
continue;
335353
}
336354

337-
PyTypeObject* type = Py_TYPE(item);
338-
traverseproc traverse;
339-
PyObject* type_op = NULL;
355+
type = Py_TYPE(item);
356+
type_op = NULL;
340357

341358
if(_Py_IsImmutable(item)){
342359
continue;
@@ -350,42 +367,37 @@ PyObject* _Py_Freeze(PyObject* obj)
350367
}
351368

352369
if(PyFunction_Check(item)){
353-
PyObject* err = walk_function(item, frontier);
354-
if(!Py_IsNone(err)){
355-
Py_DECREF(blocking_on);
356-
Py_DECREF(module_locks);
357-
Py_DECREF(frontier);
358-
return err;
370+
result = walk_function(item, frontier);
371+
if(!Py_IsNone(result)){
372+
goto cleanup;
359373
}
360374
}
361375
else
362376
{
363377
traverse = type->tp_traverse;
364378
if(traverse != NULL){
365379
if(traverse(item, (visitproc)freeze_visit, frontier)){
366-
Py_DECREF(blocking_on);
367-
Py_DECREF(module_locks);
368-
Py_DECREF(frontier);
369-
return NULL;
380+
result = NULL;
381+
goto cleanup;
370382
}
371383
}
372384
}
373385

374386
type_op = _PyObject_CAST(item->ob_type);
375387
if (!_Py_IsImmutable(type_op)){
376-
if (PyList_Append(frontier, type_op))
388+
if(push(frontier, type_op))
377389
{
378-
Py_DECREF(blocking_on);
379-
Py_DECREF(module_locks);
380-
Py_DECREF(frontier);
381-
return PyErr_NoMemory();
390+
result = PyErr_NoMemory();
391+
goto cleanup;
382392
}
383393
}
384394
}
385395

386-
Py_DECREF(blocking_on);
387-
Py_DECREF(module_locks);
388-
Py_DECREF(frontier);
396+
cleanup:
397+
Py_XDECREF(blocking_on);
398+
Py_XDECREF(module_locks);
399+
Py_XDECREF(frozen_importlib);
400+
Py_XDECREF(frontier);
389401

390-
Py_RETURN_NONE;
402+
return result;
391403
}

0 commit comments

Comments
 (0)