File tree Expand file tree Collapse file tree 4 files changed +36
-2
lines changed
Expand file tree Collapse file tree 4 files changed +36
-2
lines changed Original file line number Diff line number Diff line change @@ -63,6 +63,14 @@ static void refresh_param_values(left_matmul_expr *lnode)
6363static void forward (expr * node , const double * u )
6464{
6565 left_matmul_expr * lnode = (left_matmul_expr * ) node ;
66+
67+ /* Refresh param_source expression tree if parameters changed.*/
68+ if (lnode -> param_source != NULL && lnode -> base .needs_parameter_refresh )
69+ {
70+ /* pass NULL to forward: constant param_source never depends on u */
71+ lnode -> param_source -> forward (lnode -> param_source , NULL );
72+ }
73+
6674 refresh_param_values (lnode );
6775
6876 expr * x = node -> left ;
Original file line number Diff line number Diff line change 2828static void forward (expr * node , const double * u )
2929{
3030 expr * child = node -> left ;
31- double a = ((scalar_mult_expr * ) node )-> param_source -> value [0 ];
31+ scalar_mult_expr * snode = (scalar_mult_expr * ) node ;
32+
33+ /* Refresh param_source expression tree if parameters changed.*/
34+ if (node -> needs_parameter_refresh )
35+ {
36+ /* pass NULL to forward: constant param_source never depends on u */
37+ snode -> param_source -> forward (snode -> param_source , NULL );
38+ node -> needs_parameter_refresh = false;
39+ }
40+
41+ double a = snode -> param_source -> value [0 ];
3242
3343 /* child's forward pass */
3444 child -> forward (child , u );
Original file line number Diff line number Diff line change 2828static void forward (expr * node , const double * u )
2929{
3030 expr * child = node -> left ;
31- const double * a = ((vector_mult_expr * ) node )-> param_source -> value ;
31+ vector_mult_expr * vnode = (vector_mult_expr * ) node ;
32+
33+ /* Refresh param_source expression tree if parameters changed.*/
34+ if (node -> needs_parameter_refresh )
35+ {
36+ /* pass NULL to forward: constant param_source never depends on u */
37+ vnode -> param_source -> forward (vnode -> param_source , NULL );
38+ node -> needs_parameter_refresh = false;
39+ }
40+
41+ const double * a = vnode -> param_source -> value ;
3242
3343 /* child's forward pass */
3444 child -> forward (child , u );
Original file line number Diff line number Diff line change 4949#include "jacobian_tests/other/test_prod_axis_zero.h"
5050#include "jacobian_tests/other/test_quad_form.h"
5151#include "numerical_diff/test_numerical_diff.h"
52+ #include "problem/test_param_broadcast.h"
5253#include "problem/test_param_prob.h"
5354#include "problem/test_problem.h"
5455#include "utils/test_cblas.h"
@@ -358,6 +359,11 @@ int main(void)
358359 mu_run_test (test_param_right_matmul_rectangular , tests_run );
359360 mu_run_test (test_param_shared_left_matmul_problem , tests_run );
360361 mu_run_test (test_param_fixed_skip_in_update , tests_run );
362+
363+ printf ("\n--- Parameter + Broadcast Tests ---\n" );
364+ mu_run_test (test_param_broadcast_vector_mult , tests_run );
365+ mu_run_test (test_param_sum_scalar_mult , tests_run );
366+ mu_run_test (test_param_broadcast_left_matmul , tests_run );
361367#endif /* PROFILE_ONLY */
362368
363369#ifdef PROFILE_ONLY
You can’t perform that action at this time.
0 commit comments