Skip to content

Commit 22bf057

Browse files
authored
Folder structure refactor (#63)
* move affine atoms from bivariate to affine * run formatter * split tests up into folders * run formatter
1 parent 10340df commit 22bf057

80 files changed

Lines changed: 147 additions & 131 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/affine.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,27 @@ expr *new_reshape(expr *child, int d1, int d2);
4141
expr *new_broadcast(expr *child, int target_d1, int target_d2);
4242
expr *new_diag_vec(expr *child);
4343
expr *new_transpose(expr *child);
44-
expr *new_diag_vec(expr *child);
44+
45+
/* Left matrix multiplication: A @ f(x) where A is a constant sparse
46+
* matrix */
47+
expr *new_left_matmul(expr *u, const CSR_Matrix *A);
48+
49+
/* Left matrix multiplication: A @ f(x) where A is a constant dense
50+
* matrix (row-major, m x n). Uses CBLAS for efficient computation. */
51+
expr *new_left_matmul_dense(expr *u, int m, int n, const double *data);
52+
53+
/* Right matrix multiplication: f(x) @ A where A is a constant
54+
* matrix */
55+
expr *new_right_matmul(expr *u, const CSR_Matrix *A);
56+
57+
expr *new_right_matmul_dense(expr *u, int m, int n, const double *data);
58+
59+
/* Constant scalar multiplication: a * f(x) where a is a constant
60+
* double */
61+
expr *new_const_scalar_mult(double a, expr *child);
62+
63+
/* Constant vector elementwise multiplication: a . f(x) where a is
64+
* constant */
65+
expr *new_const_vector_mult(const double *a, expr *child);
4566

4667
#endif /* AFFINE_H */

include/bivariate.h

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,11 @@
1818
#ifndef BIVARIATE_H
1919
#define BIVARIATE_H
2020

21-
#include "expr.h"
22-
23-
expr *new_elementwise_mult(expr *left, expr *right);
24-
expr *new_rel_entr_vector_args(expr *left, expr *right);
25-
expr *new_quad_over_lin(expr *left, expr *right);
26-
27-
expr *new_rel_entr_first_arg_scalar(expr *left, expr *right);
28-
expr *new_rel_entr_second_arg_scalar(expr *left, expr *right);
29-
30-
/* Matrix multiplication: Z = X @ Y */
31-
expr *new_matmul(expr *x, expr *y);
32-
33-
/* Left matrix multiplication: A @ f(x) where A is a constant sparse matrix */
34-
expr *new_left_matmul(expr *u, const CSR_Matrix *A);
35-
36-
/* Left matrix multiplication: A @ f(x) where A is a constant dense matrix
37-
* (row-major, m x n). Uses CBLAS for efficient computation. */
38-
expr *new_left_matmul_dense(expr *u, int m, int n, const double *data);
39-
40-
/* Right matrix multiplication: f(x) @ A where A is a constant matrix */
41-
expr *new_right_matmul(expr *u, const CSR_Matrix *A);
42-
43-
expr *new_right_matmul_dense(expr *u, int m, int n, const double *data);
44-
45-
/* Constant scalar multiplication: a * f(x) where a is a constant double */
46-
expr *new_const_scalar_mult(double a, expr *child);
47-
48-
/* Constant vector elementwise multiplication: a ∘ f(x) where a is constant */
49-
expr *new_const_vector_mult(const double *a, expr *child);
21+
/* Compatibility header — includes all bivariate-related declarations.
22+
* Prefer including the specific header directly:
23+
* affine.h, bivariate_full_dom.h, bivariate_restricted_dom.h */
24+
#include "affine.h"
25+
#include "bivariate_full_dom.h"
26+
#include "bivariate_restricted_dom.h"
5027

5128
#endif /* BIVARIATE_H */

include/bivariate_full_dom.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef BIVARIATE_FULL_DOM_H
2+
#define BIVARIATE_FULL_DOM_H
3+
4+
#include "expr.h"
5+
6+
expr *new_elementwise_mult(expr *left, expr *right);
7+
8+
/* Matrix multiplication: Z = X @ Y */
9+
expr *new_matmul(expr *x, expr *y);
10+
11+
#endif /* BIVARIATE_FULL_DOM_H */

include/bivariate_restricted_dom.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef BIVARIATE_RESTRICTED_DOM_H
2+
#define BIVARIATE_RESTRICTED_DOM_H
3+
4+
#include "expr.h"
5+
6+
expr *new_quad_over_lin(expr *left, expr *right);
7+
expr *new_rel_entr_vector_args(expr *left, expr *right);
8+
expr *new_rel_entr_first_arg_scalar(expr *left, expr *right);
9+
expr *new_rel_entr_second_arg_scalar(expr *left, expr *right);
10+
11+
#endif /* BIVARIATE_RESTRICTED_DOM_H */
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* See the License for the specific language governing permissions and
1616
* limitations under the License.
1717
*/
18-
#include "bivariate.h"
18+
#include "affine.h"
1919
#include "subexpr.h"
2020
#include <assert.h>
2121
#include <stdio.h>
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* See the License for the specific language governing permissions and
1616
* limitations under the License.
1717
*/
18-
#include "bivariate.h"
18+
#include "affine.h"
1919
#include "subexpr.h"
2020
#include <stdio.h>
2121
#include <stdlib.h>

src/affine/diag_vec.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ static void forward(expr *node, const double *u)
3434
/* child's forward pass */
3535
x->forward(x, u);
3636

37-
/* zero-initialize output */
37+
/* zero-initialize output, TODO: do we need to do this? */
3838
memset(node->value, 0, node->size * sizeof(double));
3939

4040
/* place input elements on the diagonal */
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* See the License for the specific language governing permissions and
1616
* limitations under the License.
1717
*/
18-
#include "bivariate.h"
18+
#include "affine.h"
1919
#include "subexpr.h"
2020
#include "utils/matrix.h"
2121
#include <assert.h>
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
* limitations under the License.
1717
*/
1818
#include "affine.h"
19-
#include "bivariate.h"
19+
2020
#include "utils/CSR_Matrix.h"
2121
#include <stdlib.h>
2222

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* See the License for the specific language governing permissions and
1616
* limitations under the License.
1717
*/
18-
#include "bivariate.h"
18+
#include "bivariate_full_dom.h"
1919
#include "subexpr.h"
2020
#include "utils/mini_numpy.h"
2121
#include <assert.h>

0 commit comments

Comments
 (0)