Skip to content

Commit 90bf2c1

Browse files
Transurgeonclaude
andcommitted
Add parameter support to Python bindings
- Add parameter, scalar_mult, vector_mult bindings for updatable params - Add register_params and update_params problem bindings - Update constant.h to use unified new_parameter with PARAM_FIXED - Update const_scalar_mult.h and const_vector_mult.h for new parameter API - Unify left_matmul and left_param_matmul into single binding that accepts (param_or_none, child, csr_data, csr_indices, csr_indptr, m, n) - Update SparseDiffEngine submodule to parameter-support branch - Register all new methods in bindings.c Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ba3a5a2 commit 90bf2c1

11 files changed

Lines changed: 323 additions & 21 deletions

File tree

sparsediffpy/_bindings/atoms/const_scalar_mult.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33

44
#include "bivariate.h"
55
#include "common.h"
6+
#include "subexpr.h"
67

7-
/* Constant scalar multiplication: a * f(x) where a is a constant double */
8+
/* Constant scalar multiplication: a * f(x) where a is a constant double.
9+
* Creates a fixed parameter node for the scalar and calls new_scalar_mult. */
810
static PyObject *py_make_const_scalar_mult(PyObject *self, PyObject *args)
911
{
1012
PyObject *child_capsule;
@@ -22,11 +24,20 @@ static PyObject *py_make_const_scalar_mult(PyObject *self, PyObject *args)
2224
return NULL;
2325
}
2426

25-
expr *node = new_const_scalar_mult(a, child);
27+
/* Create a 1x1 fixed parameter for the scalar value */
28+
expr *a_node = new_parameter(1, 1, PARAM_FIXED, child->n_vars, &a);
29+
if (!a_node)
30+
{
31+
PyErr_SetString(PyExc_RuntimeError, "failed to create scalar parameter node");
32+
return NULL;
33+
}
34+
35+
expr *node = new_scalar_mult(a_node, child);
36+
2637
if (!node)
2738
{
2839
PyErr_SetString(PyExc_RuntimeError,
29-
"failed to create const_scalar_mult node");
40+
"failed to create scalar_mult node");
3041
return NULL;
3142
}
3243
expr_retain(node); /* Capsule owns a reference */

sparsediffpy/_bindings/atoms/const_vector_mult.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
#include "bivariate.h"
55
#include "common.h"
6+
#include "subexpr.h"
67

7-
/* Constant vector elementwise multiplication: a ∘ f(x) where a is a constant vector
8-
*/
8+
/* Constant vector elementwise multiplication: a ∘ f(x) where a is a constant vector.
9+
* Creates a fixed parameter node for the vector and calls new_vector_mult. */
910
static PyObject *py_make_const_vector_mult(PyObject *self, PyObject *args)
1011
{
1112
PyObject *child_capsule;
@@ -42,14 +43,24 @@ static PyObject *py_make_const_vector_mult(PyObject *self, PyObject *args)
4243

4344
double *a_data = (double *) PyArray_DATA(a_array);
4445

45-
expr *node = new_const_vector_mult(a_data, child);
46-
46+
/* Create a fixed parameter node for the vector */
47+
expr *a_node = new_parameter(child->d1, child->d2, PARAM_FIXED, child->n_vars,
48+
a_data);
4749
Py_DECREF(a_array);
4850

51+
if (!a_node)
52+
{
53+
PyErr_SetString(PyExc_RuntimeError,
54+
"failed to create vector parameter node");
55+
return NULL;
56+
}
57+
58+
expr *node = new_vector_mult(a_node, child);
59+
4960
if (!node)
5061
{
5162
PyErr_SetString(PyExc_RuntimeError,
52-
"failed to create const_vector_mult node");
63+
"failed to create vector_mult node");
5364
return NULL;
5465
}
5566
expr_retain(node); /* Capsule owns a reference */

sparsediffpy/_bindings/atoms/constant.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define ATOM_CONSTANT_H
33

44
#include "common.h"
5+
#include "subexpr.h"
56

67
static PyObject *py_make_constant(PyObject *self, PyObject *args)
78
{
@@ -19,8 +20,8 @@ static PyObject *py_make_constant(PyObject *self, PyObject *args)
1920
return NULL;
2021
}
2122

