Skip to content

Commit c70a84f

Browse files
src/radmeth/radmeth_optimize.cpp: fit regression model now no longer accepts initial parameter values as those weren't used. Out-params have been added to return the estimates of p for the beta-binoms in each group and the shared value of the estimated dispersion
1 parent ed00b7e commit c70a84f

1 file changed

Lines changed: 25 additions & 19 deletions

File tree

src/radmeth/radmeth_optimize.cpp

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ get_cumulative(const std::vector<std::uint32_t> &group_id,
212212
const auto val = get_value(mc[c_idx]);
213213
auto &vec = get_vector(cumul[g_idx]);
214214
for (auto i = 0u; i < val; ++i)
215-
vec[i]++;
215+
++vec[i];
216216
}
217217
};
218218
// call the lambda 3 times for m_counts, r_counts, d_counts
@@ -229,24 +229,20 @@ get_cumulative(const std::vector<std::uint32_t> &group_id,
229229
[](cumul_counts &c) -> std::vector<std::uint32_t> & { return c.d_counts; });
230230
}
231231

232-
[[nodiscard]] bool
233-
fit_regression_model(Regression &r, std::vector<double> &params_init) {
232+
void
233+
fit_regression_model(Regression &r, std::vector<double> &p_estimates,
234+
double &dispersion_estimate) {
234235
static constexpr auto init_dispersion_param = -2.5;
235236
const auto stepsize = Regression::stepsize;
236237
const auto max_iter = Regression::max_iter;
237238

238-
get_cumulative(r.design.group_id, r.design.n_groups(), r.props.mc, r.cumul);
239+
const auto n_groups = r.n_groups();
240+
get_cumulative(r.design.group_id, n_groups, r.props.mc, r.cumul);
239241
set_max_r_count(r);
240242

241-
// one more than the number of factors
242-
const std::size_t n_params = r.n_factors() + 1;
243-
if (params_init.empty()) {
244-
params_init.resize(n_params, 0.0);
245-
params_init.back() = init_dispersion_param;
246-
}
247-
if (std::size(params_init) != n_params)
248-
throw std::runtime_error("Wrong number of initial parameters.");
249-
r.p_v.resize(r.n_groups());
243+
r.p_v.resize(n_groups);
244+
245+
const std::size_t n_params = r.n_params();
250246
const auto tol = std::sqrt(n_params) * r.n_samples() * Regression::tolerance;
251247
// clang-format off
252248
auto loglik_bundle = gsl_multimin_function_fdf{
@@ -258,6 +254,12 @@ fit_regression_model(Regression &r, std::vector<double> &params_init) {
258254
};
259255
// clang-format on
260256

257+
// set the parameters: zero for "p" parameters and the final one for
258+
// dispersion using the constant
259+
auto params = gsl_vector_alloc(n_params);
260+
gsl_vector_set_all(params, 0.0);
261+
gsl_vector_set(params, n_params - 1, init_dispersion_param);
262+
261263
// Alternatives:
262264
// - gsl_multimin_fdfminimizer_conjugate_pr
263265
// - gsl_multimin_fdfminimizer_conjugate_fr
@@ -266,10 +268,6 @@ fit_regression_model(Regression &r, std::vector<double> &params_init) {
266268
const auto minimizer = gsl_multimin_fdfminimizer_conjugate_pr;
267269
auto s = gsl_multimin_fdfminimizer_alloc(minimizer, n_params);
268270

269-
auto params = gsl_vector_alloc(n_params);
270-
for (auto i = 0u; i < n_params; ++i)
271-
gsl_vector_set(params, i, params_init[i]);
272-
273271
gsl_multimin_fdfminimizer_set(s, &loglik_bundle, params, stepsize, tol);
274272

275273
int status = 0;
@@ -281,12 +279,20 @@ fit_regression_model(Regression &r, std::vector<double> &params_init) {
281279
// check status from gradient
282280
status = gsl_multimin_test_gradient(s->gradient, tol);
283281
} while (status == GSL_CONTINUE && ++iter < max_iter);
282+
if (status != GSL_SUCCESS)
283+
throw std::runtime_error("failed to fit model parameters");
284284

285285
const auto param_estimates = gsl_multimin_fdfminimizer_x(s);
286+
287+
const auto &groups = r.design.groups;
288+
p_estimates.clear();
289+
for (auto g_idx = 0u; g_idx < n_groups; ++g_idx)
290+
p_estimates.push_back(get_p(groups[g_idx], param_estimates));
291+
const auto disp_param = gsl_vector_get(param_estimates, n_params - 1);
292+
dispersion_estimate = 1.0 / std::exp(disp_param);
293+
286294
r.max_loglik = log_likelihood(param_estimates, r);
287295

288296
gsl_multimin_fdfminimizer_free(s);
289297
gsl_vector_free(params);
290-
291-
return status == GSL_SUCCESS;
292298
}

0 commit comments

Comments
 (0)