Skip to content

Commit ddd38fd

Browse files
Transurgeonclaudedance858
authored
Diag vec (#35)
* Add diag_vec atom for creating diagonal matrices from vectors Implements diag_vec which converts a vector of size n into an n×n diagonal matrix. Includes forward pass, Jacobian, and Hessian computations. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * Clean up diag_vec jacobian_init and eval_jacobian - Use standard CSR building pattern (J->p[row] = nnz) - Use next_diag counter instead of checking row == child_row * (n+1) - Simplify eval_jacobian to O(n) loop computing out_row directly Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: Daniel Cederberg <101940375+dance858@users.noreply.github.com>
1 parent cb7b17a commit ddd38fd

5 files changed

Lines changed: 166 additions & 0 deletions

File tree

include/affine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ expr *new_variable(int d1, int d2, int var_id, int n_vars);
2121
expr *new_index(expr *child, int d1, int d2, const int *indices, int n_idxs);
2222
expr *new_reshape(expr *child, int d1, int d2);
2323
expr *new_broadcast(expr *child, int target_d1, int target_d2);
24+
expr *new_diag_vec(expr *child);
2425
expr *new_transpose(expr *child);
2526

2627
#endif /* AFFINE_H */

python/atoms/diag_vec.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef ATOM_DIAG_VEC_H
2+
#define ATOM_DIAG_VEC_H
3+
4+
#include "common.h"
5+
6+
static PyObject *py_make_diag_vec(PyObject *self, PyObject *args)
7+
{
8+
PyObject *child_capsule;
9+
10+
if (!PyArg_ParseTuple(args, "O", &child_capsule))
11+
{
12+
return NULL;
13+
}
14+
15+
expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME);
16+
if (!child)
17+
{
18+
return NULL;
19+
}
20+
21+
expr *node = new_diag_vec(child);
22+
if (!node)
23+
{
24+
PyErr_SetString(PyExc_RuntimeError, "failed to create diag_vec node");
25+
return NULL;
26+
}
27+
28+
expr_retain(node);
29+
return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor);
30+
}
31+
32+
#endif /* ATOM_DIAG_VEC_H */

python/bindings.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "atoms/const_vector_mult.h"
1212
#include "atoms/constant.h"
1313
#include "atoms/cos.h"
14+
#include "atoms/diag_vec.h"
1415
#include "atoms/entr.h"
1516
#include "atoms/exp.h"
1617
#include "atoms/getters.h"
@@ -96,6 +97,7 @@ static PyMethodDef DNLPMethods[] = {
9697
"Create prod_axis_one node"},
9798
{"make_sin", py_make_sin, METH_VARARGS, "Create sin node"},
9899
{"make_cos", py_make_cos, METH_VARARGS, "Create cos node"},
100+
{"make_diag_vec", py_make_diag_vec, METH_VARARGS, "Create diag_vec node"},
99101
{"make_tan", py_make_tan, METH_VARARGS, "Create tan node"},
100102
{"make_sinh", py_make_sinh, METH_VARARGS, "Create sinh node"},
101103
{"make_tanh", py_make_tanh, METH_VARARGS, "Create tanh node"},

