@@ -15,21 +15,26 @@ AdamOptimizer::AdamOptimizer(
1515 double eps
1616) :
1717 _context(context),
18- _runtime(build_runtime(function, context)),
1918 _learning_rate(learning_rate),
2019 _schedule(schedule),
2120 _step(0 ),
2221 _step_count(step_count),
2322 _beta1(beta1),
2423 _beta2(beta2),
25- _eps(eps) {
24+ _eps(eps),
25+ _one(1.0 , context->device ()) {
2626 DevicePtr device = context->device ();
27+ std::vector<std::string> param_names;
2728 for (auto & [name, value] : function.globals ()) {
28- Tensor global = context->global (name);
29- _parameters.push_back (global);
30- _exp_avgs.emplace_back (global.dtype (), global.shape (), device).zero ();
31- _exp_avg_sqs.emplace_back (global.dtype (), global.shape (), device).zero ();
29+ if (context->global_requires_grad (name)) {
30+ param_names.push_back (name);
31+ }
3232 }
33+ _parameter = context->reallocate_globals_contiguously (param_names);
34+ _runtime = build_runtime (function, context);
35+ _parameter = Tensor (_parameter.dtype (), _parameter.shape (), _parameter.device ());
36+ _exp_avg = Tensor (_parameter.dtype (), _parameter.shape (), _parameter.device ());
37+ _exp_avg_sq = Tensor (_parameter.dtype (), _parameter.shape (), _parameter.device ());
3338 _input_types.reserve (function.inputs ().size ());
3439 for (auto & input : function.inputs ()) {
3540 _input_types.push_back (input.type );
@@ -47,20 +52,20 @@ TensorVec AdamOptimizer::step(const TensorVec& inputs) {
4752 _runtime->run_with_grad (inputs, std::vector<bool >(inputs.size (), false ));
4853 TensorVec output_grads (outputs.size ());
4954 DevicePtr device = _context->device ();
50- output_grads.at (0 ) = Tensor ( 1.0 , device) ;
55+ output_grads.at (0 ) = _one ;
5156 auto [input_grads, global_grads] =
5257 _runtime->run_backward (output_grads, stored_locals, eval_grad);
53- /* device->adam_step(
54- global_grads,
55- _parameters ,
56- _exp_avgs ,
57- _exp_avg_sqs ,
58+ device->adam_step (
59+ global_grads. at ( 0 ) ,
60+ _parameter ,
61+ _exp_avg ,
62+ _exp_avg_sq ,
5863 step_size,
5964 _beta1,
6065 _beta2,
6166 _eps,
6267 bias_corr2_sqrt
63- );*/
68+ );
6469 return outputs;
6570}
6671
0 commit comments