Skip to content

Commit 04c2268

Browse files
committed
dense right matmul
1 parent 4813f42 commit 04c2268

4 files changed

Lines changed: 25 additions & 2 deletions

File tree

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
cmake_minimum_required(VERSION 3.15)
22
project(DNLP_Diff_Engine C)
33
set(CMAKE_C_STANDARD 99)
4+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
45

56
set(DIFF_ENGINE_VERSION_MAJOR 0)
67
set(DIFF_ENGINE_VERSION_MINOR 1)

include/bivariate.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ expr *new_left_matmul_dense(expr *u, int m, int n, const double *data);
4040
/* Right matrix multiplication: f(x) @ A where A is a constant matrix */
4141
expr *new_right_matmul(expr *u, const CSR_Matrix *A);
4242

43+
expr *new_right_matmul_dense(expr *u, int m, int n, const double *data);
44+
4345
/* Constant scalar multiplication: a * f(x) where a is a constant double */
4446
expr *new_const_scalar_mult(double a, expr *child);
4547

src/bivariate/right_matmul.c

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
*/
1818
#include "affine.h"
1919
#include "bivariate.h"
20-
#include "subexpr.h"
2120
#include "utils/CSR_Matrix.h"
22-
#include "utils/linalg_sparse_matmuls.h"
2321
#include <stdlib.h>
2422

2523
/* This file implements the atom 'right_matmul' corresponding to the operation y =
@@ -41,3 +39,24 @@ expr *new_right_matmul(expr *u, const CSR_Matrix *A)
4139
free(work_transpose);
4240
return node;
4341
}
42+
43+
expr *new_right_matmul_dense(expr *u, int m, int n, const double *data)
44+
{
45+
/* We express: u @ A = (A^T @ u^T)^T
46+
A is m x n, so A^T is n x m. */
47+
double *AT = (double *) malloc(n * m * sizeof(double));
48+
for (int i = 0; i < m; i++)
49+
{
50+
for (int j = 0; j < n; j++)
51+
{
52+
AT[j * m + i] = data[i * n + j];
53+
}
54+
}
55+
56+
expr *u_transpose = new_transpose(u);
57+
expr *left_matmul_node = new_left_matmul_dense(u_transpose, m, n, AT);
58+
expr *node = new_transpose(left_matmul_node);
59+
60+
free(AT);
61+
return node;
62+
}

tests/jacobian_tests/test_prod.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "minunit.h"
55
#include "other.h"
66
#include "test_helpers.h"
7+
#include "affine.h"
78

89
const char *test_jacobian_prod_no_zero(void)
910
{

0 commit comments

Comments
 (0)