11#include "affine.h"
22#include <assert.h>
3+ #include <stdlib.h>
34#include <string.h>
45
56static void forward (expr * node , const double * u )
67{
8+ hstack_expr * hnode = (hstack_expr * ) node ;
79
810 /* children's forward passes */
9- for (int i = 0 ; i < node -> n_args ; i ++ )
11+ for (int i = 0 ; i < hnode -> n_args ; i ++ )
1012 {
11- node -> args [i ]-> forward (node -> args [i ], u );
13+ hnode -> args [i ]-> forward (hnode -> args [i ], u );
1214 }
1315
1416 /* concatenate values horizontally */
1517 int offset = 0 ;
16- for (int i = 0 ; i < node -> n_args ; i ++ )
18+ for (int i = 0 ; i < hnode -> n_args ; i ++ )
1719 {
18- expr * child = node -> args [i ];
20+ expr * child = hnode -> args [i ];
1921 memcpy (node -> value + offset , child -> value , child -> size * sizeof (double ));
2022 offset += child -> size ;
2123 }
2224}
2325
2426static void jacobian_init (expr * node )
2527{
28+ hstack_expr * hnode = (hstack_expr * ) node ;
29+
2630 /* initialize children's jacobians */
2731 int nnz = 0 ;
28- for (int i = 0 ; i < node -> n_args ; i ++ )
32+ for (int i = 0 ; i < hnode -> n_args ; i ++ )
2933 {
30- node -> args [i ]-> jacobian_init (node -> args [i ]);
31- nnz += node -> args [i ]-> jacobian -> nnz ;
34+ hnode -> args [i ]-> jacobian_init (hnode -> args [i ]);
35+ nnz += hnode -> args [i ]-> jacobian -> nnz ;
3236 }
3337
3438 node -> jacobian = new_csr_matrix (node -> size , node -> n_vars , nnz );
3539}
3640
3741static void eval_jacobian (expr * node )
3842{
43+ hstack_expr * hnode = (hstack_expr * ) node ;
44+
3945 /* evaluate children's jacobians */
4046 int row_offset = 0 ;
4147 CSR_Matrix * A = node -> jacobian ;
4248 A -> nnz = 0 ;
4349
44- for (int i = 0 ; i < node -> n_args ; i ++ )
50+ for (int i = 0 ; i < hnode -> n_args ; i ++ )
4551 {
46- expr * child = node -> args [i ];
52+ expr * child = hnode -> args [i ];
4753 child -> eval_jacobian (child );
4854 CSR_Matrix * B = child -> jacobian ;
4955
@@ -65,16 +71,55 @@ static void eval_jacobian(expr *node)
6571
6672static bool is_affine (expr * node )
6773{
68- for (int i = 0 ; i < node -> n_args ; i ++ )
74+ hstack_expr * hnode = (hstack_expr * ) node ;
75+
76+ for (int i = 0 ; i < hnode -> n_args ; i ++ )
6977 {
70- if (!node -> args [i ]-> is_affine (node -> args [i ]))
78+ if (!hnode -> args [i ]-> is_affine (hnode -> args [i ]))
7179 {
7280 return false;
7381 }
7482 }
7583 return true;
7684}
7785
86+ static void free_type_data (expr * node )
87+ {
88+ hstack_expr * hnode = (hstack_expr * ) node ;
89+ for (int i = 0 ; i < hnode -> n_args ; i ++ )
90+ {
91+ free_expr (hnode -> args [i ]);
92+ }
93+ }
94+
95+ /* Helper function to initialize an hstack expr */
96+ void init_hstack (expr * node , int d1 , int d2 , int n_vars )
97+ {
98+ node -> d1 = d1 ;
99+ node -> d2 = d2 ;
100+ node -> size = d1 * d2 ;
101+ node -> n_vars = n_vars ;
102+ node -> var_id = -1 ;
103+ node -> refcount = 1 ;
104+ node -> left = NULL ;
105+ node -> right = NULL ;
106+ node -> dwork = NULL ;
107+ node -> iwork = NULL ;
108+ node -> value = (double * ) calloc (node -> size , sizeof (double ));
109+ node -> jacobian = NULL ;
110+ node -> wsum_hess = NULL ;
111+ node -> CSR_work = NULL ;
112+ node -> jacobian_init = jacobian_init ;
113+ node -> wsum_hess_init = NULL ;
114+ node -> eval_jacobian = eval_jacobian ;
115+ node -> eval_wsum_hess = NULL ;
116+ node -> local_jacobian = NULL ;
117+ node -> local_wsum_hess = NULL ;
118+ node -> forward = forward ;
119+ node -> is_affine = is_affine ;
120+ node -> free_type_data = free_type_data ;
121+ }
122+
78123expr * new_hstack (expr * * args , int n_args , int n_vars )
79124{
80125 /* compute second dimension */
@@ -84,20 +129,30 @@ expr *new_hstack(expr **args, int n_args, int n_vars)
84129 d2 += args [i ]-> d2 ;
85130 }
86131
87- expr * node = new_expr (args [0 ]-> d1 , d2 , n_vars );
88- if (!node ) return NULL ;
89- node -> args = args ;
90- node -> n_args = n_args ;
132+ /* Allocate the type-specific struct */
133+ hstack_expr * hnode = (hstack_expr * ) malloc (sizeof (hstack_expr ));
134+ if (!hnode ) return NULL ;
135+
136+ expr * node = & hnode -> base ;
137+
138+ /* Initialize base hstack fields */
139+ init_hstack (node , args [0 ]-> d1 , d2 , n_vars );
140+
141+ /* Check if allocation succeeded */
142+ if (!node -> value )
143+ {
144+ free (hnode );
145+ return NULL ;
146+ }
147+
148+ /* Set type-specific fields */
149+ hnode -> args = args ;
150+ hnode -> n_args = n_args ;
91151
92152 for (int i = 0 ; i < n_args ; i ++ )
93153 {
94154 expr_retain (args [i ]);
95155 }
96156
97- node -> forward = forward ;
98- node -> is_affine = is_affine ;
99- node -> jacobian_init = jacobian_init ;
100- node -> eval_jacobian = eval_jacobian ;
101-
102157 return node ;
103158}
0 commit comments