Skip to content

Commit bbb8c3f

Browse files
Transurgeonclaude
andcommitted
Unify right_matmul binding to accept param_or_none like left_matmul
Update py_make_right_matmul signature to (param_or_none, child, data, indices, indptr, m, n) matching left_matmul. When param_or_none is None, a PARAM_FIXED node is created internally; otherwise the passed parameter capsule is used for updatable sparse parameters. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 20655a0 commit bbb8c3f

2 files changed

Lines changed: 59 additions & 10 deletions

File tree

sparsediffpy/_bindings/atoms/right_matmul.h

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,27 @@
33

44
#include "bivariate.h"
55
#include "common.h"
6+
#include "subexpr.h"
67

7-
/* Right matrix multiplication: f(x) @ A where A is a constant matrix */
8+
/* Right matrix multiplication: f(x) @ A where A is a constant or parameter matrix.
9+
*
10+
* Unified binding for both fixed-constant and updatable-parameter cases.
11+
* Python signature:
12+
* make_right_matmul(param_or_none, child, data, indices, indptr, m, n)
13+
*
14+
* - param_or_none: None for fixed constants (a PARAM_FIXED parameter is created
15+
* internally), or an existing parameter capsule for updatable parameters.
16+
* - child: the child expression capsule f(x).
17+
* - data, indices, indptr, m, n: CSR arrays defining the sparsity pattern and
18+
* initial values of the matrix A. */
819
static PyObject *py_make_right_matmul(PyObject *self, PyObject *args)
920
{
21+
PyObject *param_obj;
1022
PyObject *child_capsule;
1123
PyObject *data_obj, *indices_obj, *indptr_obj;
1224
int m, n;
13-
if (!PyArg_ParseTuple(args, "OOOOii", &child_capsule, &data_obj, &indices_obj,
14-
&indptr_obj, &m, &n))
25+
if (!PyArg_ParseTuple(args, "OOOOOii", &param_obj, &child_capsule, &data_obj,
26+
&indices_obj, &indptr_obj, &m, &n))
1527
{
1628
return NULL;
1729
}
@@ -38,18 +50,55 @@ static PyObject *py_make_right_matmul(PyObject *self, PyObject *args)
3850
return NULL;
3951
}
4052

41-
int nnz = (int) PyArray_SIZE(data_array);
53+
double *csr_data = (double *) PyArray_DATA(data_array);
54+
int *csr_indices = (int *) PyArray_DATA(indices_array);
55+
int *csr_indptr = (int *) PyArray_DATA(indptr_array);
56+
int nnz = csr_indptr[m];
57+
58+
/* Build CSR matrix from Python arrays */
4259
CSR_Matrix *A = new_csr_matrix(m, n, nnz);
43-
memcpy(A->x, PyArray_DATA(data_array), nnz * sizeof(double));
44-
memcpy(A->i, PyArray_DATA(indices_array), nnz * sizeof(int));
45-
memcpy(A->p, PyArray_DATA(indptr_array), (m + 1) * sizeof(int));
60+
memcpy(A->p, csr_indptr, (m + 1) * sizeof(int));
61+
memcpy(A->i, csr_indices, nnz * sizeof(int));
62+
memcpy(A->x, csr_data, nnz * sizeof(double));
63+
64+
/* Determine param_node: use passed capsule, or create PARAM_FIXED internally */
65+
expr *param_node;
66+
if (param_obj == Py_None)
67+
{
68+
/* Fixed constant: pass CSR data directly (values are already in CSR order) */
69+
param_node = new_parameter(nnz, 1, PARAM_FIXED, child->n_vars, csr_data);
70+
71+
if (!param_node)
72+
{
73+
free_csr_matrix(A);
74+
Py_DECREF(data_array);
75+
Py_DECREF(indices_array);
76+
Py_DECREF(indptr_array);
77+
PyErr_SetString(PyExc_RuntimeError,
78+
"failed to create matrix parameter node");
79+
return NULL;
80+
}
81+
}
82+
else
83+
{
84+
param_node = (expr *) PyCapsule_GetPointer(param_obj, EXPR_CAPSULE_NAME);
85+
if (!param_node)
86+
{
87+
free_csr_matrix(A);
88+
Py_DECREF(data_array);
89+
Py_DECREF(indices_array);
90+
Py_DECREF(indptr_array);
91+
PyErr_SetString(PyExc_ValueError, "invalid param capsule");
92+
return NULL;
93+
}
94+
}
4695

4796
Py_DECREF(data_array);
4897
Py_DECREF(indices_array);
4998
Py_DECREF(indptr_array);
5099

51-
expr *node = new_right_matmul(child, A);
52-
free_csr_matrix(A);
100+
expr *node = new_right_matmul(param_node, child, A);
101+
free_csr_matrix(A); /* constructor copies it */
53102

54103
if (!node)
55104
{

sparsediffpy/_bindings/bindings.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ static PyMethodDef DNLPMethods[] = {
122122
{"make_param_vector_mult", py_make_param_vector_mult, METH_VARARGS,
123123
"Create vector mult from parameter (p ∘ f(x))"},
124124
{"make_right_matmul", py_make_right_matmul, METH_VARARGS,
125-
"Create right matmul node (f(x) @ A)"},
125+
"Create right matmul node (f(x) @ A): pass None or param capsule as first arg"},
126126
{"make_quad_form", py_make_quad_form, METH_VARARGS,
127127
"Create quadratic form node (x' * Q * x)"},
128128
{"make_quad_over_lin", py_make_quad_over_lin, METH_VARARGS,

0 commit comments

Comments
 (0)