Skip to content

Commit ff0110f

Browse files
committed
add initial attempt for fixing parameters and broadcasting
1 parent 30a7868 commit ff0110f

File tree

4 files changed

+36
-2
lines changed

4 files changed

+36
-2
lines changed

src/atoms/affine/left_matmul.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,14 @@ static void refresh_param_values(left_matmul_expr *lnode)
6363
static void forward(expr *node, const double *u)
6464
{
6565
left_matmul_expr *lnode = (left_matmul_expr *) node;
66+
67+
/* Refresh param_source expression tree if parameters changed.*/
68+
if (lnode->param_source != NULL && lnode->base.needs_parameter_refresh)
69+
{
70+
/* pass NULL to forward: constant param_source never depends on u */
71+
lnode->param_source->forward(lnode->param_source, NULL);
72+
}
73+
6674
refresh_param_values(lnode);
6775

6876
expr *x = node->left;

src/atoms/affine/scalar_mult.c

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,17 @@
2828
static void forward(expr *node, const double *u)
2929
{
3030
expr *child = node->left;
31-
double a = ((scalar_mult_expr *) node)->param_source->value[0];
31+
scalar_mult_expr *snode = (scalar_mult_expr *) node;
32+
33+
/* Refresh param_source expression tree if parameters changed.*/
34+
if (node->needs_parameter_refresh)
35+
{
36+
/* pass NULL to forward: constant param_source never depends on u */
37+
snode->param_source->forward(snode->param_source, NULL);
38+
node->needs_parameter_refresh = false;
39+
}
40+
41+
double a = snode->param_source->value[0];
3242

3343
/* child's forward pass */
3444
child->forward(child, u);

src/atoms/affine/vector_mult.c

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,17 @@
2828
static void forward(expr *node, const double *u)
2929
{
3030
expr *child = node->left;
31-
const double *a = ((vector_mult_expr *) node)->param_source->value;
31+
vector_mult_expr *vnode = (vector_mult_expr *) node;
32+
33+
/* Refresh param_source expression tree if parameters changed.*/
34+
if (node->needs_parameter_refresh)
35+
{
36+
/* pass NULL to forward: constant param_source never depends on u */
37+
vnode->param_source->forward(vnode->param_source, NULL);
38+
node->needs_parameter_refresh = false;
39+
}
40+
41+
const double *a = vnode->param_source->value;
3242

3343
/* child's forward pass */
3444
child->forward(child, u);

tests/all_tests.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#include "jacobian_tests/other/test_prod_axis_zero.h"
5050
#include "jacobian_tests/other/test_quad_form.h"
5151
#include "numerical_diff/test_numerical_diff.h"
52+
#include "problem/test_param_broadcast.h"
5253
#include "problem/test_param_prob.h"
5354
#include "problem/test_problem.h"
5455
#include "utils/test_cblas.h"
@@ -358,6 +359,11 @@ int main(void)
358359
mu_run_test(test_param_right_matmul_rectangular, tests_run);
359360
mu_run_test(test_param_shared_left_matmul_problem, tests_run);
360361
mu_run_test(test_param_fixed_skip_in_update, tests_run);
362+
363+
printf("\n--- Parameter + Broadcast Tests ---\n");
364+
mu_run_test(test_param_broadcast_vector_mult, tests_run);
365+
mu_run_test(test_param_sum_scalar_mult, tests_run);
366+
mu_run_test(test_param_broadcast_left_matmul, tests_run);
361367
#endif /* PROFILE_ONLY */
362368

363369
#ifdef PROFILE_ONLY

0 commit comments

Comments
 (0)