Skip to content

Commit 36ec239

Browse files
Transurgeonclaude
andcommitted
Remove redundant A_m, A_n params from new_left_param_matmul
These dimensions are always equal to param_node->d1 and param_node->d2, which are set during make_parameter. Read them from the node directly. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent fea0708 commit 36ec239

2 files changed

Lines changed: 12 additions & 9 deletions

File tree

include/bivariate.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ expr *new_const_scalar_mult(double a, expr *child);
4343
expr *new_const_vector_mult(const double *a, expr *child);
4444

4545
/* Left matrix multiplication with parameter source: P @ f(x) where P is a parameter */
46-
expr *new_left_param_matmul(expr *param_node, expr *u, int A_m, int A_n);
46+
expr *new_left_param_matmul(expr *param_node, expr *child);
4747

4848
/* Parameter scalar multiplication: p * f(x) where p is a parameter */
4949
expr *new_param_scalar_mult(expr *param_node, expr *child);

src/bivariate/left_matmul.c

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -253,17 +253,20 @@ static void free_param_matmul_type_data(expr *node)
253253
lin_node->param_source = NULL;
254254
}
255255

256-
expr *new_left_param_matmul(expr *param_node, expr *u, int A_m, int A_n)
256+
expr *new_left_param_matmul(expr *param_node, expr *child)
257257
{
258+
int A_m = param_node->d1;
259+
int A_n = param_node->d2;
260+
258261
/* Same dimension logic as new_left_matmul */
259262
int d1, d2, n_blocks;
260-
if (u->d1 == A_n)
263+
if (child->d1 == A_n)
261264
{
262265
d1 = A_m;
263-
d2 = u->d2;
264-
n_blocks = u->d2;
266+
d2 = child->d2;
267+
n_blocks = child->d2;
265268
}
266-
else if (u->d2 == A_n && u->d1 == 1)
269+
else if (child->d2 == A_n && child->d1 == 1)
267270
{
268271
d1 = 1;
269272
d2 = A_m;
@@ -296,10 +299,10 @@ expr *new_left_param_matmul(expr *param_node, expr *u, int A_m, int A_n)
296299
left_matmul_expr *lin_node =
297300
(left_matmul_expr *) calloc(1, sizeof(left_matmul_expr));
298301
expr *node = &lin_node->base;
299-
init_expr(node, d1, d2, u->n_vars, forward, jacobian_init, eval_jacobian,
302+
init_expr(node, d1, d2, child->n_vars, forward, jacobian_init, eval_jacobian,
300303
is_affine, wsum_hess_init, eval_wsum_hess, free_param_matmul_type_data);
301-
node->left = u;
302-
expr_retain(u);
304+
node->left = child;
305+
expr_retain(child);
303306

304307
/* Initialize type-specific fields */
305308
lin_node->A = block_diag_repeat_csr(A_tmp, n_blocks);

0 commit comments

Comments
 (0)