@@ -46,7 +46,7 @@ const std::string save_init_splits = "save_init_splits";
4646const std::string keep_checkpoints = " keep_checkpoints" ;
4747const std::string use_disentangled_ssm = " use_disentangled_ssm" ;
4848const std::string use_linear_regression = " use_linear_regression" ;
49- const std::string use_mixed_effects_model = " use_mixed_effects_model " ;
49+ const std::string time_points_per_subject = " time_points_per_subject " ;
5050const std::string field_attributes = " field_attributes" ;
5151const std::string field_attribute_weights = " field_attribute_weights" ;
5252const std::string use_geodesics_to_landmarks = " use_geodesics_to_landmarks" ;
@@ -96,7 +96,7 @@ OptimizeParameters::OptimizeParameters(ProjectHandle project) {
9696 Keys::keep_checkpoints,
9797 Keys::use_disentangled_ssm,
9898 Keys::use_linear_regression,
99- Keys::use_mixed_effects_model ,
99+ Keys::time_points_per_subject ,
100100 Keys::particle_format,
101101 Keys::geodesic_remesh_percent,
102102 Keys::shared_boundary,
@@ -212,10 +212,10 @@ bool OptimizeParameters::get_use_linear_regression() { return params_.get(Keys::
212212void OptimizeParameters::set_use_linear_regression (bool value) { params_.set (Keys::use_linear_regression, value); }
213213
214214// ---------------------------------------------------------------------------
215- bool OptimizeParameters::get_use_mixed_effects_model () { return params_.get (Keys::use_mixed_effects_model, false ); }
215+ int OptimizeParameters::get_time_points_per_subject () { return params_.get (Keys::time_points_per_subject, 1 ); }
216216
217217// ---------------------------------------------------------------------------
218- void OptimizeParameters::set_use_mixed_effects_model ( bool value) { params_.set (Keys::use_mixed_effects_model , value); }
218+ void OptimizeParameters::set_time_points_per_subject ( int value) { params_.set (Keys::time_points_per_subject , value); }
219219
220220// ---------------------------------------------------------------------------
221221bool OptimizeParameters::get_use_procrustes_scaling () { return params_.get (Keys::procrustes_scaling, false ); }
@@ -449,6 +449,9 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {
449449 optimize->SetMeshFFCMode (get_mesh_ffc_mode ());
450450 optimize->SetUseDisentangledSpatiotemporalSSM (get_use_disentangled_ssm ());
451451 optimize->set_particle_format (get_particle_format ());
452+ optimize->SetTimePtsPerSubject (get_time_points_per_subject ());
453+ optimize->SetUseRegression (get_use_linear_regression ());
454+ optimize->SetUseMixedEffects (get_time_points_per_subject () > 1 ? true : false );
452455 optimize->SetSharedBoundaryEnabled (get_shared_boundary ());
453456 optimize->SetSharedBoundaryWeight (get_shared_boundary_weight ());
454457
@@ -644,6 +647,21 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {
644647 }
645648 }
646649
650+ // get explanatory variables for subjects if used for regression
651+ if (get_use_linear_regression ())
652+ {
653+ std::vector<double > exp_vars;
654+ for (const auto & s : subjects) {
655+ exp_vars.push_back (s->get_explanatory_variable ());
656+ }
657+ dynamic_cast <LinearRegressionShapeMatrix*>(
658+ optimize->GetSampler ()->GetEnsembleRegressionEntropyFunction ()->GetShapeMatrix ())
659+ ->SetExplanatory (exp_vars);
660+ dynamic_cast <MixedEffectsShapeMatrix*>(
661+ optimize->GetSampler ()->GetEnsembleMixedEffectsEntropyFunction ()->GetShapeMatrix ())
662+ ->SetExplanatory (exp_vars);
663+ }
664+
647665 std::vector<std::string> filenames;
648666 int count = 0 ;
649667 domain_count = 0 ;
0 commit comments