2424
2525/* Constant scalar multiplication: y = a * child where a is a constant double */
2626
27- static inline double get_scalar (const const_scalar_mult_expr * sn )
28- {
29- return sn -> param_source ? sn -> param_source -> value [0 ] : sn -> a ;
30- }
31-
3227static void forward (expr * node , const double * u )
3328{
3429 expr * child = node -> left ;
@@ -37,7 +32,8 @@ static void forward(expr *node, const double *u)
3732 child -> forward (child , u );
3833
3934 /* local forward pass: multiply each element by scalar a */
40- double a = get_scalar ((const_scalar_mult_expr * ) node );
35+ const_scalar_mult_expr * sn = (const_scalar_mult_expr * ) node ;
36+ double a = sn -> param_source ? sn -> param_source -> value [0 ] : sn -> a ;
4137 for (int i = 0 ; i < node -> size ; i ++ )
4238 {
4339 node -> value [i ] = a * child -> value [i ];
@@ -60,7 +56,8 @@ static void jacobian_init(expr *node)
6056static void eval_jacobian (expr * node )
6157{
6258 expr * child = node -> left ;
63- double a = get_scalar ((const_scalar_mult_expr * ) node );
59+ const_scalar_mult_expr * sn = (const_scalar_mult_expr * ) node ;
60+ double a = sn -> param_source ? sn -> param_source -> value [0 ] : sn -> a ;
6461
6562 /* evaluate child */
6663 child -> eval_jacobian (child );
@@ -90,7 +87,8 @@ static void eval_wsum_hess(expr *node, const double *w)
9087 expr * x = node -> left ;
9188 x -> eval_wsum_hess (x , w );
9289
93- double a = get_scalar ((const_scalar_mult_expr * ) node );
90+ const_scalar_mult_expr * sn = (const_scalar_mult_expr * ) node ;
91+ double a = sn -> param_source ? sn -> param_source -> value [0 ] : sn -> a ;
9492 for (int j = 0 ; j < x -> wsum_hess -> nnz ; j ++ )
9593 {
9694 node -> wsum_hess -> x [j ] = a * x -> wsum_hess -> x [j ];
0 commit comments