2424#include "atoms/matmul.h"
2525#include "atoms/multiply.h"
2626#include "atoms/neg.h"
27+ #include "atoms/parameter.h"
2728#include "atoms/power.h"
2829#include "atoms/prod.h"
2930#include "atoms/prod_axis_one.h"
3637#include "atoms/rel_entr_vector_scalar.h"
3738#include "atoms/reshape.h"
3839#include "atoms/right_matmul.h"
40+ #include "atoms/scalar_mult.h"
3941#include "atoms/sin.h"
4042#include "atoms/sinh.h"
4143#include "atoms/sum.h"
4446#include "atoms/trace.h"
4547#include "atoms/transpose.h"
4648#include "atoms/variable.h"
49+ #include "atoms/vector_mult.h"
4750#include "atoms/xexp.h"
4851
4952/* Include problem bindings */
5659#include "problem/jacobian.h"
5760#include "problem/make_problem.h"
5861#include "problem/objective_forward.h"
62+ #include "problem/register_params.h"
63+ #include "problem/update_params.h"
5964
6065static int numpy_initialized = 0 ;
6166
@@ -70,6 +75,7 @@ static int ensure_numpy(void)
7075static PyMethodDef DNLPMethods [] = {
7176 {"make_variable" , py_make_variable , METH_VARARGS , "Create variable node" },
7277 {"make_constant" , py_make_constant , METH_VARARGS , "Create constant node" },
78+ {"make_parameter" , py_make_parameter , METH_VARARGS , "Create parameter node" },
7379 {"make_linear" , py_make_linear , METH_VARARGS , "Create linear op node" },
7480 {"make_log" , py_make_log , METH_VARARGS , "Create log node" },
7581 {"make_exp" , py_make_exp , METH_VARARGS , "Create exp node" },
@@ -110,7 +116,11 @@ static PyMethodDef DNLPMethods[] = {
110116 {"make_logistic" , py_make_logistic , METH_VARARGS , "Create logistic node" },
111117 {"make_xexp" , py_make_xexp , METH_VARARGS , "Create xexp node" },
112118 {"make_left_matmul" , py_make_left_matmul , METH_VARARGS ,
113- "Create left matmul node (A @ f(x))" },
119+ "Create left matmul node (A @ f(x)): pass None or param capsule as first arg" },
120+ {"make_param_scalar_mult" , py_make_param_scalar_mult , METH_VARARGS ,
121+ "Create scalar mult from parameter (p * f(x))" },
122+ {"make_param_vector_mult" , py_make_param_vector_mult , METH_VARARGS ,
123+ "Create vector mult from parameter (p ∘ f(x))" },
114124 {"make_right_matmul" , py_make_right_matmul , METH_VARARGS ,
115125 "Create right matmul node (f(x) @ A)" },
116126 {"make_quad_form" , py_make_quad_form , METH_VARARGS ,
@@ -150,6 +160,10 @@ static PyMethodDef DNLPMethods[] = {
150160 "Compute Lagrangian Hessian" },
151161 {"get_hessian" , py_get_hessian , METH_VARARGS ,
152162 "Get Lagrangian Hessian without recomputing" },
163+ {"problem_register_params" , py_problem_register_params , METH_VARARGS ,
164+ "Register parameter nodes with the problem" },
165+ {"problem_update_params" , py_problem_update_params , METH_VARARGS ,
166+ "Update parameter values" },
153167 {NULL , NULL , 0 , NULL }};
154168
155169static struct PyModuleDef sparsediffpy_module = {
0 commit comments