22-
expr *node =
23-
new_constant(d1, d2, n_vars, (const double *) PyArray_DATA(values_array));
23+
expr *node = new_parameter(d1, d2, PARAM_FIXED, n_vars,
24+
(const double *) PyArray_DATA(values_array));
2425
Py_DECREF(values_array);
2526

2627
if (!node)

sparsediffpy/_bindings/atoms/left_matmul.h

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,27 @@
33

44
#include "bivariate.h"
55
#include "common.h"
6+
#include "subexpr.h"
67

7-
/* Left matrix multiplication: A @ f(x) where A is a constant matrix */
8+
/* Left matrix multiplication: A @ f(x).
9+
*
10+
* Unified binding for both fixed-constant and updatable-parameter cases.
11+
* Python signature:
12+
* make_left_matmul(param_or_none, child, data, indices, indptr, m, n)
13+
*
14+
* - param_or_none: None for fixed constants (a PARAM_FIXED parameter is created
15+
* internally), or an existing parameter capsule for updatable parameters.
16+
* - child: the child expression capsule f(x).
17+
* - data, indices, indptr, m, n: CSR arrays defining the sparsity pattern and
18+
* initial values of the matrix A. */
819
static PyObject *py_make_left_matmul(PyObject *self, PyObject *args)
920
{
21+
PyObject *param_obj;
1022
PyObject *child_capsule;
1123
PyObject *data_obj, *indices_obj, *indptr_obj;
1224
int m, n;
13-
if (!PyArg_ParseTuple(args, "OOOOii", &child_capsule, &data_obj, &indices_obj,
14-
&indptr_obj, &m, &n))
25+
if (!PyArg_ParseTuple(args, "OOOOOii", &param_obj, &child_capsule, &data_obj,
26+
&indices_obj, &indptr_obj, &m, &n))
1527
{
1628
return NULL;
1729
}
@@ -38,18 +50,61 @@ static PyObject *py_make_left_matmul(PyObject *self, PyObject *args)
3850
return NULL;
3951
}
4052

41-
int nnz = (int) PyArray_SIZE(data_array);
53+
double *csr_data = (double *) PyArray_DATA(data_array);
54+
int *csr_indices = (int *) PyArray_DATA(indices_array);
55+
int *csr_indptr = (int *) PyArray_DATA(indptr_array);
56+
int nnz = csr_indptr[m];
57+
58+
/* Build CSR matrix from Python arrays */
4259
CSR_Matrix *A = new_csr_matrix(m, n, nnz);
43-
memcpy(A->x, PyArray_DATA(data_array), nnz * sizeof(double));
44-
memcpy(A->i, PyArray_DATA(indices_array), nnz * sizeof(int));
45-
memcpy(A->p, PyArray_DATA(indptr_array), (m + 1) * sizeof(int));
60+
memcpy(A->p, csr_indptr, (m + 1) * sizeof(int));
61+
memcpy(A->i, csr_indices, nnz * sizeof(int));
62+
memcpy(A->x, csr_data, nnz * sizeof(double));
63+
64+
/* Determine param_node: use passed capsule, or create PARAM_FIXED internally */
65+
expr *param_node;
66+
if (param_obj == Py_None)
67+
{
68+
/* Fixed constant: create column-major values for the parameter node */
69+
double *col_major = (double *) calloc(m * n, sizeof(double));
70+
for (int row = 0; row < m; row++)
71+
for (int k = csr_indptr[row]; k < csr_indptr[row + 1]; k++)
72+
col_major[row + csr_indices[k] * m] = csr_data[k];
73+
74+
param_node = new_parameter(m, n, PARAM_FIXED, child->n_vars, col_major);
75+
free(col_major);
76+
77+
if (!param_node)
78+
{
79+
free_csr_matrix(A);
80+
Py_DECREF(data_array);
81+
Py_DECREF(indices_array);
82+
Py_DECREF(indptr_array);
83+
PyErr_SetString(PyExc_RuntimeError,
84+
"failed to create matrix parameter node");
85+
return NULL;
86+
}
87+
}
88+
else
89+
{
90+
param_node = (expr *) PyCapsule_GetPointer(param_obj, EXPR_CAPSULE_NAME);
91+
if (!param_node)
92+
{
93+
free_csr_matrix(A);
94+
Py_DECREF(data_array);
95+
Py_DECREF(indices_array);
96+
Py_DECREF(indptr_array);
97+
PyErr_SetString(PyExc_ValueError, "invalid param capsule");
98+
return NULL;
99+
}
100+
}
46101

