|
23 | 23 |
|
24 | 24 | /* Constant vector elementwise multiplication: y = a \circ child */ |
25 | 25 |
|
26 | | -static inline const double *get_vector(const const_vector_mult_expr *vn) |
27 | | -{ |
28 | | - return vn->param_source ? vn->param_source->value : vn->a; |
29 | | -} |
30 | | - |
31 | 26 | static void forward(expr *node, const double *u) |
32 | 27 | { |
33 | 28 | expr *child = node->left; |
34 | | - const double *a = get_vector((const_vector_mult_expr *) node); |
| 29 | + const_vector_mult_expr *vn = (const_vector_mult_expr *) node; |
| 30 | + const double *a = vn->param_source ? vn->param_source->value : vn->a; |
35 | 31 |
|
36 | 32 | /* child's forward pass */ |
37 | 33 | child->forward(child, u); |
@@ -59,7 +55,8 @@ static void jacobian_init(expr *node) |
59 | 55 | static void eval_jacobian(expr *node) |
60 | 56 | { |
61 | 57 | expr *x = node->left; |
62 | | - const double *a = get_vector((const_vector_mult_expr *) node); |
| 58 | + const_vector_mult_expr *vn = (const_vector_mult_expr *) node; |
| 59 | + const double *a = vn->param_source ? vn->param_source->value : vn->a; |
63 | 60 |
|
64 | 61 | /* evaluate x */ |
65 | 62 | x->eval_jacobian(x); |
@@ -92,7 +89,8 @@ static void wsum_hess_init(expr *node) |
92 | 89 | static void eval_wsum_hess(expr *node, const double *w) |
93 | 90 | { |
94 | 91 | expr *x = node->left; |
95 | | - const double *a = get_vector((const_vector_mult_expr *) node); |
| 92 | + const_vector_mult_expr *vn = (const_vector_mult_expr *) node; |
| 93 | + const double *a = vn->param_source ? vn->param_source->value : vn->a; |
96 | 94 |
|
97 | 95 | /* scale weights w by a */ |
98 | 96 | for (int i = 0; i < node->size; i++) |
|
0 commit comments