Skip to content

Commit 2ed525b

Browse files
committed
refactor of affine atoms done
1 parent 618efc2 commit 2ed525b

5 files changed

Lines changed: 23 additions & 66 deletions

File tree

include/affine.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@ expr *new_linear(expr *u, const CSR_Matrix *A);
99

1010
expr *new_add(expr *left, expr *right);
1111

12-
/* Helper function to initialize a sum expr (can be used with derived types) */
13-
void init_sum(expr *node, expr *child, int d1);
14-
1512
expr *new_sum(expr *child, int axis);
1613
expr *new_hstack(expr **args, int n_args, int n_vars);
1714

src/affine/linear_op.c

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
static void forward(expr *node, const double *u)
55
{
66
expr *x = node->left;
7+
78
/* child's forward pass */
89
node->left->forward(node->left, u);
910

@@ -19,42 +20,30 @@ static bool is_affine(const expr *node)
1920
static void free_type_data(expr *node)
2021
{
2122
linear_op_expr *lin_node = (linear_op_expr *) node;
22-
free_csr_matrix(lin_node->A_csr);
23+
/* memory pointing to by A_csr will already be freed when the
24+
jacobian is freed, so free_csr_matrix(lin_node->A_csr) should
25+
be commented out */
26+
// free_csr_matrix(lin_node->A_csr);
2327
free_csc_matrix(lin_node->A_csc);
28+
lin_node->A_csr = NULL;
29+
lin_node->A_csc = NULL;
2430
}
2531

2632
expr *new_linear(expr *u, const CSR_Matrix *A)
2733
{
2834
/* Allocate the type-specific struct */
2935
linear_op_expr *lin_node = (linear_op_expr *) calloc(1, sizeof(linear_op_expr));
30-
if (!lin_node) return NULL;
31-
3236
expr *node = &lin_node->base;
33-
34-
/* Initialize base fields */
3537
init_expr(node, A->m, 1, u->n_vars, forward, NULL, NULL, is_affine,
3638
free_type_data);
37-
38-
/* Set left child */
3939
node->left = u;
4040
expr_retain(u);
4141

42-
/* Check if allocation succeeded */
43-
if (!node->value)
44-
{
45-
free(lin_node);
46-
return NULL;
47-
}
48-
49-
/* allocate jacobian and copy A into it */
50-
// TODO: this should eventually be removed
51-
node->jacobian = new_csr_matrix(A->m, A->n, A->nnz);
52-
copy_csr_matrix(A, node->jacobian);
53-
5442
/* Initialize type-specific fields */
5543
lin_node->A_csr = new_csr_matrix(A->m, A->n, A->nnz);
5644
copy_csr_matrix(A, lin_node->A_csr);
5745
lin_node->A_csc = csr_to_csc(A);
46+
node->jacobian = lin_node->A_csr;
5847

5948
return node;
6049
}

src/affine/sum.c

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -145,35 +145,6 @@ static void free_type_data(expr *node)
145145
free_int_double_pair_array(snode->int_double_pairs);
146146
}
147147

148-
/* Helper function to initialize a sum expr */
149-
void init_sum(expr *node, expr *child, int d1)
150-
{
151-
node->d1 = d1;
152-
node->d2 = 1;
153-
node->size = d1 * 1;
154-
node->n_vars = child->n_vars;
155-
node->var_id = -1;
156-
node->refcount = 1;
157-
node->left = child;
158-
node->right = NULL;
159-
node->dwork = NULL;
160-
node->iwork = NULL;
161-
node->value = (double *) calloc(node->size, sizeof(double));
162-
node->jacobian = NULL;
163-
node->wsum_hess = NULL;
164-
node->jacobian_init = jacobian_init;
165-
node->wsum_hess_init = wsum_hess_init;
166-
node->eval_jacobian = eval_jacobian;
167-
node->eval_wsum_hess = eval_wsum_hess;
168-
node->local_jacobian = NULL;
169-
node->local_wsum_hess = NULL;
170-
node->is_affine = is_affine;
171-
node->forward = forward;
172-
node->free_type_data = free_type_data;
173-
174-
expr_retain(child);
175-
}
176-
177148
expr *new_sum(expr *child, int axis)
178149
{
179150
int d1 = 0;
@@ -195,20 +166,16 @@ expr *new_sum(expr *child, int axis)
195166
}
196167

197168
/* Allocate the type-specific struct */
198-
sum_expr *snode = (sum_expr *) malloc(sizeof(sum_expr));
199-
if (!snode) return NULL;
200-
169+
sum_expr *snode = (sum_expr *) calloc(1, sizeof(sum_expr));
201170
expr *node = &snode->base;
171+
init_expr(node, d1, 1, child->n_vars, forward, jacobian_init, eval_jacobian,
172+
is_affine, free_type_data);
173+
node->left = child;
174+
expr_retain(child);
202175

203-
/* Initialize base sum fields */
204-
init_sum(node, child, d1);
205-
206-
/* Check if allocation succeeded */
207-
if (!node->value)
208-
{
209-
free(snode);
210-
return NULL;
211-
}
176+
/* hessian function pointers */
177+
node->wsum_hess_init = wsum_hess_init;
178+
node->eval_wsum_hess = eval_wsum_hess;
212179

213180
/* Set type-specific fields */
214181
snode->axis = axis;

src/affine/variable.c

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,10 @@ static bool is_affine(const expr *node)
2929
expr *new_variable(int d1, int d2, int var_id, int n_vars)
3030
{
3131
expr *node = new_expr(d1, d2, n_vars);
32-
if (!node) return NULL;
33-
3432
node->forward = forward;
3533
node->var_id = var_id;
3634
node->is_affine = is_affine;
3735
node->jacobian_init = jacobian_init;
38-
// node->jacobian = NULL;
3936

4037
return node;
4138
}

src/expr.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,20 @@ void free_expr(expr *node)
3838
/* recursively free children */
3939
free_expr(node->left);
4040
free_expr(node->right);
41+
node->left = NULL;
42+
node->right = NULL;
4143

4244
/* free value array and jacobian */
4345
free(node->value);
4446
free_csr_matrix(node->jacobian);
4547
free_csr_matrix(node->wsum_hess);
4648
free(node->dwork);
4749
free(node->iwork);
50+
node->value = NULL;
51+
node->jacobian = NULL;
52+
node->wsum_hess = NULL;
53+
node->dwork = NULL;
54+
node->iwork = NULL;
4855

4956
/* free type-specific data */
5057
if (node->free_type_data)

0 commit comments

Comments
 (0)