Skip to content

Commit 15cb7c9

Browse files
Transurgeonclaude
andcommitted
Add parameter support to Python bindings
Expose the SparseDiffEngine parameter API to Python, enabling updatable problem parameters that can be changed between solves without reconstructing the expression tree. New bindings: make_parameter, make_param_scalar_mult, make_param_vector_mult, problem_register_params, problem_update_params. Existing constant/matmul bindings updated to use the unified parameter API internally (backward compatible). Also fixes stale elementwise_univariate.h includes renamed in the engine. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d9b72e9 commit 15cb7c9

13 files changed

Lines changed: 412 additions & 21 deletions

SparseDiffEngine

Submodule SparseDiffEngine updated 71 files

sparsediffpy/_bindings/atoms/const_scalar_mult.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,18 @@ static PyObject *py_make_const_scalar_mult(PyObject *self, PyObject *args)
2222
return NULL;
2323
}
2424

25-
expr *node = new_const_scalar_mult(a, child);
25+
expr *a_node = new_parameter(1, 1, PARAM_FIXED, child->n_vars, &a);
26+
if (!a_node)
27+
{
28+
PyErr_SetString(PyExc_RuntimeError,
29+
"failed to create parameter node for scalar");
30+
return NULL;
31+
}
32+
33+
expr *node = new_scalar_mult(a_node, child);
2634
if (!node)
2735
{
36+
free_expr(a_node);
2837
PyErr_SetString(PyExc_RuntimeError,
2938
"failed to create const_scalar_mult node");
3039
return NULL;

sparsediffpy/_bindings/atoms/const_vector_mult.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,22 @@ static PyObject *py_make_const_vector_mult(PyObject *self, PyObject *args)
4242

4343
double *a_data = (double *) PyArray_DATA(a_array);
4444

45-
expr *node = new_const_vector_mult(a_data, child);
45+
expr *a_node =
46+
new_parameter(a_size, 1, PARAM_FIXED, child->n_vars, a_data);
4647

4748
Py_DECREF(a_array);
4849

50+
if (!a_node)
51+
{
52+
PyErr_SetString(PyExc_RuntimeError,
53+
"failed to create parameter node for vector");
54+
return NULL;
55+
}
56+
57+
expr *node = new_vector_mult(a_node, child);
4958
if (!node)
5059
{
60+
free_expr(a_node);
5161
PyErr_SetString(PyExc_RuntimeError,
5262
"failed to create const_vector_mult node");
5363
return NULL;

sparsediffpy/_bindings/atoms/constant.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ static PyObject *py_make_constant(PyObject *self, PyObject *args)
1919
return NULL;
2020
}
2121

22-
expr *node =
23-
new_constant(d1, d2, n_vars, (const double *) PyArray_DATA(values_array));
22+
expr *node = new_parameter(d1, d2, PARAM_FIXED, n_vars,
23+
(const double *) PyArray_DATA(values_array));
2424
Py_DECREF(values_array);
2525

2626
if (!node)

sparsediffpy/_bindings/atoms/dense_matmul.h

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,20 @@
77
/* Dense left matrix multiplication: A @ f(x) where A is a dense matrix.
88
*
99
* Python signature:
10-
* make_dense_left_matmul(child, A_data_flat, m, n)
10+
* make_dense_left_matmul(param_or_none, child, A_data_flat, m, n)
1111
*
12+
* - param_or_none: None for constant matrix, or a parameter capsule.
1213
* - child: the child expression capsule f(x).
1314
* - A_data_flat: contiguous row-major numpy float64 array of size m*n.
1415
* - m, n: dimensions of matrix A. */
1516
static PyObject *py_make_dense_left_matmul(PyObject *self, PyObject *args)
1617
{
18+
PyObject *param_obj;
1719
PyObject *child_capsule;
1820
PyObject *data_obj;
1921
int m, n;
20-
if (!PyArg_ParseTuple(args, "OOii", &child_capsule, &data_obj, &m, &n))
22+
if (!PyArg_ParseTuple(args, "OOOii", &param_obj, &child_capsule,
23+
&data_obj, &m, &n))
2124
{
2225
return NULL;
2326
}
@@ -38,11 +41,38 @@ static PyObject *py_make_dense_left_matmul(PyObject *self, PyObject *args)
3841

3942
double *A_data = (double *) PyArray_DATA(data_array);
4043

41-
expr *node = new_left_matmul_dense(child, m, n, A_data);
44+
/* Build the parameter node: use provided capsule or create PARAM_FIXED */
45+
expr *param_node = NULL;
46+
if (param_obj == Py_None)
47+
{
48+
param_node =
49+
new_parameter(m * n, 1, PARAM_FIXED, child->n_vars, A_data);
50+
if (!param_node)
51+
{
52+
Py_DECREF(data_array);
53+
PyErr_SetString(PyExc_RuntimeError,
54+
"failed to create parameter node for dense matrix");
55+
return NULL;
56+
}
57+
}
58+
else
59+
{
60+
param_node =
61+
(expr *) PyCapsule_GetPointer(param_obj, EXPR_CAPSULE_NAME);
62+
if (!param_node)
63+
{
64+
Py_DECREF(data_array);
65+
PyErr_SetString(PyExc_ValueError, "invalid parameter capsule");
66+
return NULL;
67+
}
68+
}
69+
70+
expr *node = new_left_matmul_dense(param_node, child, m, n, A_data);
4271
Py_DECREF(data_array);
4372

4473
if (!node)
4574
{
75+
if (param_obj == Py_None) free_expr(param_node);
4676
PyErr_SetString(PyExc_RuntimeError,
4777
"failed to create dense_left_matmul node");
4878
return NULL;
@@ -54,17 +84,20 @@ static PyObject *py_make_dense_left_matmul(PyObject *self, PyObject *args)
5484
/* Dense right matrix multiplication: f(x) @ A where A is a dense matrix.
5585
*
5686
* Python signature:
57-
* make_dense_right_matmul(child, A_data_flat, m, n)
87+
* make_dense_right_matmul(param_or_none, child, A_data_flat, m, n)
5888
*
89+
* - param_or_none: None for constant matrix, or a parameter capsule.
5990
* - child: the child expression capsule f(x).
6091
* - A_data_flat: contiguous row-major numpy float64 array of size m*n.
6192
* - m, n: dimensions of matrix A. */
6293
static PyObject *py_make_dense_right_matmul(PyObject *self, PyObject *args)
6394
{
95+
PyObject *param_obj;
6496
PyObject *child_capsule;
6597
PyObject *data_obj;
6698
int m, n;
67-
if (!PyArg_ParseTuple(args, "OOii", &child_capsule, &data_obj, &m, &n))
99+
if (!PyArg_ParseTuple(args, "OOOii", &param_obj, &child_capsule,
100+
&data_obj, &m, &n))
68101
{
69102
return NULL;
70103
}
@@ -85,11 +118,38 @@ static PyObject *py_make_dense_right_matmul(PyObject *self, PyObject *args)
85118

86119
double *A_data = (double *) PyArray_DATA(data_array);
87120

88-
expr *node = new_right_matmul_dense(child, m, n, A_data);
121+
/* Build the parameter node: use provided capsule or create PARAM_FIXED */
122+
expr *param_node = NULL;
123+
if (param_obj == Py_None)
124+
{
125+
param_node =
126+
new_parameter(m * n, 1, PARAM_FIXED, child->n_vars, A_data);
127+
if (!param_node)
128+
{
129+
Py_DECREF(data_array);
130+
PyErr_SetString(PyExc_RuntimeError,
131+
"failed to create parameter node for dense matrix");
132+
return NULL;
133+
}
134+
}
135+
else
136+
{
137+
param_node =
138+
(expr *) PyCapsule_GetPointer(param_obj, EXPR_CAPSULE_NAME);
139+
if (!param_node)
140+
{
141+
Py_DECREF(data_array);
142+
PyErr_SetString(PyExc_ValueError, "invalid parameter capsule");
143+
return NULL;
144+
}
145+
}
146+
147+
expr *node = new_right_matmul_dense(param_node, child, m, n, A_data);
89148
Py_DECREF(data_array);
90149

91150
if (!node)
92151
{
152+
if (param_obj == Py_None) free_expr(param_node);
93153
PyErr_SetString(PyExc_RuntimeError,
94154
"failed to create dense_right_matmul node");
95155
return NULL;
@@ -98,4 +158,4 @@ static PyObject *py_make_dense_right_matmul(PyObject *self, PyObject *args)
98158
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
99159
}
100160

101-
#endif /* ATOM_DENSE_MATMUL_H */
161+
#endif /* ATOM_DENSE_MATMUL_H */

sparsediffpy/_bindings/atoms/left_matmul.h

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,24 @@
44
#include "bivariate_full_dom.h"
55
#include "common.h"
66

7-
/* Left matrix multiplication: A @ f(x) where A is a constant matrix */
7+
/* Left matrix multiplication: A @ f(x) where A is a constant or parameter
8+
* sparse matrix.
9+
*
10+
* Python signature:
11+
* make_sparse_left_matmul(param_or_none, child, data, indices, indptr, m, n)
12+
*
13+
* - param_or_none: None for constant matrix, or a parameter capsule.
14+
* - child: the child expression capsule f(x).
15+
* - data, indices, indptr: CSR arrays for matrix A.
16+
* - m, n: dimensions of matrix A. */
817
static PyObject *py_make_sparse_left_matmul(PyObject *self, PyObject *args)
918
{
19+
PyObject *param_obj;
1020
PyObject *child_capsule;
1121
PyObject *data_obj, *indices_obj, *indptr_obj;
1222
int m, n;
13-
if (!PyArg_ParseTuple(args, "OOOOii", &child_capsule, &data_obj, &indices_obj,
14-
&indptr_obj, &m, &n))
23+
if (!PyArg_ParseTuple(args, "OOOOOii", &param_obj, &child_capsule,
24+
&data_obj, &indices_obj, &indptr_obj, &m, &n))
1525
{
1626
return NULL;
1727
}
@@ -39,6 +49,37 @@ static PyObject *py_make_sparse_left_matmul(PyObject *self, PyObject *args)
3949
}
4050

4151
int nnz = (int) PyArray_SIZE(data_array);
52+
53+
/* Build the parameter node: use provided capsule or create PARAM_FIXED */
54+
expr *param_node = NULL;
55+
if (param_obj == Py_None)
56+
{
57+
param_node = new_parameter(nnz, 1, PARAM_FIXED, child->n_vars,
58+
(const double *) PyArray_DATA(data_array));
59+
if (!param_node)
60+
{
61+
Py_DECREF(data_array);
62+
Py_DECREF(indices_array);
63+
Py_DECREF(indptr_array);
64+
PyErr_SetString(PyExc_RuntimeError,
65+
"failed to create parameter node for matrix");
66+
return NULL;
67+
}
68+
}
69+
else
70+
{
71+
param_node =
72+
(expr *) PyCapsule_GetPointer(param_obj, EXPR_CAPSULE_NAME);
73+
if (!param_node)
74+
{
75+
Py_DECREF(data_array);
76+
Py_DECREF(indices_array);
77+
Py_DECREF(indptr_array);
78+
PyErr_SetString(PyExc_ValueError, "invalid parameter capsule");
79+
return NULL;
80+
}
81+
}
82+
4283
CSR_Matrix *A = new_csr_matrix(m, n, nnz);
4384
memcpy(A->x, PyArray_DATA(data_array), nnz * sizeof(double));
4485
memcpy(A->i, PyArray_DATA(indices_array), nnz * sizeof(int));
@@ -48,11 +89,12 @@ static PyObject *py_make_sparse_left_matmul(PyObject *self, PyObject *args)
4889
Py_DECREF(indices_array);
4990
Py_DECREF(indptr_array);
5091

51-
expr *node = new_left_matmul(child, A);
92+
expr *node = new_left_matmul(param_node, child, A);
5293
free_csr_matrix(A);
5394

5495
if (!node)
5596
{
97+
if (param_obj == Py_None) free_expr(param_node);
5698
PyErr_SetString(PyExc_RuntimeError, "failed to create left_matmul node");
5799
return NULL;
58100
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef ATOM_PARAMETER_H
2+
#define ATOM_PARAMETER_H
3+
4+
#include "common.h"
5+
6+
static PyObject *py_make_parameter(PyObject *self, PyObject *args)
7+
{
8+
int d1, d2, param_id, n_vars;
9+
if (!PyArg_ParseTuple(args, "iiii", &d1, &d2, &param_id, &n_vars))
10+
{
11+
return NULL;
12+
}
13+
14+
expr *node = new_parameter(d1, d2, param_id, n_vars, NULL);
15+
if (!node)
16+
{
17+
PyErr_SetString(PyExc_RuntimeError, "failed to create parameter node");
18+
return NULL;
19+
}
20+
expr_retain(node); /* Capsule owns a reference */
21+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
22+
}
23+
24+
#endif /* ATOM_PARAMETER_H */

sparsediffpy/_bindings/atoms/right_matmul.h

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,24 @@
44
#include "bivariate_full_dom.h"
55
#include "common.h"
66

7-
/* Right matrix multiplication: f(x) @ A where A is a constant matrix */
7+
/* Right matrix multiplication: f(x) @ A where A is a constant or parameter
8+
* sparse matrix.
9+
*
10+
* Python signature:
11+
* make_sparse_right_matmul(param_or_none, child, data, indices, indptr, m, n)
12+
*
13+
* - param_or_none: None for constant matrix, or a parameter capsule.
14+
* - child: the child expression capsule f(x).
15+
* - data, indices, indptr: CSR arrays for matrix A.
16+
* - m, n: dimensions of matrix A. */
817
static PyObject *py_make_sparse_right_matmul(PyObject *self, PyObject *args)
918
{
19+
PyObject *param_obj;
1020
PyObject *child_capsule;
1121
PyObject *data_obj, *indices_obj, *indptr_obj;
1222
int m, n;
13-
if (!PyArg_ParseTuple(args, "OOOOii", &child_capsule, &data_obj, &indices_obj,
14-
&indptr_obj, &m, &n))
23+
if (!PyArg_ParseTuple(args, "OOOOOii", &param_obj, &child_capsule,
24+
&data_obj, &indices_obj, &indptr_obj, &m, &n))
1525
{
1626
return NULL;
1727
}
@@ -39,6 +49,37 @@ static PyObject *py_make_sparse_right_matmul(PyObject *self, PyObject *args)
3949
}
4050

4151
int nnz = (int) PyArray_SIZE(data_array);
52+
53+
/* Build the parameter node: use provided capsule or create PARAM_FIXED */
54+
expr *param_node = NULL;
55+
if (param_obj == Py_None)
56+
{
57+
param_node = new_parameter(nnz, 1, PARAM_FIXED, child->n_vars,
58+
(const double *) PyArray_DATA(data_array));
59+
if (!param_node)
60+
{
61+
Py_DECREF(data_array);
62+
Py_DECREF(indices_array);
63+
Py_DECREF(indptr_array);
64+
PyErr_SetString(PyExc_RuntimeError,
65+
"failed to create parameter node for matrix");
66+
return NULL;
67+
}
68+
}
69+
else
70+
{
71+
param_node =
72+
(expr *) PyCapsule_GetPointer(param_obj, EXPR_CAPSULE_NAME);
73+
if (!param_node)
74+
{
75+
Py_DECREF(data_array);
76+
Py_DECREF(indices_array);
77+
Py_DECREF(indptr_array);
78+
PyErr_SetString(PyExc_ValueError, "invalid parameter capsule");
79+
return NULL;
80+
}
81+
}
82+
4283
CSR_Matrix *A = new_csr_matrix(m, n, nnz);
4384
memcpy(A->x, PyArray_DATA(data_array), nnz * sizeof(double));
4485
memcpy(A->i, PyArray_DATA(indices_array), nnz * sizeof(int));
@@ -48,12 +89,14 @@ static PyObject *py_make_sparse_right_matmul(PyObject *self, PyObject *args)
4889
Py_DECREF(indices_array);
4990
Py_DECREF(indptr_array);
5091

51-
expr *node = new_right_matmul(child, A);
92+
expr *node = new_right_matmul(param_node, child, A);
5293
free_csr_matrix(A);
5394

5495
if (!node)
5596
{
56-
PyErr_SetString(PyExc_RuntimeError, "failed to create right_matmul node");
97+
if (param_obj == Py_None) free_expr(param_node);
98+
PyErr_SetString(PyExc_RuntimeError,
99+
"failed to create right_matmul node");
57100
return NULL;
58101
}
59102
expr_retain(node); /* Capsule owns a reference */

0 commit comments

Comments
 (0)