@@ -178,13 +178,73 @@ static PyObject *py_forward(PyObject *self, PyObject *args)
178178 return out ;
179179}
180180
181+ static PyObject * py_jacobian (PyObject * self , PyObject * args )
182+ {
183+ PyObject * node_capsule ;
184+ PyObject * u_obj ;
185+ if (!PyArg_ParseTuple (args , "OO" , & node_capsule , & u_obj ))
186+ {
187+ return NULL ;
188+ }
189+
190+ expr * node = (expr * ) PyCapsule_GetPointer (node_capsule , EXPR_CAPSULE_NAME );
191+ if (!node )
192+ {
193+ PyErr_SetString (PyExc_ValueError , "invalid node capsule" );
194+ return NULL ;
195+ }
196+
197+ PyArrayObject * u_array =
198+ (PyArrayObject * ) PyArray_FROM_OTF (u_obj , NPY_DOUBLE , NPY_ARRAY_IN_ARRAY );
199+ if (!u_array )
200+ {
201+ return NULL ;
202+ }
203+
204+ // Run forward pass first (required before jacobian)
205+ node -> forward (node , (const double * ) PyArray_DATA (u_array ));
206+
207+ // Initialize and evaluate jacobian
208+ node -> jacobian_init (node );
209+ node -> eval_jacobian (node );
210+
211+ CSR_Matrix * jac = node -> jacobian ;
212+
213+ // Create numpy arrays for CSR components
214+ npy_intp nnz = jac -> nnz ;
215+ npy_intp m_plus_1 = jac -> m + 1 ;
216+
217+ PyObject * data = PyArray_SimpleNew (1 , & nnz , NPY_DOUBLE );
218+ PyObject * indices = PyArray_SimpleNew (1 , & nnz , NPY_INT32 );
219+ PyObject * indptr = PyArray_SimpleNew (1 , & m_plus_1 , NPY_INT32 );
220+
221+ if (!data || !indices || !indptr )
222+ {
223+ Py_XDECREF (data );
224+ Py_XDECREF (indices );
225+ Py_XDECREF (indptr );
226+ Py_DECREF (u_array );
227+ return NULL ;
228+ }
229+
230+ memcpy (PyArray_DATA ((PyArrayObject * ) data ), jac -> x , nnz * sizeof (double ));
231+ memcpy (PyArray_DATA ((PyArrayObject * ) indices ), jac -> i , nnz * sizeof (int ));
232+ memcpy (PyArray_DATA ((PyArrayObject * ) indptr ), jac -> p , m_plus_1 * sizeof (int ));
233+
234+ Py_DECREF (u_array );
235+
236+ // Return tuple: (data, indices, indptr, shape)
237+ return Py_BuildValue ("(OOO(ii))" , data , indices , indptr , jac -> m , jac -> n );
238+ }
239+
181240static PyMethodDef DNLPMethods [] = {
182241 {"make_variable" , py_make_variable , METH_VARARGS , "Create variable node" },
183242 {"make_log" , py_make_log , METH_VARARGS , "Create log node" },
184243 {"make_exp" , py_make_exp , METH_VARARGS , "Create exp node" },
185244 {"make_add" , py_make_add , METH_VARARGS , "Create add node" },
186245 {"make_sum" , py_make_sum , METH_VARARGS , "Create sum node" },
187246 {"forward" , py_forward , METH_VARARGS , "Run forward pass and return values" },
247+ {"jacobian" , py_jacobian , METH_VARARGS , "Compute jacobian and return CSR components" },
188248 {NULL , NULL , 0 , NULL }};
189249
190250static struct PyModuleDef dnlp_module = {PyModuleDef_HEAD_INIT , "DNLP_diff_engine" ,
0 commit comments