Skip to content

Commit 1febd65

Browse files
committed
adds initial code for tree conversion from cvxpy
1 parent 4698872 commit 1febd65

3 files changed

Lines changed: 499 additions & 1 deletion

File tree

python/bindings.c

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@
1010
// Capsule name for expr* pointers
1111
#define EXPR_CAPSULE_NAME "DNLP_EXPR"
1212

13+
static int numpy_initialized = 0;
14+
1315
static int ensure_numpy(void)
1416
{
15-
import_array();
17+
if (numpy_initialized) return 0;
18+
import_array1(-1);
19+
numpy_initialized = 1;
1620
return 0;
1721
}
1822

@@ -65,6 +69,77 @@ static PyObject *py_make_log(PyObject *self, PyObject *args)
6569
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
6670
}
6771

72+
static PyObject *py_make_exp(PyObject *self, PyObject *args)
73+
{
74+
PyObject *child_capsule;
75+
if (!PyArg_ParseTuple(args, "O", &child_capsule))
76+
{
77+
return NULL;
78+
}
79+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
80+
if (!child)
81+
{
82+
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
83+
return NULL;
84+
}
85+
86+
expr *node = new_exp(child);
87+
if (!node)
88+
{
89+
PyErr_SetString(PyExc_RuntimeError, "failed to create exp node");
90+
return NULL;
91+
}
92+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
93+
}
94+
95+
static PyObject *py_make_add(PyObject *self, PyObject *args)
96+
{
97+
PyObject *left_capsule, *right_capsule;
98+
if (!PyArg_ParseTuple(args, "OO", &left_capsule, &right_capsule))
99+
{
100+
return NULL;
101+
}
102+
expr *left = (expr *) PyCapsule_GetPointer(left_capsule, EXPR_CAPSULE_NAME);
103+
expr *right = (expr *) PyCapsule_GetPointer(right_capsule, EXPR_CAPSULE_NAME);
104+
if (!left || !right)
105+
{
106+
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
107+
return NULL;
108+
}
109+
110+
expr *node = new_add(left, right);
111+
if (!node)
112+
{
113+
PyErr_SetString(PyExc_RuntimeError, "failed to create add node");
114+
return NULL;
115+
}
116+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
117+
}
118+
119+
static PyObject *py_make_sum(PyObject *self, PyObject *args)
120+
{
121+
PyObject *child_capsule;
122+
int axis;
123+
if (!PyArg_ParseTuple(args, "Oi", &child_capsule, &axis))
124+
{
125+
return NULL;
126+
}
127+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
128+
if (!child)
129+
{
130+
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
131+
return NULL;
132+
}
133+
134+
expr *node = new_sum(child, axis);
135+
if (!node)
136+
{
137+
PyErr_SetString(PyExc_RuntimeError, "failed to create sum node");
138+
return NULL;
139+
}
140+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
141+
}
142+
68143
static PyObject *py_forward(PyObject *self, PyObject *args)
69144
{
70145
PyObject *node_capsule;
@@ -106,6 +181,9 @@ static PyObject *py_forward(PyObject *self, PyObject *args)
106181
static PyMethodDef DNLPMethods[] = {
107182
{"make_variable", py_make_variable, METH_VARARGS, "Create variable node"},
108183
{"make_log", py_make_log, METH_VARARGS, "Create log node"},
184+
{"make_exp", py_make_exp, METH_VARARGS, "Create exp node"},
185+
{"make_add", py_make_add, METH_VARARGS, "Create add node"},
186+
{"make_sum", py_make_sum, METH_VARARGS, "Create sum node"},
109187
{"forward", py_forward, METH_VARARGS, "Run forward pass and return values"},
110188
{NULL, NULL, 0, NULL}};
111189

0 commit comments

Comments
 (0)