77/* Dense left matrix multiplication: A @ f(x) where A is a dense matrix.
88 *
99 * Python signature:
10- * make_dense_left_matmul(child, A_data_flat, m, n)
10+ * make_dense_left_matmul(param_or_none, child, A_data_flat, m, n)
1111 *
12+ * - param_or_none: None for constant matrix, or a parameter capsule.
1213 * - child: the child expression capsule f(x).
1314 * - A_data_flat: contiguous row-major numpy float64 array of size m*n.
1415 * - m, n: dimensions of matrix A. */
1516static PyObject * py_make_dense_left_matmul (PyObject * self , PyObject * args )
1617{
18+ PyObject * param_obj ;
1719 PyObject * child_capsule ;
1820 PyObject * data_obj ;
1921 int m , n ;
20- if (!PyArg_ParseTuple (args , "OOii" , & child_capsule , & data_obj , & m , & n ))
22+ if (!PyArg_ParseTuple (args , "OOOii" , & param_obj , & child_capsule ,
23+ & data_obj , & m , & n ))
2124 {
2225 return NULL ;
2326 }
@@ -38,11 +41,38 @@ static PyObject *py_make_dense_left_matmul(PyObject *self, PyObject *args)
3841
3942 double * A_data = (double * ) PyArray_DATA (data_array );
4043
41- expr * node = new_left_matmul_dense (child , m , n , A_data );
44+ /* Build the parameter node: use provided capsule or create PARAM_FIXED */
45+ expr * param_node = NULL ;
46+ if (param_obj == Py_None )
47+ {
48+ param_node =
49+ new_parameter (m * n , 1 , PARAM_FIXED , child -> n_vars , A_data );
50+ if (!param_node )
51+ {
52+ Py_DECREF (data_array );
53+ PyErr_SetString (PyExc_RuntimeError ,
54+ "failed to create parameter node for dense matrix" );
55+ return NULL ;
56+ }
57+ }
58+ else
59+ {
60+ param_node =
61+ (expr * ) PyCapsule_GetPointer (param_obj , EXPR_CAPSULE_NAME );
62+ if (!param_node )
63+ {
64+ Py_DECREF (data_array );
65+ PyErr_SetString (PyExc_ValueError , "invalid parameter capsule" );
66+ return NULL ;
67+ }
68+ }
69+
70+ expr * node = new_left_matmul_dense (param_node , child , m , n , A_data );
4271 Py_DECREF (data_array );
4372
4473 if (!node )
4574 {
75+ if (param_obj == Py_None ) free_expr (param_node );
4676 PyErr_SetString (PyExc_RuntimeError ,
4777 "failed to create dense_left_matmul node" );
4878 return NULL ;
@@ -54,17 +84,20 @@ static PyObject *py_make_dense_left_matmul(PyObject *self, PyObject *args)
5484/* Dense right matrix multiplication: f(x) @ A where A is a dense matrix.
5585 *
5686 * Python signature:
57- * make_dense_right_matmul(child, A_data_flat, m, n)
87+ * make_dense_right_matmul(param_or_none, child, A_data_flat, m, n)
5888 *
89+ * - param_or_none: None for constant matrix, or a parameter capsule.
5990 * - child: the child expression capsule f(x).
6091 * - A_data_flat: contiguous row-major numpy float64 array of size m*n.
6192 * - m, n: dimensions of matrix A. */
6293static PyObject * py_make_dense_right_matmul (PyObject * self , PyObject * args )
6394{
95+ PyObject * param_obj ;
6496 PyObject * child_capsule ;
6597 PyObject * data_obj ;
6698 int m , n ;
67- if (!PyArg_ParseTuple (args , "OOii" , & child_capsule , & data_obj , & m , & n ))
99+ if (!PyArg_ParseTuple (args , "OOOii" , & param_obj , & child_capsule ,
100+ & data_obj , & m , & n ))
68101 {
69102 return NULL ;
70103 }
@@ -85,11 +118,38 @@ static PyObject *py_make_dense_right_matmul(PyObject *self, PyObject *args)
85118
86119 double * A_data = (double * ) PyArray_DATA (data_array );
87120
88- expr * node = new_right_matmul_dense (child , m , n , A_data );
121+ /* Build the parameter node: use provided capsule or create PARAM_FIXED */
122+ expr * param_node = NULL ;
123+ if (param_obj == Py_None )
124+ {
125+ param_node =
126+ new_parameter (m * n , 1 , PARAM_FIXED , child -> n_vars , A_data );
127+ if (!param_node )
128+ {
129+ Py_DECREF (data_array );
130+ PyErr_SetString (PyExc_RuntimeError ,
131+ "failed to create parameter node for dense matrix" );
132+ return NULL ;
133+ }
134+ }
135+ else
136+ {
137+ param_node =
138+ (expr * ) PyCapsule_GetPointer (param_obj , EXPR_CAPSULE_NAME );
139+ if (!param_node )
140+ {
141+ Py_DECREF (data_array );
142+ PyErr_SetString (PyExc_ValueError , "invalid parameter capsule" );
143+ return NULL ;
144+ }
145+ }
146+
147+ expr * node = new_right_matmul_dense (param_node , child , m , n , A_data );
89148 Py_DECREF (data_array );
90149
91150 if (!node )
92151 {
152+ if (param_obj == Py_None ) free_expr (param_node );
93153 PyErr_SetString (PyExc_RuntimeError ,
94154 "failed to create dense_right_matmul node" );
95155 return NULL ;
@@ -98,4 +158,4 @@ static PyObject *py_make_dense_right_matmul(PyObject *self, PyObject *args)
98158 return PyCapsule_New (node , EXPR_CAPSULE_NAME , expr_capsule_destructor );
99159}
100160
101- #endif /* ATOM_DENSE_MATMUL_H */
161+ #endif /* ATOM_DENSE_MATMUL_H */
0 commit comments