33
44#include "bivariate.h"
55#include "common.h"
6+ #include "subexpr.h"
67
7- /* Right matrix multiplication: f(x) @ A where A is a constant matrix */
8+ /* Right matrix multiplication: f(x) @ A where A is a constant or parameter matrix.
9+ *
10+ * Unified binding for both fixed-constant and updatable-parameter cases.
11+ * Python signature:
12+ * make_right_matmul(param_or_none, child, data, indices, indptr, m, n)
13+ *
14+ * - param_or_none: None for fixed constants (a PARAM_FIXED parameter is created
15+ * internally), or an existing parameter capsule for updatable parameters.
16+ * - child: the child expression capsule f(x).
17+ * - data, indices, indptr, m, n: CSR arrays defining the sparsity pattern and
18+ * initial values of the matrix A. */
819static PyObject * py_make_right_matmul (PyObject * self , PyObject * args )
920{
21+ PyObject * param_obj ;
1022 PyObject * child_capsule ;
1123 PyObject * data_obj , * indices_obj , * indptr_obj ;
1224 int m , n ;
13- if (!PyArg_ParseTuple (args , "OOOOii " , & child_capsule , & data_obj , & indices_obj ,
14- & indptr_obj , & m , & n ))
25+ if (!PyArg_ParseTuple (args , "OOOOOii " , & param_obj , & child_capsule , & data_obj ,
26+ & indices_obj , & indptr_obj , & m , & n ))
1527 {
1628 return NULL ;
1729 }
@@ -38,18 +50,55 @@ static PyObject *py_make_right_matmul(PyObject *self, PyObject *args)
3850 return NULL ;
3951 }
4052
41- int nnz = (int ) PyArray_SIZE (data_array );
53+ double * csr_data = (double * ) PyArray_DATA (data_array );
54+ int * csr_indices = (int * ) PyArray_DATA (indices_array );
55+ int * csr_indptr = (int * ) PyArray_DATA (indptr_array );
56+ int nnz = csr_indptr [m ];
57+
58+ /* Build CSR matrix from Python arrays */
4259 CSR_Matrix * A = new_csr_matrix (m , n , nnz );
43- memcpy (A -> x , PyArray_DATA (data_array ), nnz * sizeof (double ));
44- memcpy (A -> i , PyArray_DATA (indices_array ), nnz * sizeof (int ));
45- memcpy (A -> p , PyArray_DATA (indptr_array ), (m + 1 ) * sizeof (int ));
60+ memcpy (A -> p , csr_indptr , (m + 1 ) * sizeof (int ));
61+ memcpy (A -> i , csr_indices , nnz * sizeof (int ));
62+ memcpy (A -> x , csr_data , nnz * sizeof (double ));
63+
64+ /* Determine param_node: use passed capsule, or create PARAM_FIXED internally */
65+ expr * param_node ;
66+ if (param_obj == Py_None )
67+ {
68+ /* Fixed constant: pass CSR data directly (values are already in CSR order) */
69+ param_node = new_parameter (nnz , 1 , PARAM_FIXED , child -> n_vars , csr_data );
70+
71+ if (!param_node )
72+ {
73+ free_csr_matrix (A );
74+ Py_DECREF (data_array );
75+ Py_DECREF (indices_array );
76+ Py_DECREF (indptr_array );
77+ PyErr_SetString (PyExc_RuntimeError ,
78+ "failed to create matrix parameter node" );
79+ return NULL ;
80+ }
81+ }
82+ else
83+ {
84+ param_node = (expr * ) PyCapsule_GetPointer (param_obj , EXPR_CAPSULE_NAME );
85+ if (!param_node )
86+ {
87+ free_csr_matrix (A );
88+ Py_DECREF (data_array );
89+ Py_DECREF (indices_array );
90+ Py_DECREF (indptr_array );
91+ PyErr_SetString (PyExc_ValueError , "invalid param capsule" );
92+ return NULL ;
93+ }
94+ }
4695
4796 Py_DECREF (data_array );
4897 Py_DECREF (indices_array );
4998 Py_DECREF (indptr_array );
5099
51- expr * node = new_right_matmul (child , A );
52- free_csr_matrix (A );
100+ expr * node = new_right_matmul (param_node , child , A );
101+ free_csr_matrix (A ); /* constructor copies it */
53102
54103 if (!node )
55104 {
0 commit comments