Skip to content

Commit 4353001

Browse files
Transurgeonclaude
andcommitted
Merge origin/main into parameter-support-v2
Resolves conflicts from folder structure refactor (#63) and more tests (#62). Moves scalar_mult/vector_mult to src/affine/ and test files to new subdirectory structure. Updates affine.h declarations with parameter support signatures. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2 parents 0dd827c + 22bf057 commit 4353001

80 files changed

Lines changed: 174 additions & 129 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/affine.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,28 @@ expr *new_reshape(expr *child, int d1, int d2);
4141
expr *new_broadcast(expr *child, int target_d1, int target_d2);
4242
expr *new_diag_vec(expr *child);
4343
expr *new_transpose(expr *child);
44-
expr *new_diag_vec(expr *child);
44+
45+
/* Left matrix multiplication: A @ f(x) where A is a constant or parameter
46+
* sparse matrix. param_node is NULL for fixed constants. */
47+
expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A);
48+
49+
/* Left matrix multiplication: A @ f(x) where A is a constant or parameter
50+
* dense matrix (row-major, m x n). Uses CBLAS for efficient computation. */
51+
expr *new_left_matmul_dense(expr *param_node, expr *u, int m, int n,
52+
const double *data);
53+
54+
/* Right matrix multiplication: f(x) @ A where A is a constant or parameter
55+
* matrix. */
56+
expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A);
57+
58+
expr *new_right_matmul_dense(expr *param_node, expr *u, int m, int n,
59+
const double *data);
60+
61+
/* Scalar multiplication: a * f(x) where a comes from param_node */
62+
expr *new_scalar_mult(expr *param_node, expr *child);
63+
64+
/* Vector elementwise multiplication: a . f(x) where a comes from
65+
* param_node */
66+
expr *new_vector_mult(expr *param_node, expr *child);
4567

4668
#endif /* AFFINE_H */

include/bivariate.h

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -18,38 +18,11 @@
1818
#ifndef BIVARIATE_H
1919
#define BIVARIATE_H
2020

21-
#include "expr.h"
22-
23-
expr *new_elementwise_mult(expr *left, expr *right);
24-
expr *new_rel_entr_vector_args(expr *left, expr *right);
25-
expr *new_quad_over_lin(expr *left, expr *right);
26-
27-
expr *new_rel_entr_first_arg_scalar(expr *left, expr *right);
28-
expr *new_rel_entr_second_arg_scalar(expr *left, expr *right);
29-
30-
/* Matrix multiplication: Z = X @ Y */
31-
expr *new_matmul(expr *x, expr *y);
32-
33-
/* Left matrix multiplication: A @ f(x) where A is a constant or parameter
34-
* sparse matrix. param_node is NULL for fixed constants. */
35-
expr *new_left_matmul(expr *param_node, expr *u, const CSR_Matrix *A);
36-
37-
/* Left matrix multiplication: A @ f(x) where A is a constant or parameter
38-
* dense matrix (row-major, m x n). Uses CBLAS for efficient computation. */
39-
expr *new_left_matmul_dense(expr *param_node, expr *u, int m, int n,
40-
const double *data);
41-
42-
/* Right matrix multiplication: f(x) @ A where A is a constant or parameter
43-
* matrix. */
44-
expr *new_right_matmul(expr *param_node, expr *u, const CSR_Matrix *A);
45-
46-
expr *new_right_matmul_dense(expr *param_node, expr *u, int m, int n,
47-
const double *data);
48-
49-
/* Scalar multiplication: a * f(x) where a comes from param_node */
50-
expr *new_scalar_mult(expr *param_node, expr *child);
51-
52-
/* Vector elementwise multiplication: a ∘ f(x) where a comes from param_node */
53-
expr *new_vector_mult(expr *param_node, expr *child);
21+
/* Compatibility header — includes all bivariate-related declarations.
22+
* Prefer including the specific header directly:
23+
* affine.h, bivariate_full_dom.h, bivariate_restricted_dom.h */
24+
#include "affine.h"
25+
#include "bivariate_full_dom.h"
26+
#include "bivariate_restricted_dom.h"
5427

5528
#endif /* BIVARIATE_H */

include/bivariate_full_dom.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef BIVARIATE_FULL_DOM_H
2+
#define BIVARIATE_FULL_DOM_H
3+
4+
#include "expr.h"
5+
6+
expr *new_elementwise_mult(expr *left, expr *right);
7+
8+
/* Matrix multiplication: Z = X @ Y */
9+
expr *new_matmul(expr *x, expr *y);
10+
11+
#endif /* BIVARIATE_FULL_DOM_H */

include/bivariate_restricted_dom.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#ifndef BIVARIATE_RESTRICTED_DOM_H
2+
#define BIVARIATE_RESTRICTED_DOM_H
3+
4+
#include "expr.h"
5+
6+
expr *new_quad_over_lin(expr *left, expr *right);
7+
expr *new_rel_entr_vector_args(expr *left, expr *right);
8+
expr *new_rel_entr_first_arg_scalar(expr *left, expr *right);
9+
expr *new_rel_entr_second_arg_scalar(expr *left, expr *right);
10+
11+
#endif /* BIVARIATE_RESTRICTED_DOM_H */

src/affine/diag_vec.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ static void forward(expr *node, const double *u)
3434
/* child's forward pass */
3535
x->forward(x, u);
3636

37-
/* zero-initialize output */
37+
/* zero-initialize output, TODO: do we need to do this? */
3838
memset(node->value, 0, node->size * sizeof(double));
3939

4040
/* place input elements on the diagonal */
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* See the License for the specific language governing permissions and
1616
* limitations under the License.
1717
*/
18-
#include "bivariate.h"
18+
#include "affine.h"
1919
#include "subexpr.h"
2020
#include "utils/matrix.h"
2121
#include <assert.h>
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
* limitations under the License.
1717
*/
1818
#include "affine.h"
19-
#include "bivariate.h"
2019
#include "subexpr.h"
2120
#include "utils/CSR_Matrix.h"
2221
#include <stdlib.h>
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* See the License for the specific language governing permissions and
1616
* limitations under the License.
1717
*/
18-
#include "bivariate.h"
18+
#include "bivariate_full_dom.h"
1919
#include "subexpr.h"
2020
#include "utils/mini_numpy.h"
2121
#include <assert.h>

0 commit comments

Comments
 (0)