Skip to content

Commit b3e2304

Browse files
committed
small edits
1 parent 978aa03 commit b3e2304

2 files changed

Lines changed: 17 additions & 14 deletions

File tree

src/affine/parameter.c

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,17 @@ static bool is_affine(const expr *node)
7070
expr *new_parameter(int d1, int d2, int param_id, int n_vars, const double *values)
7171
{
7272
parameter_expr *pnode = (parameter_expr *) calloc(1, sizeof(parameter_expr));
73-
init_expr(&pnode->base, d1, d2, n_vars, forward, jacobian_init, eval_jacobian,
74-
is_affine, wsum_hess_init, eval_wsum_hess, NULL);
73+
expr *node = &pnode->base;
74+
init_expr(node, d1, d2, n_vars, forward, jacobian_init, eval_jacobian, is_affine,
75+
wsum_hess_init, eval_wsum_hess, NULL);
7576
pnode->param_id = param_id;
7677

77-
/* If values provided (fixed constant), copy them now */
78+
/* If values provided (fixed constant), copy them now.
79+
Otherwise values will be populated by problem_update_params. */
7880
if (values != NULL)
7981
{
80-
memcpy(pnode->base.value, values, pnode->base.size * sizeof(double));
82+
memcpy(node->value, values, node->size * sizeof(double));
8183
}
82-
/* Otherwise values will be populated by problem_update_params */
8384

84-
return &pnode->base;
85+
return node;
8586
}

src/bivariate/left_matmul.c

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,21 +162,23 @@ static void eval_wsum_hess(expr *node, const double *w)
162162

163163
expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A)
164164
{
165-
int A_m = A->m;
166-
int A_n = A->n;
165+
/* Dimension logic: handle numpy broadcasting (1, n) as (n, )/
166+
We expect u->d1 == A->n. However, numpy's broadcasting rules allow users
167+
to do A @ u where u is (n, ) which in C is actually (1, n). In that case
168+
the result of A @ u is (m, ), which is (1, m) according to broadcasting
169+
rules. We therefore check if this is the case. */
167170

168-
/* Dimension logic: handle numpy broadcasting (1, n) as (n, ) */
169171
int d1, d2, n_blocks;
170-
if (child->d1 == A_n)
172+
if (child->d1 == A->n)
171173
{
172-
d1 = A_m;
174+
d1 = A->m;
173175
d2 = child->d2;
174176
n_blocks = child->d2;
175177
}
176-
else if (child->d2 == A_n && child->d1 == 1)
178+
else if (child->d2 == A->n && child->d1 == 1)
177179
{
178180
d1 = 1;
179-
d2 = A_m;
181+
d2 = A->m;
180182
n_blocks = 1;
181183
}
182184
else
@@ -195,7 +197,7 @@ expr *new_left_matmul(expr *param_node, expr *child, const CSR_Matrix *A)
195197
expr_retain(child);
196198

197199
/* Store small A (NOT block-diagonal) — block functions handle the rest */
198-
node->iwork = (int *) malloc(MAX(A_n, node->n_vars) * sizeof(int));
200+
node->iwork = (int *) malloc(MAX(A->n, node->n_vars) * sizeof(int));
199201
lin_node->csc_to_csr_workspace = (int *) malloc(node->size * sizeof(int));
200202
lin_node->n_blocks = n_blocks;
201203
lin_node->A = new_csr(A);

0 commit comments

Comments
 (0)