47102
Py_DECREF(data_array);
48103
Py_DECREF(indices_array);
49104
Py_DECREF(indptr_array);
50105

51-
expr *node = new_left_matmul(child, A);
52-
free_csr_matrix(A);
106+
expr *node = new_left_matmul(param_node, child, A);
107+
free_csr_matrix(A); /* constructor copies it */
53108

54109
if (!node)
55110
{
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef ATOM_PARAMETER_H
2+
#define ATOM_PARAMETER_H
3+
4+
#include "common.h"
5+
6+
/* Updatable parameter: make_parameter(d1, d2, param_id, n_vars)
7+
* Values are set later via problem_update_params. */
8+
static PyObject *py_make_parameter(PyObject *self, PyObject *args)
9+
{
10+
int d1, d2, param_id, n_vars;
11+
if (!PyArg_ParseTuple(args, "iiii", &d1, &d2, &param_id, &n_vars))
12+
{
13+
return NULL;
14+
}
15+
16+
expr *node = new_parameter(d1, d2, param_id, n_vars, NULL);
17+
if (!node)
18+
{
19+
PyErr_SetString(PyExc_RuntimeError, "failed to create parameter node");
20+
return NULL;
21+
}
22+
expr_retain(node); /* Capsule owns a reference */
23+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
24+
}
25+
26+
#endif /* ATOM_PARAMETER_H */
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#ifndef ATOM_SCALAR_MULT_H
2+
#define ATOM_SCALAR_MULT_H
3+
4+
#include "bivariate.h"
5+
#include "common.h"
6+
7+
/* Parameter scalar multiplication: param * f(x) where param is a parameter capsule.
8+
* Python name: make_param_scalar_mult(param_capsule, child_capsule) */
9+
static PyObject *py_make_param_scalar_mult(PyObject *self, PyObject *args)
10+
{
11+
PyObject *param_capsule;
12+
PyObject *child_capsule;
13+
14+
if (!PyArg_ParseTuple(args, "OO", &param_capsule, &child_capsule))
15+
{
16+
return NULL;
17+
}
18+
19+
expr *param_node =
20+
(expr *) PyCapsule_GetPointer(param_capsule, EXPR_CAPSULE_NAME);
21+
if (!param_node)
22+
{
23+
PyErr_SetString(PyExc_ValueError, "invalid param capsule");
24+
return NULL;
25+
}
26+
27+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
28+
if (!child)
29+
{
30+
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
31+
return NULL;
32+
}
33+
34+
expr *node = new_scalar_mult(param_node, child);
35+
if (!node)
36+
{
37+
PyErr_SetString(PyExc_RuntimeError, "failed to create scalar_mult node");
38+
return NULL;
39+
}
40+
expr_retain(node);
41+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
42+
}
43+
44+
#endif /* ATOM_SCALAR_MULT_H */
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#ifndef ATOM_VECTOR_MULT_H
2+
#define ATOM_VECTOR_MULT_H
3+
4+
#include "bivariate.h"
5+
#include "common.h"
6+
7+
/* Parameter vector multiplication: param ∘ f(x) where param is a parameter capsule.
8+
* Python name: make_param_vector_mult(param_capsule, child_capsule) */
9+
static PyObject *py_make_param_vector_mult(PyObject *self, PyObject *args)
10+
{
11+
PyObject *param_capsule;
12+
PyObject *child_capsule;
13+
14+
if (!PyArg_ParseTuple(args, "OO", &param_capsule, &child_capsule))
15+
{
16+
return NULL;
17+
}
18+
19+
expr *param_node =
20+
(expr *) PyCapsule_GetPointer(param_capsule, EXPR_CAPSULE_NAME);
21+
if (!param_node)
22+
{
23+
PyErr_SetString(PyExc_ValueError, "invalid param capsule");
24+
return NULL;
25+
}
26+
27+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
28+
if (!child)
29+
{
30+
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
31+
return NULL;
32+
}
33+
34+
expr *node = new_vector_mult(param_node, child);
35+
if (!node)
36+
{
37+
PyErr_SetString(PyExc_RuntimeError, "failed to create vector_mult node");
38+
return NULL;
39+
}
40+
expr_retain(node);
41+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
42+
}
43+
44+
#endif /* ATOM_VECTOR_MULT_H */

sparsediffpy/_bindings/bindings.c

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "atoms/matmul.h"
2525
#include "atoms/multiply.h"
2626
#include "atoms/neg.h"
27+
#include "atoms/parameter.h"
2728
#include "atoms/power.h"
2829
#include "atoms/prod.h"
2930
#include "atoms/prod_axis_one.h"
@@ -36,6 +37,7 @@
3637
#include "atoms/rel_entr_vector_scalar.h"
3738
#include "atoms/reshape.h"
3839
#include "atoms/right_matmul.h"
40+
#include "atoms/scalar_mult.h"
3941
#include "atoms/sin.h"
4042
#include "atoms/sinh.h"
4143
#include "atoms/sum.h"
@@ -44,6 +46,7 @@
4446
#include "atoms/trace.h"
4547
#include "atoms/transpose.h"
4648
#include "atoms/variable.h"
49+
#include "atoms/vector_mult.h"
4750
#include "atoms/xexp.h"
4851

4952
/* Include problem bindings */
@@ -56,6 +59,8 @@
5659
#include "problem/jacobian.h"
5760
#include "problem/make_problem.h"
5861
#include "problem/objective_forward.h"
62+
#include "problem/register_params.h"
63+
#include "problem/update_params.h"
5964

6065
static int numpy_initialized = 0;
6166

@@ -70,6 +75,7 @@ static int ensure_numpy(void)
7075
static PyMethodDef DNLPMethods[] = {
7176
{"make_variable", py_make_variable, METH_VARARGS, "Create variable node"},
7277
{"make_constant", py_make_constant, METH_VARARGS, "Create constant node"},
78+
{"make_parameter", py_make_parameter, METH_VARARGS, "Create parameter node"},
7379
{"make_linear", py_make_linear, METH_VARARGS, "Create linear op node"},
7480
{"make_log", py_make_log, METH_VARARGS, "Create log node"},
7581
{"make_exp", py_make_exp, METH_VARARGS, "Create exp node"},
@@ -110,7 +116,11 @@ static PyMethodDef DNLPMethods[] = {
110116
{"make_logistic", py_make_logistic, METH_VARARGS, "Create logistic node"},
111117
{"make_xexp", py_make_xexp, METH_VARARGS, "Create xexp node"},
112118
{"make_left_matmul", py_make_left_matmul, METH_VARARGS,
113-
"Create left matmul node (A @ f(x))"},
119+
"Create left matmul node (A @ f(x)): pass None or param capsule as first arg"},
120+
{"make_param_scalar_mult", py_make_param_scalar_mult, METH_VARARGS,
121+
"Create scalar mult from parameter (p * f(x))"},
122+
{"make_param_vector_mult", py_make_param_vector_mult, METH_VARARGS,
123+
"Create vector mult from parameter (p ∘ f(x))"},
114124
{"make_right_matmul", py_make_right_matmul, METH_VARARGS,
115125
"Create right matmul node (f(x) @ A)"},
116126
{"make_quad_form", py_make_quad_form, METH_VARARGS,
@@ -150,6 +160,10 @@ static PyMethodDef DNLPMethods[] = {
150160
"Compute Lagrangian Hessian"},
151161
{"get_hessian", py_get_hessian, METH_VARARGS,
152162
"Get Lagrangian Hessian without recomputing"},
163+
{"problem_register_params", py_problem_register_params, METH_VARARGS,
164+
"Register parameter nodes with the problem"},
165+
{"problem_update_params", py_problem_update_params, METH_VARARGS,
166+
"Update parameter values"},
153167
{NULL, NULL, 0, NULL}};
154168

155169
static struct PyModuleDef sparsediffpy_module = {

0 commit comments

Comments
 (0)