Skip to content

Commit 2b38d34

Browse files
committed
optimization on contiguous parameters
1 parent 85c8cbe commit 2b38d34

6 files changed

Lines changed: 112 additions & 24 deletions

File tree

madspace/include/madspace/driver/adam_optimizer.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@ class AdamOptimizer {
3737
double _beta1;
3838
double _beta2;
3939
double _eps;
40-
TensorVec _parameters;
41-
TensorVec _exp_avgs;
42-
TensorVec _exp_avg_sqs;
40+
Tensor _one;
41+
Tensor _parameter;
42+
Tensor _exp_avg;
43+
Tensor _exp_avg_sq;
4344
TypeVec _input_types;
4445
};
4546

madspace/include/madspace/driver/context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class Context {
128128
std::vector<std::string> global_names() const;
129129
void delete_global(const std::string& name);
130130
void copy_globals_from(Context& context);
131+
Tensor reallocate_globals_contiguously(const std::vector<std::string>& names);
131132
const MatrixElementApi& matrix_element(std::size_t index) const;
132133
void save_globals(const std::string& dir) const;
133134
void load_globals(const std::string& dir);

madspace/include/madspace/driver/tensor.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ class Sizes {
5151
const std::size_t* data() const { return &_values[0]; }
5252
std::size_t& back() { return _values[_size - 1]; }
5353
const std::size_t& back() const { return _values[_size - 1]; }
54+
std::size_t product() const {
55+
std::size_t size = 1;
56+
for (std::size_t dim_size : *this) {
57+
size *= dim_size;
58+
}
59+
return size;
60+
}
5461

5562
private:
5663
std::size_t _values[max_size];
@@ -485,14 +492,7 @@ class Tensor {
485492
}
486493
}
487494

488-
std::size_t byte_size() const {
489-
check_impl();
490-
std::size_t size = dtype_size();
491-
for (auto dim_size : impl->shape) {
492-
size *= dim_size;
493-
}
494-
return size;
495-
}
495+
std::size_t byte_size() const { return dtype_size() * shape().product(); }
496496

497497
void reset() {
498498
if (impl == nullptr) {
@@ -517,7 +517,9 @@ class Tensor {
517517
std::vector<Tensor> unstack(std::size_t axis) const;
518518
Tensor unsqueeze(std::size_t axis) const;
519519
Tensor expand(const Sizes& shape) const;
520+
Tensor reshape(const Sizes& shape) const;
520521
Tensor factor_dim(std::size_t axis, std::size_t factor);
522+
std::vector<Tensor> split_and_reshape(const std::vector<Sizes>& shapes) const;
521523

522524
template <typename D>
523525
Tensor cpu(const D& device) const {

madspace/src/driver/adam_optimizer.cpp

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,26 @@ AdamOptimizer::AdamOptimizer(
1515
double eps
1616
) :
1717
_context(context),
18-
_runtime(build_runtime(function, context)),
1918
_learning_rate(learning_rate),
2019
_schedule(schedule),
2120
_step(0),
2221
_step_count(step_count),
2322
_beta1(beta1),
2423
_beta2(beta2),
25-
_eps(eps) {
24+
_eps(eps),
25+
_one(1.0, context->device()) {
2626
DevicePtr device = context->device();
27+
std::vector<std::string> param_names;
2728
for (auto& [name, value] : function.globals()) {
28-
Tensor global = context->global(name);
29-
_parameters.push_back(global);
30-
_exp_avgs.emplace_back(global.dtype(), global.shape(), device).zero();
31-
_exp_avg_sqs.emplace_back(global.dtype(), global.shape(), device).zero();
29+
if (context->global_requires_grad(name)) {
30+
param_names.push_back(name);
31+
}
3232
}
33+
_parameter = context->reallocate_globals_contiguously(param_names);
34+
_runtime = build_runtime(function, context);
35+
_parameter = Tensor(_parameter.dtype(), _parameter.shape(), _parameter.device());
36+
_exp_avg = Tensor(_parameter.dtype(), _parameter.shape(), _parameter.device());
37+
_exp_avg_sq = Tensor(_parameter.dtype(), _parameter.shape(), _parameter.device());
3338
_input_types.reserve(function.inputs().size());
3439
for (auto& input : function.inputs()) {
3540
_input_types.push_back(input.type);
@@ -47,20 +52,20 @@ TensorVec AdamOptimizer::step(const TensorVec& inputs) {
4752
_runtime->run_with_grad(inputs, std::vector<bool>(inputs.size(), false));
4853
TensorVec output_grads(outputs.size());
4954
DevicePtr device = _context->device();
50-
output_grads.at(0) = Tensor(1.0, device);
55+
output_grads.at(0) = _one;
5156
auto [input_grads, global_grads] =
5257
_runtime->run_backward(output_grads, stored_locals, eval_grad);
53-
/*device->adam_step(
54-
global_grads,
55-
_parameters,
56-
_exp_avgs,
57-
_exp_avg_sqs,
58+
device->adam_step(
59+
global_grads.at(0),
60+
_parameter,
61+
_exp_avg,
62+
_exp_avg_sq,
5863
step_size,
5964
_beta1,
6065
_beta2,
6166
_eps,
6267
bias_corr2_sqrt
63-
);*/
68+
);
6469
return outputs;
6570
}
6671

madspace/src/driver/context.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,38 @@ void Context::copy_globals_from(Context& context) {
195195
}
196196
}
197197

198+
Tensor Context::reallocate_globals_contiguously(const std::vector<std::string>& names) {
199+
std::vector<Sizes> shapes;
200+
shapes.reserve(names.size());
201+
std::size_t total_size = 0;
202+
DataType dtype;
203+
for (bool first = true; auto& name : names) {
204+
auto& glob = _globals.at(name).first;
205+
if (!glob.is_only_reference()) {
206+
throw std::runtime_error(
207+
std::format(
208+
"Global {}: cannot reallocate as it is externally referenced", name
209+
)
210+
);
211+
}
212+
if (first) {
213+
dtype = glob.dtype();
214+
first = false;
215+
} else if (dtype != glob.dtype()) {
216+
throw std::runtime_error(
217+
std::format("Global {}: incompatible dtype", name)
218+
);
219+
}
220+
shapes.push_back(glob.shape());
221+
total_size += glob.shape().product();
222+
}
223+
Tensor parent(dtype, {total_size}, device());
224+
for (auto [name, tensor] : zip(names, parent.split_and_reshape(shapes))) {
225+
_globals.at(name).first = tensor;
226+
}
227+
return parent;
228+
}
229+
198230
const MatrixElementApi& Context::matrix_element(std::size_t index) const {
199231
if (index >= _matrix_elements.size()) {
200232
throw std::runtime_error("Matrix element index out of bounds");

madspace/src/driver/tensor.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,31 @@ Tensor Tensor::expand(const Sizes& shape) const {
120120
});
121121
}
122122

123+
Tensor Tensor::reshape(const Sizes& new_shape) const {
124+
check_impl();
125+
if (!is_contiguous()) {
126+
throw std::runtime_error("Tensor must be contiguous");
127+
}
128+
if (new_shape.product() != shape().product()) {
129+
throw std::runtime_error("Incompatible shapes");
130+
}
131+
Tensor ret(new Tensor::TensorImpl{
132+
impl->dtype,
133+
new_shape,
134+
impl->device,
135+
impl->data,
136+
false,
137+
std::nullopt,
138+
impl,
139+
1,
140+
{},
141+
impl->offset,
142+
0
143+
});
144+
ret.init_stride();
145+
return ret;
146+
}
147+
123148
Tensor Tensor::factor_dim(std::size_t axis, std::size_t factor) {
124149
check_impl();
125150
auto new_dim = impl->shape.size() + 1;
@@ -157,6 +182,28 @@ Tensor Tensor::factor_dim(std::size_t axis, std::size_t factor) {
157182
});
158183
}
159184

185+
std::vector<Tensor> Tensor::split_and_reshape(const std::vector<Sizes>& shapes) const {
186+
check_impl();
187+
if (!is_contiguous() || shape().size() != 1) {
188+
throw std::runtime_error(
189+
"split_and_reshape is only available for single-dimensional contiguous "
190+
"tensors"
191+
);
192+
}
193+
SizeVec size_prods;
194+
size_prods.reserve(shapes.size());
195+
for (auto& shape : shapes) {
196+
size_prods.push_back(shape.product());
197+
}
198+
TensorVec split_tensors = split(0, size_prods);
199+
TensorVec ret;
200+
ret.reserve(shapes.size());
201+
for (auto [tensor, shape] : zip(split_tensors, shapes)) {
202+
ret.push_back(tensor.reshape(shape));
203+
}
204+
return ret;
205+
}
206+
160207
std::size_t Tensor::init_stride() {
161208
std::size_t stride_prod = 1;
162209
bool first = true;

0 commit comments

Comments
 (0)