Skip to content

Commit ca03bcc

Browse files
committed
adds basic jacobian bindings and tests
1 parent c5eca4a commit ca03bcc

3 files changed

Lines changed: 250 additions & 144 deletions

File tree

python/bindings.c

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,73 @@ static PyObject *py_forward(PyObject *self, PyObject *args)
178178
return out;
179179
}
180180

181+
static PyObject *py_jacobian(PyObject *self, PyObject *args)
182+
{
183+
PyObject *node_capsule;
184+
PyObject *u_obj;
185+
if (!PyArg_ParseTuple(args, "OO", &node_capsule, &u_obj))
186+
{
187+
return NULL;
188+
}
189+
190+
expr *node = (expr *) PyCapsule_GetPointer(node_capsule, EXPR_CAPSULE_NAME);
191+
if (!node)
192+
{
193+
PyErr_SetString(PyExc_ValueError, "invalid node capsule");
194+
return NULL;
195+
}
196+
197+
PyArrayObject *u_array =
198+
(PyArrayObject *) PyArray_FROM_OTF(u_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
199+
if (!u_array)
200+
{
201+
return NULL;
202+
}
203+
204+
// Run forward pass first (required before jacobian)
205+
node->forward(node, (const double *) PyArray_DATA(u_array));
206+
207+
// Initialize and evaluate jacobian
208+
node->jacobian_init(node);
209+
node->eval_jacobian(node);
210+
211+
CSR_Matrix *jac = node->jacobian;
212+
213+
// Create numpy arrays for CSR components
214+
npy_intp nnz = jac->nnz;
215+
npy_intp m_plus_1 = jac->m + 1;
216+
217+
PyObject *data = PyArray_SimpleNew(1, &nnz, NPY_DOUBLE);
218+
PyObject *indices = PyArray_SimpleNew(1, &nnz, NPY_INT32);
219+
PyObject *indptr = PyArray_SimpleNew(1, &m_plus_1, NPY_INT32);
220+
221+
if (!data || !indices || !indptr)
222+
{
223+
Py_XDECREF(data);
224+
Py_XDECREF(indices);
225+
Py_XDECREF(indptr);
226+
Py_DECREF(u_array);
227+
return NULL;
228+
}
229+
230+
memcpy(PyArray_DATA((PyArrayObject *) data), jac->x, nnz * sizeof(double));
231+
memcpy(PyArray_DATA((PyArrayObject *) indices), jac->i, nnz * sizeof(int));
232+
memcpy(PyArray_DATA((PyArrayObject *) indptr), jac->p, m_plus_1 * sizeof(int));
233+
234+
Py_DECREF(u_array);
235+
236+
// Return tuple: (data, indices, indptr, shape)
237+
return Py_BuildValue("(OOO(ii))", data, indices, indptr, jac->m, jac->n);
238+
}
239+
181240
static PyMethodDef DNLPMethods[] = {
182241
{"make_variable", py_make_variable, METH_VARARGS, "Create variable node"},
183242
{"make_log", py_make_log, METH_VARARGS, "Create log node"},
184243
{"make_exp", py_make_exp, METH_VARARGS, "Create exp node"},
185244
{"make_add", py_make_add, METH_VARARGS, "Create add node"},
186245
{"make_sum", py_make_sum, METH_VARARGS, "Create sum node"},
187246
{"forward", py_forward, METH_VARARGS, "Run forward pass and return values"},
247+
{"jacobian", py_jacobian, METH_VARARGS, "Compute jacobian and return CSR components"},
188248
{NULL, NULL, 0, NULL}};
189249

190250
static struct PyModuleDef dnlp_module = {PyModuleDef_HEAD_INIT, "DNLP_diff_engine",

0 commit comments

Comments
 (0)