Skip to content

Commit d67e707

Browse files
expose and link linear regression opt params
1 parent d452bfc commit d67e707

4 files changed

Lines changed: 37 additions & 8 deletions

File tree

Libs/Optimize/Optimize.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,6 +1667,12 @@ void Optimize::SetTimePtsPerSubject(int time_pts_per_subject) { this->m_timepts_
16671667
//---------------------------------------------------------------------------
16681668
int Optimize::GetTimePtsPerSubject() { return this->m_timepts_per_subject; }
16691669

1670+
//---------------------------------------------------------------------------
1671+
void Optimize::SetExplanatoryVariables(std::vector<double> val) { this->m_explanatory_variables = val; }
1672+
1673+
//---------------------------------------------------------------------------
1674+
std::vector<double> Optimize::GetExplanatoryVariables() { return this->m_explanatory_variables; }
1675+
16701676
//---------------------------------------------------------------------------
16711677
void Optimize::SetOptimizationIterations(int optimization_iterations) {
16721678
this->m_optimization_iterations = optimization_iterations;

Libs/Optimize/Optimize.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,14 @@ class Optimize {
158158
m_mesh_ffc_mode = mesh_ffc_mode;
159159
m_sampler->SetMeshFFCMode(mesh_ffc_mode);
160160
}
161-
//! Set the number of time points per subject (TODO: details)
161+
//! Set the number of time points per subject used for Linear Regression or Mixed Effects Model optimization
162162
void SetTimePtsPerSubject(int time_pts_per_subject);
163-
//! Get the number of time points per subject (TODO: details)
163+
//! Get the number of time points per subject used for Linear Regression or Mixed Effects Model optimization
164164
int GetTimePtsPerSubject();
165+
//! Set Explanatory Variable for | used for Linear Regression or Mixed Effects Model optimization
166+
void SetExplanatoryVariables(std::vector<double> vals);
167+
//! Get the number of time points per subject used for Linear Regression or Mixed Effects Model optimization
168+
std::vector<double> GetExplanatoryVariables();
165169
//! Set the number of optimization iterations
166170
void SetOptimizationIterations(int optimization_iterations);
167171
//! Set the number of optimization iterations already completed (TODO: details)
@@ -388,6 +392,7 @@ class Optimize {
388392
bool m_mesh_ffc_mode = 0;
389393

390394
unsigned int m_timepts_per_subject = 1;
395+
std::vector<double> m_explanatory_variables;
391396
int m_optimization_iterations = 2000;
392397
int m_optimization_iterations_completed = 0;
393398
int m_iterations_per_split = 1000;

Libs/Optimize/OptimizeParameters.cpp

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ const std::string save_init_splits = "save_init_splits";
4646
const std::string keep_checkpoints = "keep_checkpoints";
4747
const std::string use_disentangled_ssm = "use_disentangled_ssm";
4848
const std::string use_linear_regression = "use_linear_regression";
49-
const std::string use_mixed_effects_model = "use_mixed_effects_model";
49+
const std::string time_points_per_subject = "time_points_per_subject";
5050
const std::string field_attributes = "field_attributes";
5151
const std::string field_attribute_weights = "field_attribute_weights";
5252
const std::string use_geodesics_to_landmarks = "use_geodesics_to_landmarks";
@@ -96,7 +96,7 @@ OptimizeParameters::OptimizeParameters(ProjectHandle project) {
9696
Keys::keep_checkpoints,
9797
Keys::use_disentangled_ssm,
9898
Keys::use_linear_regression,
99-
Keys::use_mixed_effects_model,
99+
Keys::time_points_per_subject,
100100
Keys::particle_format,
101101
Keys::geodesic_remesh_percent,
102102
Keys::shared_boundary,
@@ -212,10 +212,10 @@ bool OptimizeParameters::get_use_linear_regression() { return params_.get(Keys::
212212
void OptimizeParameters::set_use_linear_regression(bool value) { params_.set(Keys::use_linear_regression, value); }
213213

214214
//---------------------------------------------------------------------------
215-
bool OptimizeParameters::get_use_mixed_effects_model() { return params_.get(Keys::use_mixed_effects_model, false); }
215+
int OptimizeParameters::get_time_points_per_subject() { return params_.get(Keys::time_points_per_subject, 1); }
216216

217217
//---------------------------------------------------------------------------
218-
void OptimizeParameters::set_use_mixed_effects_model(bool value) { params_.set(Keys::use_mixed_effects_model, value); }
218+
void OptimizeParameters::set_time_points_per_subject(int value) { params_.set(Keys::time_points_per_subject, value); }
219219

220220
//---------------------------------------------------------------------------
221221
bool OptimizeParameters::get_use_procrustes_scaling() { return params_.get(Keys::procrustes_scaling, false); }
@@ -449,6 +449,9 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {
449449
optimize->SetMeshFFCMode(get_mesh_ffc_mode());
450450
optimize->SetUseDisentangledSpatiotemporalSSM(get_use_disentangled_ssm());
451451
optimize->set_particle_format(get_particle_format());
452+
optimize->SetTimePtsPerSubject(get_time_points_per_subject());
453+
optimize->SetUseRegression(get_use_linear_regression());
454+
optimize->SetUseMixedEffects(get_time_points_per_subject() > 1 ? true : false);
452455
optimize->SetSharedBoundaryEnabled(get_shared_boundary());
453456
optimize->SetSharedBoundaryWeight(get_shared_boundary_weight());
454457

@@ -644,6 +647,21 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {
644647
}
645648
}
646649

650+
// get explanatory variables for subjects if used for regression
651+
if (get_use_linear_regression())
652+
{
653+
std::vector<double> exp_vars;
654+
for (const auto& s : subjects) {
655+
exp_vars.push_back(s->get_explanatory_variable());
656+
}
657+
dynamic_cast<LinearRegressionShapeMatrix*>(
658+
optimize->GetSampler()->GetEnsembleRegressionEntropyFunction()->GetShapeMatrix())
659+
->SetExplanatory(exp_vars);
660+
dynamic_cast<MixedEffectsShapeMatrix*>(
661+
optimize->GetSampler()->GetEnsembleMixedEffectsEntropyFunction()->GetShapeMatrix())
662+
->SetExplanatory(exp_vars);
663+
}
664+
647665
std::vector<std::string> filenames;
648666
int count = 0;
649667
domain_count = 0;

Libs/Optimize/OptimizeParameters.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ class OptimizeParameters {
6363
bool get_use_linear_regression();
6464
void set_use_linear_regression(bool value);
6565

66-
bool get_use_mixed_effects_model();
67-
void set_use_mixed_effects_model(bool value);
66+
int get_time_points_per_subject();
67+
void set_time_points_per_subject(int value);
6868

6969
bool get_use_procrustes();
7070
void set_use_procrustes(bool value);

0 commit comments

Comments
 (0)