Skip to content

Commit 633ee43

Browse files
Transurgeonclaude
andcommitted
Add parameter support to Python bindings
Expose the SparseDiffEngine parameter API to Python, enabling updatable problem parameters that can be changed between solves without reconstructing the expression tree. New bindings: make_parameter, make_param_scalar_mult, make_param_vector_mult, problem_register_params, problem_update_params. Existing constant/matmul bindings updated to use the unified parameter API internally (backward compatible). Also fixes stale elementwise_univariate.h includes renamed in the engine. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a3b4668 commit 633ee43

25 files changed

Lines changed: 425 additions & 33 deletions

SparseDiffEngine

Submodule SparseDiffEngine updated 135 files

sparsediffpy/_bindings/atoms/asinh.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#define ATOM_ASINH_H
33

44
#include "common.h"
5-
#include "elementwise_univariate.h"
5+
#include "elementwise_full_dom.h"
66

77
static PyObject *py_make_asinh(PyObject *self, PyObject *args)
88
{

sparsediffpy/_bindings/atoms/atanh.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#define ATOM_ATANH_H
33

44
#include "common.h"
5-
#include "elementwise_univariate.h"
5+
#include "elementwise_restricted_dom.h"
66

77
static PyObject *py_make_atanh(PyObject *self, PyObject *args)
88
{

sparsediffpy/_bindings/atoms/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
#include <numpy/arrayobject.h>
77

88
#include "affine.h"
9-
#include "elementwise_univariate.h"
9+
#include "elementwise_full_dom.h"
10+
#include "elementwise_restricted_dom.h"
1011
#include "expr.h"
1112

1213
#define EXPR_CAPSULE_NAME "DNLP_EXPR"

sparsediffpy/_bindings/atoms/const_scalar_mult.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,18 @@ static PyObject *py_make_const_scalar_mult(PyObject *self, PyObject *args)
2222
return NULL;
2323
}
2424

25-
expr *node = new_const_scalar_mult(a, child);
25+
expr *a_node = new_parameter(1, 1, PARAM_FIXED, child->n_vars, &a);
26+
if (!a_node)
27+
{
28+
PyErr_SetString(PyExc_RuntimeError,
29+
"failed to create parameter node for scalar");
30+
return NULL;
31+
}
32+
33+
expr *node = new_scalar_mult(a_node, child);
2634
if (!node)
2735
{
36+
free_expr(a_node);
2837
PyErr_SetString(PyExc_RuntimeError,
2938
"failed to create const_scalar_mult node");
3039
return NULL;

sparsediffpy/_bindings/atoms/const_vector_mult.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,22 @@ static PyObject *py_make_const_vector_mult(PyObject *self, PyObject *args)
4242

4343
double *a_data = (double *) PyArray_DATA(a_array);
4444

45-
expr *node = new_const_vector_mult(a_data, child);
45+
expr *a_node =
46+
new_parameter(a_size, 1, PARAM_FIXED, child->n_vars, a_data);
4647

4748
Py_DECREF(a_array);
4849

50+
if (!a_node)
51+
{
52+
PyErr_SetString(PyExc_RuntimeError,
53+
"failed to create parameter node for vector");
54+
return NULL;
55+
}
56+
57+
expr *node = new_vector_mult(a_node, child);
4958
if (!node)
5059
{
60+
free_expr(a_node);
5161
PyErr_SetString(PyExc_RuntimeError,
5262
"failed to create const_vector_mult node");
5363
return NULL;

sparsediffpy/_bindings/atoms/constant.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ static PyObject *py_make_constant(PyObject *self, PyObject *args)
1919
return NULL;
2020
}
2121

22-
expr *node =
23-
new_constant(d1, d2, n_vars, (const double *) PyArray_DATA(values_array));
22+
expr *node = new_parameter(d1, d2, PARAM_FIXED, n_vars,
23+
(const double *) PyArray_DATA(values_array));
2424
Py_DECREF(values_array);
2525

2626
if (!node)

sparsediffpy/_bindings/atoms/cos.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#define ATOM_COS_H
33

44
#include "common.h"
5-
#include "elementwise_univariate.h"
5+
#include "elementwise_full_dom.h"
66

77
static PyObject *py_make_cos(PyObject *self, PyObject *args)
88
{

sparsediffpy/_bindings/atoms/dense_matmul.h

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,20 @@
77
/* Dense left matrix multiplication: A @ f(x) where A is a dense matrix.
88
*
99
* Python signature:
10-
* make_dense_left_matmul(child, A_data_flat, m, n)
10+
* make_dense_left_matmul(param_or_none, child, A_data_flat, m, n)
1111
*
12+
* - param_or_none: None for constant matrix, or a parameter capsule.
1213
* - child: the child expression capsule f(x).
1314
* - A_data_flat: contiguous row-major numpy float64 array of size m*n.
1415
* - m, n: dimensions of matrix A. */
1516
static PyObject *py_make_dense_left_matmul(PyObject *self, PyObject *args)
1617
{
18+
PyObject *param_obj;
1719
PyObject *child_capsule;
1820
PyObject *data_obj;
1921
int m, n;
20-
if (!PyArg_ParseTuple(args, "OOii", &child_capsule, &data_obj, &m, &n))
22+
if (!PyArg_ParseTuple(args, "OOOii", &param_obj, &child_capsule,
23+
&data_obj, &m, &n))
2124
{
2225
return NULL;
2326
}
@@ -38,11 +41,38 @@ static PyObject *py_make_dense_left_matmul(PyObject *self, PyObject *args)
3841

3942
double *A_data = (double *) PyArray_DATA(data_array);
4043

41-
expr *node = new_left_matmul_dense(child, m, n, A_data);
44+
/* Build the parameter node: use provided capsule or create PARAM_FIXED */
45+
expr *param_node = NULL;
46+
if (param_obj == Py_None)
47+
{
48+
param_node =
49+
new_parameter(m * n, 1, PARAM_FIXED, child->n_vars, A_data);
50+
if (!param_node)
51+
{
52+
Py_DECREF(data_array);
53+
PyErr_SetString(PyExc_RuntimeError,
54+
"failed to create parameter node for dense matrix");
55+
return NULL;
56+
}
57+
}
58+
else
59+
{
60+
param_node =
61+
(expr *) PyCapsule_GetPointer(param_obj, EXPR_CAPSULE_NAME);
62+
if (!param_node)
63+
{
64+
Py_DECREF(data_array);
65+
PyErr_SetString(PyExc_ValueError, "invalid parameter capsule");
66+
return NULL;
67+
}
68+
}
69+
70+
expr *node = new_left_matmul_dense(param_node, child, m, n, A_data);
4271
Py_DECREF(data_array);
4372

4473
if (!node)
4574
{
75+
if (param_obj == Py_None) free_expr(param_node);
4676
PyErr_SetString(PyExc_RuntimeError,
4777
"failed to create dense_left_matmul node");
4878
return NULL;
@@ -54,17 +84,20 @@ static PyObject *py_make_dense_left_matmul(PyObject *self, PyObject *args)
5484
/* Dense right matrix multiplication: f(x) @ A where A is a dense matrix.
5585
*
5686
* Python signature:
57-
* make_dense_right_matmul(child, A_data_flat, m, n)
87+
* make_dense_right_matmul(param_or_none, child, A_data_flat, m, n)
5888
*
89+
* - param_or_none: None for constant matrix, or a parameter capsule.
5990
* - child: the child expression capsule f(x).
6091
* - A_data_flat: contiguous row-major numpy float64 array of size m*n.
6192
* - m, n: dimensions of matrix A. */
6293
static PyObject *py_make_dense_right_matmul(PyObject *self, PyObject *args)
6394
{
95+
PyObject *param_obj;
6496
PyObject *child_capsule;
6597
PyObject *data_obj;
6698
int m, n;
67-
if (!PyArg_ParseTuple(args, "OOii", &child_capsule, &data_obj, &m, &n))
99+
if (!PyArg_ParseTuple(args, "OOOii", &param_obj, &child_capsule,
100+
&data_obj, &m, &n))
68101
{
69102
return NULL;
70103
}
@@ -85,11 +118,38 @@ static PyObject *py_make_dense_right_matmul(PyObject *self, PyObject *args)
85118

86119
double *A_data = (double *) PyArray_DATA(data_array);
87120

88-
expr *node = new_right_matmul_dense(child, m, n, A_data);
121+
/* Build the parameter node: use provided capsule or create PARAM_FIXED */
122+
expr *param_node = NULL;
123+
if (param_obj == Py_None)
124+
{
125+
param_node =
126+
new_parameter(m * n, 1, PARAM_FIXED, child->n_vars, A_data);
127+
if (!param_node)
128+
{
129+
Py_DECREF(data_array);
130+
PyErr_SetString(PyExc_RuntimeError,
131+
"failed to create parameter node for dense matrix");
132+
return NULL;
133+
}
134+
}
135+
else
136+
{
137+
param_node =
138+
(expr *) PyCapsule_GetPointer(param_obj, EXPR_CAPSULE_NAME);
139+
if (!param_node)
140+
{
141+
Py_DECREF(data_array);
142+
PyErr_SetString(PyExc_ValueError, "invalid parameter capsule");
143+
return NULL;
144+
}
145+
}
146+
147+
expr *node = new_right_matmul_dense(param_node, child, m, n, A_data);
89148
Py_DECREF(data_array);
90149

91150
if (!node)
92151
{
152+
if (param_obj == Py_None) free_expr(param_node);
93153
PyErr_SetString(PyExc_RuntimeError,
94154
"failed to create dense_right_matmul node");
95155
return NULL;
@@ -98,4 +158,4 @@ static PyObject *py_make_dense_right_matmul(PyObject *self, PyObject *args)
98158
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
99159
}
100160

101-
#endif /* ATOM_DENSE_MATMUL_H */
161+
#endif /* ATOM_DENSE_MATMUL_H */

sparsediffpy/_bindings/atoms/entr.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#define ATOM_ENTR_H
33

44
#include "common.h"
5-
#include "elementwise_univariate.h"
5+
#include "elementwise_restricted_dom.h"
66

77
static PyObject *py_make_entr(PyObject *self, PyObject *args)
88
{

0 commit comments

Comments
 (0)