Skip to content

Commit bec476f

Browse files
committed
new bindings
1 parent 91b1c48 commit bec476f

4 files changed

Lines changed: 112 additions & 6 deletions

File tree

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#ifndef ATOM_DENSE_MATMUL_H
2+
#define ATOM_DENSE_MATMUL_H
3+
4+
#include "bivariate.h"
5+
#include "common.h"
6+
7+
/* Dense left matrix multiplication: A @ f(x) where A is a dense matrix.
8+
*
9+
* Python signature:
10+
* make_dense_left_matmul(child, A_data_flat, m, n)
11+
*
12+
* - child: the child expression capsule f(x).
13+
* - A_data_flat: contiguous row-major numpy float64 array of size m*n.
14+
* - m, n: dimensions of matrix A. */
15+
static PyObject *py_make_dense_left_matmul(PyObject *self, PyObject *args)
16+
{
17+
PyObject *child_capsule;
18+
PyObject *data_obj;
19+
int m, n;
20+
if (!PyArg_ParseTuple(args, "OOii", &child_capsule, &data_obj, &m, &n))
21+
{
22+
return NULL;
23+
}
24+
25+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
26+
if (!child)
27+
{
28+
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
29+
return NULL;
30+
}
31+
32+
PyArrayObject *data_array =
33+
(PyArrayObject *) PyArray_FROM_OTF(data_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
34+
if (!data_array)
35+
{
36+
return NULL;
37+
}
38+
39+
double *A_data = (double *) PyArray_DATA(data_array);
40+
41+
expr *node = new_left_matmul_dense(child, m, n, A_data);
42+
Py_DECREF(data_array);
43+
44+
if (!node)
45+
{
46+
PyErr_SetString(PyExc_RuntimeError,
47+
"failed to create dense_left_matmul node");
48+
return NULL;
49+
}
50+
expr_retain(node);
51+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
52+
}
53+
54+
/* Dense right matrix multiplication: f(x) @ A where A is a dense matrix.
55+
*
56+
* Python signature:
57+
* make_dense_right_matmul(child, A_data_flat, m, n)
58+
*
59+
* - child: the child expression capsule f(x).
60+
* - A_data_flat: contiguous row-major numpy float64 array of size m*n.
61+
* - m, n: dimensions of matrix A. */
62+
static PyObject *py_make_dense_right_matmul(PyObject *self, PyObject *args)
63+
{
64+
PyObject *child_capsule;
65+
PyObject *data_obj;
66+
int m, n;
67+
if (!PyArg_ParseTuple(args, "OOii", &child_capsule, &data_obj, &m, &n))
68+
{
69+
return NULL;
70+
}
71+
72+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
73+
if (!child)
74+
{
75+
PyErr_SetString(PyExc_ValueError, "invalid child capsule");
76+
return NULL;
77+
}
78+
79+
PyArrayObject *data_array =
80+
(PyArrayObject *) PyArray_FROM_OTF(data_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
81+
if (!data_array)
82+
{
83+
return NULL;
84+
}
85+
86+
double *A_data = (double *) PyArray_DATA(data_array);
87+
88+
expr *node = new_right_matmul_dense(child, m, n, A_data);
89+
Py_DECREF(data_array);
90+
91+
if (!node)
92+
{
93+
PyErr_SetString(PyExc_RuntimeError,
94+
"failed to create dense_right_matmul node");
95+
return NULL;
96+
}
97+
expr_retain(node);
98+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
99+
}
100+
101+
#endif /* ATOM_DENSE_MATMUL_H */

sparsediffpy/_bindings/atoms/left_matmul.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "common.h"
66

77
/* Left matrix multiplication: A @ f(x) where A is a constant matrix */
8-
static PyObject *py_make_left_matmul(PyObject *self, PyObject *args)
8+
static PyObject *py_make_sparse_left_matmul(PyObject *self, PyObject *args)
99
{
1010
PyObject *child_capsule;
1111
PyObject *data_obj, *indices_obj, *indptr_obj;

sparsediffpy/_bindings/atoms/right_matmul.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "common.h"
66

77
/* Right matrix multiplication: f(x) @ A where A is a constant matrix */
8-
static PyObject *py_make_right_matmul(PyObject *self, PyObject *args)
8+
static PyObject *py_make_sparse_right_matmul(PyObject *self, PyObject *args)
99
{
1010
PyObject *child_capsule;
1111
PyObject *data_obj, *indices_obj, *indptr_obj;

sparsediffpy/_bindings/bindings.c

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "atoms/const_vector_mult.h"
1212
#include "atoms/constant.h"
1313
#include "atoms/cos.h"
14+
#include "atoms/dense_matmul.h"
1415
#include "atoms/diag_vec.h"
1516
#include "atoms/entr.h"
1617
#include "atoms/exp.h"
@@ -109,10 +110,14 @@ static PyMethodDef DNLPMethods[] = {
109110
{"make_entr", py_make_entr, METH_VARARGS, "Create entr node"},
110111
{"make_logistic", py_make_logistic, METH_VARARGS, "Create logistic node"},
111112
{"make_xexp", py_make_xexp, METH_VARARGS, "Create xexp node"},
112-
{"make_left_matmul", py_make_left_matmul, METH_VARARGS,
113-
"Create left matmul node (A @ f(x))"},
114-
{"make_right_matmul", py_make_right_matmul, METH_VARARGS,
115-
"Create right matmul node (f(x) @ A)"},
113+
{"make_sparse_left_matmul", py_make_sparse_left_matmul, METH_VARARGS,
114+
"Create sparse left matmul node (A @ f(x))"},
115+
{"make_dense_left_matmul", py_make_dense_left_matmul, METH_VARARGS,
116+
"Create dense left matmul node (A @ f(x)) where A is dense"},
117+
{"make_sparse_right_matmul", py_make_sparse_right_matmul, METH_VARARGS,
118+
"Create sparse right matmul node (f(x) @ A)"},
119+
{"make_dense_right_matmul", py_make_dense_right_matmul, METH_VARARGS,
120+
"Create dense right matmul node (f(x) @ A) where A is dense"},
116121
{"make_quad_form", py_make_quad_form, METH_VARARGS,
117122
"Create quadratic form node (x' * Q * x)"},
118123
{"make_quad_over_lin", py_make_quad_over_lin, METH_VARARGS,

0 commit comments

Comments
 (0)