Skip to content

Commit 7367914

Browse files
committed
we don't always have to call forward for parameter node in vector mult
1 parent 5fc185f commit 7367914

2 files changed

Lines changed: 12 additions & 6 deletions

File tree

src/atoms/affine/left_matmul.c

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ static void forward(expr *node, const double *u)
6464
{
6565
left_matmul_expr *lnode = (left_matmul_expr *) node;
6666

67-
/* Always call forward on param_source if it exists */
68-
/* Should we also adopt a convention that left_matmul always
69-
points to a param_source, even if its constant? */
67+
/* call forward on param_source if it exists and needs refresh */
7068
if (lnode->param_source != NULL && lnode->base.needs_parameter_refresh)
7169
{
7270
lnode->param_source->forward(lnode->param_source, NULL);

src/atoms/affine/vector_mult.c

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,14 @@ static void forward(expr *node, const double *u)
3030
expr *child = node->left;
3131
vector_mult_expr *vnode = (vector_mult_expr *) node;
3232

33-
/* call forward for param_source expr tree
34-
ex: broadcast(param) or promote(const)*/
35-
vnode->param_source->forward(vnode->param_source, NULL);
33+
/* call forward for param_source expr tree (this extra logic is needed
34+
in case the parameter is a broadcast or promote node which needs to refresh
35+
its values) */
36+
if (vnode->base.needs_parameter_refresh)
37+
{
38+
vnode->param_source->forward(vnode->param_source, NULL);
39+
vnode->base.needs_parameter_refresh = false;
40+
}
3641

3742
const double *a = vnode->param_source->value;
3843

@@ -135,5 +140,8 @@ expr *new_vector_mult(expr *param_node, expr *child)
135140
vnode->param_source = param_node;
136141
expr_retain(param_node);
137142

143+
/* special case for handling broadcasting of constants correctly */
144+
vnode->base.needs_parameter_refresh = true;
145+
138146
return node;
139147
}

0 commit comments

Comments
 (0)