Skip to content

Commit 0fa643c

Browse files
Transurgeonclaude
andcommitted
Bindings for diag_mat, kron_left, upper_tri, and vstack atoms
Add Python C extension bindings for four new SparseDiffEngine atoms: - diag_mat: extract diagonal from square matrix - upper_tri: extract strict upper triangular elements - kron_left: Kronecker product kron(C, X) with constant sparse C - vstack: vertical stack of expressions (via transpose-hstack composition) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c316612 commit 0fa643c

5 files changed

Lines changed: 197 additions & 0 deletions

File tree

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef ATOM_DIAG_MAT_H
2+
#define ATOM_DIAG_MAT_H
3+
4+
#include "common.h"
5+
6+
static PyObject *py_make_diag_mat(PyObject *self, PyObject *args)
7+
{
8+
PyObject *child_capsule;
9+
10+
if (!PyArg_ParseTuple(args, "O", &child_capsule))
11+
{
12+
return NULL;
13+
}
14+
15+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
16+
if (!child)
17+
{
18+
return NULL;
19+
}
20+
21+
expr *node = new_diag_mat(child);
22+
if (!node)
23+
{
24+
PyErr_SetString(PyExc_RuntimeError, "failed to create diag_mat node");
25+
return NULL;
26+
}
27+
28+
expr_retain(node);
29+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
30+
}
31+
32+
#endif /* ATOM_DIAG_MAT_H */
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#ifndef ATOM_KRON_LEFT_H
2+
#define ATOM_KRON_LEFT_H
3+
4+
#include "bivariate.h"
5+
#include "common.h"
6+
7+
/*
8+
* Python signature:
9+
* make_kron_left(child_capsule, C_data, C_indices, C_indptr, m, n, p, q)
10+
*
11+
* Creates kron(C, X) where C is (m x n) constant CSR matrix and X is the
12+
* child expression of shape (p x q).
13+
*/
14+
static PyObject *py_make_kron_left(PyObject *self, PyObject *args)
15+
{
16+
(void) self;
17+
PyObject *child_capsule;
18+
PyObject *data_obj, *indices_obj, *indptr_obj;
19+
int m, n, p, q;
20+
21+
if (!PyArg_ParseTuple(args, "OOOOiiii", &child_capsule, &data_obj,
22+
&indices_obj, &indptr_obj, &m, &n, &p, &q))
23+
{
24+
return NULL;
25+
}
26+
27+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
28+
if (!child)
29+
{
30+
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
31+
return NULL;
32+
}
33+
34+
PyArrayObject *data_array =
35+
(PyArrayObject *) PyArray_FROM_OTF(data_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
36+
PyArrayObject *indices_array =
37+
(PyArrayObject *) PyArray_FROM_OTF(indices_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY);
38+
PyArrayObject *indptr_array =
39+
(PyArrayObject *) PyArray_FROM_OTF(indptr_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY);
40+
41+
if (!data_array || !indices_array || !indptr_array)
42+
{
43+
Py_XDECREF(data_array);
44+
Py_XDECREF(indices_array);
45+
Py_XDECREF(indptr_array);
46+
return NULL;
47+
}
48+
49+
int nnz = (int) PyArray_SIZE(data_array);
50+
CSR_Matrix *C = new_csr_matrix(m, n, nnz);
51+
memcpy(C->x, PyArray_DATA(data_array), (size_t) nnz * sizeof(double));
52+
memcpy(C->i, PyArray_DATA(indices_array), (size_t) nnz * sizeof(int));
53+
memcpy(C->p, PyArray_DATA(indptr_array), (size_t)(m + 1) * sizeof(int));
54+
55+
Py_DECREF(data_array);
56+
Py_DECREF(indices_array);
57+
Py_DECREF(indptr_array);
58+
59+
expr *node = new_kron_left(child, C, p, q);
60+
free_csr_matrix(C);
61+
62+
if (!node)
63+
{
64+
PyErr_SetString(PyExc_RuntimeError, "failed to create kron_left node");
65+
return NULL;
66+
}
67+
68+
expr_retain(node);
69+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
70+
}
71+
72+
#endif /* ATOM_KRON_LEFT_H */
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef ATOM_UPPER_TRI_H
2+
#define ATOM_UPPER_TRI_H
3+
4+
#include "common.h"
5+
6+
static PyObject *py_make_upper_tri(PyObject *self, PyObject *args)
7+
{
8+
PyObject *child_capsule;
9+
10+
if (!PyArg_ParseTuple(args, "O", &child_capsule))
11+
{
12+
return NULL;
13+
}
14+
15+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
16+
if (!child)
17+
{
18+
return NULL;
19+
}
20+
21+
expr *node = new_upper_tri(child);
22+
if (!node)
23+
{
24+
PyErr_SetString(PyExc_RuntimeError, "failed to create upper_tri node");
25+
return NULL;
26+
}
27+
28+
expr_retain(node);
29+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
30+
}
31+
32+
#endif /* ATOM_UPPER_TRI_H */
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#ifndef ATOM_VSTACK_H
2+
#define ATOM_VSTACK_H
3+
4+
#include "common.h"
5+
6+
static PyObject *py_make_vstack(PyObject *self, PyObject *args)
7+
{
8+
(void) self;
9+
PyObject *list_obj;
10+
if (!PyArg_ParseTuple(args, "O", &list_obj))
11+
{
12+
return NULL;
13+
}
14+
if (!PyList_Check(list_obj))
15+
{
16+
PyErr_SetString(PyExc_TypeError,
17+
"First argument must be a list of expr capsules");
18+
return NULL;
19+
}
20+
Py_ssize_t n_args = PyList_Size(list_obj);
21+
if (n_args == 0)
22+
{
23+
PyErr_SetString(PyExc_ValueError, "List of expr capsules cannot be empty");
24+
return NULL;
25+
}
26+
expr **expr_args = (expr **) calloc(n_args, sizeof(expr *));
27+
for (Py_ssize_t i = 0; i < n_args; ++i)
28+
{
29+
PyObject *item = PyList_GetItem(list_obj, i);
30+
expr *e = (expr *) PyCapsule_GetPointer(item, EXPR_CAPSULE_NAME);
31+
if (!e)
32+
{
33+
free(expr_args);
34+
PyErr_SetString(PyExc_ValueError, "Invalid expr capsule in list");
35+
return NULL;
36+
}
37+
expr_args[i] = e;
38+
}
39+
int n_vars = expr_args[0]->n_vars;
40+
expr *node = new_vstack(expr_args, (int) n_args, n_vars);
41+
free(expr_args);
42+
if (!node)
43+
{
44+
PyErr_SetString(PyExc_RuntimeError, "failed to create vstack node");
45+
return NULL;
46+
}
47+
expr_retain(node);
48+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
49+
}
50+
51+
#endif // ATOM_VSTACK_H

sparsediffpy/_bindings/bindings.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
#include "atoms/constant.h"
1313
#include "atoms/cos.h"
1414
#include "atoms/dense_matmul.h"
15+
#include "atoms/diag_mat.h"
1516
#include "atoms/diag_vec.h"
1617
#include "atoms/entr.h"
1718
#include "atoms/exp.h"
1819
#include "atoms/getters.h"
1920
#include "atoms/hstack.h"
2021
#include "atoms/index.h"
22+
#include "atoms/kron_left.h"
2123
#include "atoms/left_matmul.h"
2224
#include "atoms/linear.h"
2325
#include "atoms/log.h"
@@ -45,7 +47,9 @@
4547
#include "atoms/tanh.h"
4648
#include "atoms/trace.h"
4749
#include "atoms/transpose.h"
50+
#include "atoms/upper_tri.h"
4851
#include "atoms/variable.h"
52+
#include "atoms/vstack.h"
4953
#include "atoms/xexp.h"
5054

5155
/* Include problem bindings */
@@ -82,6 +86,8 @@ static PyMethodDef DNLPMethods[] = {
8286
{"make_hstack", py_make_hstack, METH_VARARGS,
8387
"Create hstack node from list of expr capsules and n_vars (make_hstack([e1, "
8488
"e2, ...], n_vars))"},
89+
{"make_vstack", py_make_vstack, METH_VARARGS,
90+
"Create vstack node from list of expr capsules (make_vstack([e1, e2, ...]))"},
8591
{"make_sum", py_make_sum, METH_VARARGS, "Create sum node"},
8692
{"make_neg", py_make_neg, METH_VARARGS, "Create neg node"},
8793
{"make_normal_cdf", py_make_normal_cdf, METH_VARARGS, "Create normal_cdf node"},
@@ -102,16 +108,20 @@ static PyMethodDef DNLPMethods[] = {
102108
"Create prod_axis_one node"},
103109
{"make_sin", py_make_sin, METH_VARARGS, "Create sin node"},
104110
{"make_cos", py_make_cos, METH_VARARGS, "Create cos node"},
111+
{"make_diag_mat", py_make_diag_mat, METH_VARARGS, "Create diag_mat node"},
105112
{"make_diag_vec", py_make_diag_vec, METH_VARARGS, "Create diag_vec node"},
106113
{"make_tan", py_make_tan, METH_VARARGS, "Create tan node"},
107114
{"make_sinh", py_make_sinh, METH_VARARGS, "Create sinh node"},
108115
{"make_tanh", py_make_tanh, METH_VARARGS, "Create tanh node"},
109116
{"make_asinh", py_make_asinh, METH_VARARGS, "Create asinh node"},
110117
{"make_atanh", py_make_atanh, METH_VARARGS, "Create atanh node"},
118+
{"make_upper_tri", py_make_upper_tri, METH_VARARGS, "Create upper_tri node"},
111119
{"make_broadcast", py_make_broadcast, METH_VARARGS, "Create broadcast node"},
112120
{"make_entr", py_make_entr, METH_VARARGS, "Create entr node"},
113121
{"make_logistic", py_make_logistic, METH_VARARGS, "Create logistic node"},
114122
{"make_xexp", py_make_xexp, METH_VARARGS, "Create xexp node"},
123+
{"make_kron_left", py_make_kron_left, METH_VARARGS,
124+
"Create kron(C, X) node where C is constant sparse matrix"},
115125
{"make_sparse_left_matmul", py_make_sparse_left_matmul, METH_VARARGS,
116126
"Create sparse left matmul node (A @ f(x))"},
117127
{"make_dense_left_matmul", py_make_dense_left_matmul, METH_VARARGS,

0 commit comments

Comments
 (0)