1+ #ifndef ATOM_DENSE_MATMUL_H
2+ #define ATOM_DENSE_MATMUL_H
3+
4+ #include "bivariate.h"
5+ #include "common.h"
6+
7+ /* Dense left matrix multiplication: A @ f(x) where A is a dense matrix.
8+ *
9+ * Python signature:
10+ * make_dense_left_matmul(child, A_data_flat, m, n)
11+ *
12+ * - child: the child expression capsule f(x).
13+ * - A_data_flat: contiguous row-major numpy float64 array of size m*n.
14+ * - m, n: dimensions of matrix A. */
15+ static PyObject * py_make_dense_left_matmul (PyObject * self , PyObject * args )
16+ {
17+ PyObject * child_capsule ;
18+ PyObject * data_obj ;
19+ int m , n ;
20+ if (!PyArg_ParseTuple (args , "OOii" , & child_capsule , & data_obj , & m , & n ))
21+ {
22+ return NULL ;
23+ }
24+
25+ expr * child = (expr * ) PyCapsule_GetPointer (child_capsule , EXPR_CAPSULE_NAME );
26+ if (!child )
27+ {
28+ PyErr_SetString (PyExc_ValueError , "invalid child capsule" );
29+ return NULL ;
30+ }
31+
32+ PyArrayObject * data_array =
33+ (PyArrayObject * ) PyArray_FROM_OTF (data_obj , NPY_DOUBLE , NPY_ARRAY_IN_ARRAY );
34+ if (!data_array )
35+ {
36+ return NULL ;
37+ }
38+
39+ double * A_data = (double * ) PyArray_DATA (data_array );
40+
41+ expr * node = new_left_matmul_dense (child , m , n , A_data );
42+ Py_DECREF (data_array );
43+
44+ if (!node )
45+ {
46+ PyErr_SetString (PyExc_RuntimeError ,
47+ "failed to create dense_left_matmul node" );
48+ return NULL ;
49+ }
50+ expr_retain (node );
51+ return PyCapsule_New (node , EXPR_CAPSULE_NAME , expr_capsule_destructor );
52+ }
53+
54+ /* Dense right matrix multiplication: f(x) @ A where A is a dense matrix.
55+ *
56+ * Python signature:
57+ * make_dense_right_matmul(child, A_data_flat, m, n)
58+ *
59+ * - child: the child expression capsule f(x).
60+ * - A_data_flat: contiguous row-major numpy float64 array of size m*n.
61+ * - m, n: dimensions of matrix A. */
62+ static PyObject * py_make_dense_right_matmul (PyObject * self , PyObject * args )
63+ {
64+ PyObject * child_capsule ;
65+ PyObject * data_obj ;
66+ int m , n ;
67+ if (!PyArg_ParseTuple (args , "OOii" , & child_capsule , & data_obj , & m , & n ))
68+ {
69+ return NULL ;
70+ }
71+
72+ expr * child = (expr * ) PyCapsule_GetPointer (child_capsule , EXPR_CAPSULE_NAME );
73+ if (!child )
74+ {
75+ PyErr_SetString (PyExc_ValueError , "invalid child capsule" );
76+ return NULL ;
77+ }
78+
79+ PyArrayObject * data_array =
80+ (PyArrayObject * ) PyArray_FROM_OTF (data_obj , NPY_DOUBLE , NPY_ARRAY_IN_ARRAY );
81+ if (!data_array )
82+ {
83+ return NULL ;
84+ }
85+
86+ double * A_data = (double * ) PyArray_DATA (data_array );
87+
88+ expr * node = new_right_matmul_dense (child , m , n , A_data );
89+ Py_DECREF (data_array );
90+
91+ if (!node )
92+ {
93+ PyErr_SetString (PyExc_RuntimeError ,
94+ "failed to create dense_right_matmul node" );
95+ return NULL ;
96+ }
97+ expr_retain (node );
98+ return PyCapsule_New (node , EXPR_CAPSULE_NAME , expr_capsule_destructor );
99+ }
100+
101+ #endif /* ATOM_DENSE_MATMUL_H */
0 commit comments