File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -64,9 +64,7 @@ static void forward(expr *node, const double *u)
6464{
6565 left_matmul_expr * lnode = (left_matmul_expr * ) node ;
6666
67- /* Always call forward on param_source if it exists */
68- /* Should we also adopt a convention that left_matmul always
69- points to a param_source, even if its constant? */
67+ /* call forward on param_source if it exists and needs refresh */
7068 if (lnode -> param_source != NULL && lnode -> base .needs_parameter_refresh )
7169 {
7270 lnode -> param_source -> forward (lnode -> param_source , NULL );
Original file line number Diff line number Diff line change @@ -30,9 +30,14 @@ static void forward(expr *node, const double *u)
3030 expr * child = node -> left ;
3131 vector_mult_expr * vnode = (vector_mult_expr * ) node ;
3232
33- /* call forward for param_source expr tree
34- ex: broadcast(param) or promote(const)*/
35- vnode -> param_source -> forward (vnode -> param_source , NULL );
33+ /* call forward for param_source expr tree (this extra logic is needed
34+ in case the parameter is a broadcast or promote node which needs to refresh
35+ its values) */
36+ if (vnode -> base .needs_parameter_refresh )
37+ {
38+ vnode -> param_source -> forward (vnode -> param_source , NULL );
39+ vnode -> base .needs_parameter_refresh = false;
40+ }
3641
3742 const double * a = vnode -> param_source -> value ;
3843
@@ -135,5 +140,8 @@ expr *new_vector_mult(expr *param_node, expr *child)
135140 vnode -> param_source = param_node ;
136141 expr_retain (param_node );
137142
143+ /* special case for handling broadcasting of constants correctly */
144+ vnode -> base .needs_parameter_refresh = true;
145+
138146 return node ;
139147}
You can’t perform that action at this time.
0 commit comments