Skip to content

Commit f83f990

Browse files
committed
add test
1 parent 04c2268 commit f83f990

4 files changed

Lines changed: 74 additions & 6 deletions

File tree

src/bivariate/right_matmul.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ expr *new_right_matmul_dense(expr *u, int m, int n, const double *data)
5454
}
5555

5656
expr *u_transpose = new_transpose(u);
57-
expr *left_matmul_node = new_left_matmul_dense(u_transpose, m, n, AT);
57+
expr *left_matmul_node = new_left_matmul_dense(u_transpose, n, m, AT);
5858
expr *node = new_transpose(left_matmul_node);
5959

6060
free(AT);

src/utils/dense_matrix.c

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,22 @@ static void dense_block_left_mult_vec(const Matrix *A, const double *x, double *
3030

3131
/* y = kron(I_p, A) @ x via a single dgemm call:
3232
Treat x as n x p (column-major blocks) and y as m x p.
33-
But x and y are stored as p blocks of length n and m respectively
34-
(i.e. block-interleaved). This is the same as treating them as
35-
row-major matrices of shape p x n and p x m, so:
33+
But x and y are stored as p blocks of length n and m
34+
respectively (i.e. block-interleaved). This is the same as
35+
treating them as row-major matrices of shape p x n and
36+
p x m, so:
3637
y (p x m) = x (p x n) * A^T (n x m), all row-major.
3738
cblas with RowMajor: C = alpha * A * B + beta * C
3839
where A = x (p x n), B = A^T (n x m), C = y (p x m). */
39-
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, p, m, n, 1.0, x, n, dm->x,
40-
n, 0.0, y, m);
40+
/* cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
41+
p, m, n, 1.0, x, n, dm->x,
42+
n, 0.0, y, m); */
43+
for (int b = 0; b < p; b++)
44+
{
45+
cblas_dgemv(CblasRowMajor, CblasNoTrans, m, n, 1.0,
46+
dm->x, n, x + b * n, 1,
47+
0.0, y + b * m, 1);
48+
}
4149
}
4250

4351
static CSC_Matrix *dense_block_left_mult_sparsity(const Matrix *A,

tests/all_tests.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "forward_pass/composite/test_composite.h"
1616
#include "forward_pass/elementwise/test_exp.h"
1717
#include "forward_pass/elementwise/test_log.h"
18+
#include "forward_pass/test_left_matmul_dense.h"
1819
#include "forward_pass/test_matmul.h"
1920
#include "forward_pass/test_prod_axis_one.h"
2021
#include "forward_pass/test_prod_axis_zero.h"
@@ -112,6 +113,7 @@ int main(void)
112113
mu_run_test(test_forward_prod_axis_zero, tests_run);
113114
mu_run_test(test_forward_prod_axis_one, tests_run);
114115
mu_run_test(test_matmul, tests_run);
116+
mu_run_test(test_left_matmul_dense, tests_run);
115117

116118
printf("\n--- Jacobian Tests ---\n");
117119
mu_run_test(test_neg_jacobian, tests_run);
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
4+
#include "bivariate.h"
5+
#include "expr.h"
6+
#include "minunit.h"
7+
#include "test_helpers.h"
8+
9+
const char *test_left_matmul_dense(void)
10+
{
11+
/* Test: Z = A @ X where
12+
* A is 3x3 (row-major): [1 2 3; 4 5 6; 7 8 9]
13+
* X is 3x3 variable (col-major): [1 4 7; 2 5 8; 3 6 9]
14+
*
15+
* Z = A @ X = [14 32 50]
16+
* [32 77 122]
17+
* [50 122 194]
18+
*/
19+
20+
/* Create X variable (3 x 3) */
21+
expr *X = new_variable(3, 3, 0, 9);
22+
23+
/* Constant matrix A in row-major order */
24+
double A_data[9] = {1.0, 2.0, 3.0,
25+
4.0, 5.0, 6.0,
26+
7.0, 8.0, 9.0};
27+
28+
/* Build expression Z = A @ X */
29+
expr *Z = new_left_matmul_dense(X, 3, 3, A_data);
30+
31+
/* Variable values in column-major order */
32+
double u[9] = {1.0, 2.0, 3.0, /* first column */
33+
4.0, 5.0, 6.0, /* second column */
34+
7.0, 8.0, 9.0}; /* third column */
35+
36+
/* Evaluate forward pass */
37+
Z->forward(Z, u);
38+
39+
/* Expected result (3 x 3) in column-major order */
40+
double expected[9] = {14.0, 32.0, 50.0, /* first column */
41+
32.0, 77.0, 122.0, /* second column */
42+
50.0, 122.0, 194.0}; /* third column */
43+
44+
/* Verify dimensions */
45+
mu_assert("left_matmul_dense result should have d1=3",
46+
Z->d1 == 3);
47+
mu_assert("left_matmul_dense result should have d2=3",
48+
Z->d2 == 3);
49+
mu_assert("left_matmul_dense result should have size=9",
50+
Z->size == 9);
51+
52+
/* Verify values */
53+
mu_assert("Left matmul dense forward pass test failed",
54+
cmp_double_array(Z->value, expected, 9));
55+
56+
free_expr(Z);
57+
return 0;
58+
}

0 commit comments

Comments
 (0)