-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathleft_matmul.c
More file actions
208 lines (174 loc) · 7.04 KB
/
left_matmul.c
File metadata and controls
208 lines (174 loc) · 7.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
/*
* Copyright 2026 Daniel Cederberg and William Zhang
*
* This file is part of the DNLP-differentiation-engine project.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bivariate.h"
#include "subexpr.h"
#include "utils/Timer.h"
#include "utils/linalg_sparse_matmuls.h"
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
/* This file implements the atom 'left_matmul' corresponding to the operation y =
A @ f(x), where A is a given matrix (from a parameter node) and f(x) is an
arbitrary expression. Here, f(x) can be a vector-valued expression and a
matrix-valued expression. The dimensions are A - m x n, f(x) - n x p, y - m x p.
Note that here A does not have global column indices but it is a local matrix.
This is an important distinction compared to linear_op_expr.
* To compute the forward pass: vec(y) = A_kron @ vec(f(x)),
where A_kron = I_p kron A is a Kronecker product of size (m*p) x (n*p),
or more specificely, a block-diagonal matrix with p blocks of A along the
diagonal. In the refactored implementation we don't form A_kron explicitly,
only conceptually. This led to a 100x speedup in the initialization of the
Jacobian sparsity pattern.
* To compute the Jacobian: J_y = A_kron @ J_f(x), where J_f(x) is the
Jacobian of f(x) of size (n*p) x n_vars.
* To compute the contribution to the Lagrange Hessian: we form
w = A_kron^T @ lambda and then evaluate the hessian of f(x).
Working in terms of A_kron unifies the implementation of f(x) being
vector-valued or matrix-valued.
*/
#include "utils/utils.h"
#include <string.h>
/* Refresh A and AT values from param_source.
A is the small m x n matrix (NOT block-diagonal).
No-op when param_source is NULL (fixed constant — values already in A). */
static void refresh_param_values(left_matmul_expr *lin_node)
{
if (!lin_node->param_source) return;
memcpy(lin_node->A->x, lin_node->param_source->value,
lin_node->A->nnz * sizeof(double));
/* Recompute AT values from updated A */
AT_fill_values(lin_node->A, lin_node->AT, lin_node->base.iwork);
}
static void forward(expr *node, const double *u)
{
expr *x = node->left;
left_matmul_expr *lin_node = (left_matmul_expr *) node;
/* refresh A/AT from parameter source */
refresh_param_values(lin_node);
/* child's forward pass */
node->left->forward(node->left, u);
/* y = A_kron @ vec(f(x)) */
block_left_multiply_vec(lin_node->A, x->value, node->value, lin_node->n_blocks);
}
static bool is_affine(const expr *node)
{
return node->left->is_affine(node->left);
}
static void free_type_data(expr *node)
{
left_matmul_expr *lin_node = (left_matmul_expr *) node;
free_csr_matrix(lin_node->A);
free_csr_matrix(lin_node->AT);
free_csc_matrix(lin_node->Jchild_CSC);
free_csc_matrix(lin_node->J_CSC);
free(lin_node->csc_to_csr_workspace);
free_expr(lin_node->param_source);
}
static void jacobian_init(expr *node)
{
expr *x = node->left;
left_matmul_expr *lin_node = (left_matmul_expr *) node;
/* initialize child's jacobian and precompute sparsity of its CSC */
x->jacobian_init(x);
lin_node->Jchild_CSC = csr_to_csc_fill_sparsity(x->jacobian, node->iwork);
/* precompute sparsity of this node's jacobian in CSC and CSR */
lin_node->J_CSC = block_left_multiply_fill_sparsity(
lin_node->A, lin_node->Jchild_CSC, lin_node->n_blocks);
node->jacobian =
csc_to_csr_fill_sparsity(lin_node->J_CSC, lin_node->csc_to_csr_workspace);
}
static void eval_jacobian(expr *node)
{
expr *x = node->left;
left_matmul_expr *lnode = (left_matmul_expr *) node;
CSC_Matrix *Jchild_CSC = lnode->Jchild_CSC;
CSC_Matrix *J_CSC = lnode->J_CSC;
/* refresh A from parameter source */
refresh_param_values(lnode);
/* evaluate child's jacobian and convert to CSC */
x->eval_jacobian(x);
csr_to_csc_fill_values(x->jacobian, Jchild_CSC, node->iwork);
/* compute this node's jacobian: */
block_left_multiply_fill_values(lnode->A, Jchild_CSC, J_CSC);
csc_to_csr_fill_values(J_CSC, node->jacobian, lnode->csc_to_csr_workspace);
}
static void wsum_hess_init(expr *node)
{
/* initialize child's hessian */
expr *x = node->left;
x->wsum_hess_init(x);
/* allocate this node's hessian with the same sparsity as child's */
node->wsum_hess = new_csr_matrix(node->n_vars, node->n_vars, x->wsum_hess->nnz);
memcpy(node->wsum_hess->p, x->wsum_hess->p, (node->n_vars + 1) * sizeof(int));
memcpy(node->wsum_hess->i, x->wsum_hess->i, x->wsum_hess->nnz * sizeof(int));
/* work for computing A^T w*/
int n_blocks = ((left_matmul_expr *) node)->n_blocks;
int dim = ((left_matmul_expr *) node)->A->n * n_blocks;
node->dwork = (double *) malloc(dim * sizeof(double));
}
static void eval_wsum_hess(expr *node, const double *w)
{
/* compute A^T w */
CSR_Matrix *AT = ((left_matmul_expr *) node)->AT;
int n_blocks = ((left_matmul_expr *) node)->n_blocks;
block_left_multiply_vec(AT, w, node->dwork, n_blocks);
node->left->eval_wsum_hess(node->left, node->dwork);
memcpy(node->wsum_hess->x, node->left->wsum_hess->x,
node->wsum_hess->nnz * sizeof(double));
}
expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A)
{
int A_m = A->m;
int A_n = A->n;
/* Dimension logic: handle numpy broadcasting (1, n) as (n, ) */
int d1, d2, n_blocks;
if (child->d1 == A_n)
{
d1 = A_m;
d2 = child->d2;
n_blocks = child->d2;
}
else if (child->d2 == A_n && child->d1 == 1)
{
d1 = 1;
d2 = A_m;
n_blocks = 1;
}
else
{
fprintf(stderr, "Error in new_left_matmul: dimension mismatch\n");
exit(1);
}
/* Allocate the type-specific struct */
left_matmul_expr *lin_node =
(left_matmul_expr *) calloc(1, sizeof(left_matmul_expr));
expr *node = &lin_node->base;
init_expr(node, d1, d2, child->n_vars, forward, jacobian_init, eval_jacobian,
is_affine, wsum_hess_init, eval_wsum_hess, free_type_data);
node->left = child;
expr_retain(child);
/* Store small A (NOT block-diagonal) — block functions handle the rest */
node->iwork = (int *) malloc(MAX(A_n, node->n_vars) * sizeof(int));
lin_node->csc_to_csr_workspace = (int *) malloc(node->size * sizeof(int));
lin_node->n_blocks = n_blocks;
lin_node->A = new_csr(A);
lin_node->AT = transpose(lin_node->A, node->iwork);
lin_node->param_source = param_node;
if (param_node) expr_retain(param_node);
return node;
}