|
10 | 10 | // Capsule name for expr* pointers |
11 | 11 | #define EXPR_CAPSULE_NAME "DNLP_EXPR" |
12 | 12 |
|
| 13 | +static int numpy_initialized = 0; |
| 14 | + |
13 | 15 | static int ensure_numpy(void) |
14 | 16 | { |
15 | | - import_array(); |
| 17 | + if (numpy_initialized) return 0; |
| 18 | + import_array1(-1); |
| 19 | + numpy_initialized = 1; |
16 | 20 | return 0; |
17 | 21 | } |
18 | 22 |
|
@@ -65,6 +69,77 @@ static PyObject *py_make_log(PyObject *self, PyObject *args) |
65 | 69 | return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); |
66 | 70 | } |
67 | 71 |
|
| 72 | +static PyObject *py_make_exp(PyObject *self, PyObject *args) |
| 73 | +{ |
| 74 | + PyObject *child_capsule; |
| 75 | + if (!PyArg_ParseTuple(args, "O", &child_capsule)) |
| 76 | + { |
| 77 | + return NULL; |
| 78 | + } |
| 79 | + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); |
| 80 | + if (!child) |
| 81 | + { |
| 82 | + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); |
| 83 | + return NULL; |
| 84 | + } |
| 85 | + |
| 86 | + expr *node = new_exp(child); |
| 87 | + if (!node) |
| 88 | + { |
| 89 | + PyErr_SetString(PyExc_RuntimeError, "failed to create exp node"); |
| 90 | + return NULL; |
| 91 | + } |
| 92 | + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); |
| 93 | +} |
| 94 | + |
| 95 | +static PyObject *py_make_add(PyObject *self, PyObject *args) |
| 96 | +{ |
| 97 | + PyObject *left_capsule, *right_capsule; |
| 98 | + if (!PyArg_ParseTuple(args, "OO", &left_capsule, &right_capsule)) |
| 99 | + { |
| 100 | + return NULL; |
| 101 | + } |
| 102 | + expr *left = (expr *) PyCapsule_GetPointer(left_capsule, EXPR_CAPSULE_NAME); |
| 103 | + expr *right = (expr *) PyCapsule_GetPointer(right_capsule, EXPR_CAPSULE_NAME); |
| 104 | + if (!left || !right) |
| 105 | + { |
| 106 | + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); |
| 107 | + return NULL; |
| 108 | + } |
| 109 | + |
| 110 | + expr *node = new_add(left, right); |
| 111 | + if (!node) |
| 112 | + { |
| 113 | + PyErr_SetString(PyExc_RuntimeError, "failed to create add node"); |
| 114 | + return NULL; |
| 115 | + } |
| 116 | + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); |
| 117 | +} |
| 118 | + |
| 119 | +static PyObject *py_make_sum(PyObject *self, PyObject *args) |
| 120 | +{ |
| 121 | + PyObject *child_capsule; |
| 122 | + int axis; |
| 123 | + if (!PyArg_ParseTuple(args, "Oi", &child_capsule, &axis)) |
| 124 | + { |
| 125 | + return NULL; |
| 126 | + } |
| 127 | + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); |
| 128 | + if (!child) |
| 129 | + { |
| 130 | + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); |
| 131 | + return NULL; |
| 132 | + } |
| 133 | + |
| 134 | + expr *node = new_sum(child, axis); |
| 135 | + if (!node) |
| 136 | + { |
| 137 | + PyErr_SetString(PyExc_RuntimeError, "failed to create sum node"); |
| 138 | + return NULL; |
| 139 | + } |
| 140 | + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); |
| 141 | +} |
| 142 | + |
68 | 143 | static PyObject *py_forward(PyObject *self, PyObject *args) |
69 | 144 | { |
70 | 145 | PyObject *node_capsule; |
@@ -106,6 +181,9 @@ static PyObject *py_forward(PyObject *self, PyObject *args) |
106 | 181 | static PyMethodDef DNLPMethods[] = { |
107 | 182 | {"make_variable", py_make_variable, METH_VARARGS, "Create variable node"}, |
108 | 183 | {"make_log", py_make_log, METH_VARARGS, "Create log node"}, |
| 184 | + {"make_exp", py_make_exp, METH_VARARGS, "Create exp node"}, |
| 185 | + {"make_add", py_make_add, METH_VARARGS, "Create add node"}, |
| 186 | + {"make_sum", py_make_sum, METH_VARARGS, "Create sum node"}, |
109 | 187 | {"forward", py_forward, METH_VARARGS, "Run forward pass and return values"}, |
110 | 188 | {NULL, NULL, 0, NULL}}; |
111 | 189 |
|
|
0 commit comments