|
| 1 | +#include <stdio.h> |
| 2 | +#include <stdlib.h> |
| 3 | + |
| 4 | +#include "bivariate.h" |
| 5 | +#include "expr.h" |
| 6 | +#include "minunit.h" |
| 7 | +#include "test_helpers.h" |
| 8 | + |
| 9 | +const char *test_left_matmul_dense(void) |
| 10 | +{ |
| 11 | + /* Test: Z = A @ X where |
| 12 | + * A is 3x3 (row-major): [1 2 3; 4 5 6; 7 8 9] |
| 13 | + * X is 3x3 variable (col-major): [1 4 7; 2 5 8; 3 6 9] |
| 14 | + * |
| 15 | + * Z = A @ X = [14 32 50] |
| 16 | + * [32 77 122] |
| 17 | + * [50 122 194] |
| 18 | + */ |
| 19 | + |
| 20 | + /* Create X variable (3 x 3) */ |
| 21 | + expr *X = new_variable(3, 3, 0, 9); |
| 22 | + |
| 23 | + /* Constant matrix A in row-major order */ |
| 24 | + double A_data[9] = {1.0, 2.0, 3.0, |
| 25 | + 4.0, 5.0, 6.0, |
| 26 | + 7.0, 8.0, 9.0}; |
| 27 | + |
| 28 | + /* Build expression Z = A @ X */ |
| 29 | + expr *Z = new_left_matmul_dense(X, 3, 3, A_data); |
| 30 | + |
| 31 | + /* Variable values in column-major order */ |
| 32 | + double u[9] = {1.0, 2.0, 3.0, /* first column */ |
| 33 | + 4.0, 5.0, 6.0, /* second column */ |
| 34 | + 7.0, 8.0, 9.0}; /* third column */ |
| 35 | + |
| 36 | + /* Evaluate forward pass */ |
| 37 | + Z->forward(Z, u); |
| 38 | + |
| 39 | + /* Expected result (3 x 3) in column-major order */ |
| 40 | + double expected[9] = {14.0, 32.0, 50.0, /* first column */ |
| 41 | + 32.0, 77.0, 122.0, /* second column */ |
| 42 | + 50.0, 122.0, 194.0}; /* third column */ |
| 43 | + |
| 44 | + /* Verify dimensions */ |
| 45 | + mu_assert("left_matmul_dense result should have d1=3", |
| 46 | + Z->d1 == 3); |
| 47 | + mu_assert("left_matmul_dense result should have d2=3", |
| 48 | + Z->d2 == 3); |
| 49 | + mu_assert("left_matmul_dense result should have size=9", |
| 50 | + Z->size == 9); |
| 51 | + |
| 52 | + /* Verify values */ |
| 53 | + mu_assert("Left matmul dense forward pass test failed", |
| 54 | + cmp_double_array(Z->value, expected, 9)); |
| 55 | + |
| 56 | + free_expr(Z); |
| 57 | + return 0; |
| 58 | +} |
0 commit comments