src/affine/diag_vec.c

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
3+
#include "affine.h"
4+
#include <assert.h>
5+
#include <stdlib.h>
6+
#include <string.h>
7+
8+
/* diag_vec: converts a vector of size n into an n×n diagonal matrix.
9+
* In Fortran (column-major) order, element i of the input maps to
10+
* position i*(n+1) in the flattened output (the diagonal positions). */
11+
12+
static void forward(expr *node, const double *u)
13+
{
14+
expr *x = node->left;
15+
int n = x->size;
16+
17+
/* child's forward pass */
18+
x->forward(x, u);
19+
20+
/* zero-initialize output */
21+
memset(node->value, 0, node->size * sizeof(double));
22+
23+
/* place input elements on the diagonal */
24+
for (int i = 0; i < n; i++)
25+
{
26+
node->value[i * (n + 1)] = x->value[i];
27+
}
28+
}
29+
30+
static void jacobian_init(expr *node)
31+
{
32+
expr *x = node->left;
33+
int n = x->size;
34+
x->jacobian_init(x);
35+
36+
CSR_Matrix *Jx = x->jacobian;
37+
CSR_Matrix *J = new_csr_matrix(node->size, node->n_vars, Jx->nnz);
38+
39+
/* Output has n² rows but only n diagonal positions are non-empty.
40+
* Diagonal position i is at row i*(n+1) in Fortran order. */
41+
int nnz = 0;
42+
int next_diag = 0;
43+
for (int row = 0; row < node->size; row++)
44+
{
45+
J->p[row] = nnz;
46+
if (row == next_diag)
47+
{
48+
int child_row = row / (n + 1);
49+
int len = Jx->p[child_row + 1] - Jx->p[child_row];
50+
memcpy(J->i + nnz, Jx->i + Jx->p[child_row], len * sizeof(int));
51+
nnz += len;
52+
next_diag += n + 1;
53+
}
54+
}
55+
J->p[node->size] = nnz;
56+
57+
node->jacobian = J;
58+
}
59+
60+
static void eval_jacobian(expr *node)
61+
{
62+
expr *x = node->left;
63+
int n = x->size;
64+
x->eval_jacobian(x);
65+
66+
CSR_Matrix *J = node->jacobian;
67+
CSR_Matrix *Jx = x->jacobian;
68+
69+
/* Copy values from child row i to output diagonal row i*(n+1) */
70+
for (int i = 0; i < n; i++)
71+
{
72+
int out_row = i * (n + 1);
73+
int len = J->p[out_row + 1] - J->p[out_row];
74+
memcpy(J->x + J->p[out_row], Jx->x + Jx->p[i], len * sizeof(double));
75+
}
76+
}
77+
78+
static void wsum_hess_init(expr *node)
79+
{
80+
expr *x = node->left;
81+
82+
/* initialize child's wsum_hess */
83+
x->wsum_hess_init(x);
84+
85+
/* workspace for extracting diagonal weights */
86+
node->dwork = (double *) calloc(x->size, sizeof(double));
87+
88+
/* Copy child's Hessian structure (diag_vec is linear, so its own Hessian is zero) */
89+
CSR_Matrix *Hx = x->wsum_hess;
90+
node->wsum_hess = new_csr_matrix(Hx->m, Hx->n, Hx->nnz);
91+
memcpy(node->wsum_hess->p, Hx->p, (Hx->m + 1) * sizeof(int));
92+
memcpy(node->wsum_hess->i, Hx->i, Hx->nnz * sizeof(int));
93+
}
94+
95+
static void eval_wsum_hess(expr *node, const double *w)
96+
{
97+
expr *x = node->left;
98+
int n = x->size;
99+
100+
/* Extract weights from diagonal positions of w (which has n² elements) */
101+
for (int i = 0; i < n; i++)
102+
{
103+
node->dwork[i] = w[i * (n + 1)];
104+
}
105+
106+
/* Evaluate child's Hessian with extracted weights */
107+
x->eval_wsum_hess(x, node->dwork);
108+
memcpy(node->wsum_hess->x, x->wsum_hess->x, x->wsum_hess->nnz * sizeof(double));
109+
}
110+
111+
static bool is_affine(const expr *node)
112+
{
113+
return node->left->is_affine(node->left);
114+
}
115+
116+
expr *new_diag_vec(expr *child)
117+
{
118+
/* child must be a vector: either column (n, 1) or row (1, n) */
119+
assert(child->d1 == 1 || child->d2 == 1);
120+
121+
/* n is the number of elements (works for both row and column vectors) */
122+
int n = child->size;
123+
expr *node = (expr *) calloc(1, sizeof(expr));
124+
init_expr(node, n, n, child->n_vars, forward, jacobian_init, eval_jacobian,
125+
is_affine, wsum_hess_init, eval_wsum_hess, NULL);
126+
node->left = child;
127+
expr_retain(child);
128+
129+
return node;
130+
}

src/dnlp_diff_engine/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"make_promote",
2121
"make_index",
2222
"make_reshape",
23+
"make_diag_vec",
2324
"make_log",
2425
"make_exp",
2526
"make_power",

0 commit comments

Comments
 (0)