Skip to content

Commit 1650c9f

Browse files
authored
Merge pull request #2056 from SCIInstitute/disentangled_4d_ssm
Spatiotemporal SSM - Disentangled Approach
2 parents 700e4e3 + 7f2fd37 commit 1650c9f

12 files changed

Lines changed: 646 additions & 3 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
.DS_*
3636
.directory
3737
.idea/
38+
.vscode/
3839

3940
# PYC files
4041
*.pyc

Libs/Optimize/CorrespondenceMode.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ namespace shapeworks {
88
EnsembleRegressionEntropy = 3,
99
EnsembleMixedEffectsEntropy = 4,
1010
MeshBasedGeneralEntropy = 5,
11-
MeshBasedGeneralMeanEnergy = 6
11+
MeshBasedGeneralMeanEnergy = 6,
12+
DisentagledEnsembleEntropy = 7,
13+
DisentangledEnsembleMeanEnergy = 8
1214
};
1315
}
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
#include "DisentangledCorrespondenceFunction.h"
2+
#include <string>
3+
#include "Libs/Optimize/Domain/ImageDomainWithGradients.h"
4+
#include "Libs/Optimize/Utils/ParticleGaussianModeWriter.h"
5+
#include "Libs/Utils/Utils.h"
6+
#include <tbb/parallel_for.h>
7+
#include "vnl/algo/vnl_symmetric_eigensystem.h"
8+
9+
namespace shapeworks {
10+
void DisentangledCorrespondenceFunction ::WriteModes(const std::string& prefix, int n) const {
11+
typename ParticleGaussianModeWriter<VDimension>::Pointer writer = ParticleGaussianModeWriter<VDimension>::New();
12+
writer->SetShapeMatrix(m_ShapeMatrix);
13+
writer->SetFileName(prefix.c_str());
14+
writer->SetNumberOfModes(n);
15+
writer->Update();
16+
}
17+
18+
void DisentangledCorrespondenceFunction::ComputeCovarianceMatrices() {
19+
const unsigned int num_N = m_ShapeMatrix->cols(); // Total Number of subjects
20+
const unsigned int num_T = m_ShapeMatrix->GetDomainsPerShape(); // Total Number of Time points
21+
const unsigned int num_dims = m_ShapeMatrix->rows() / num_T; // (dM X T) / T = dM
22+
this->Initialize();
23+
// computation across time cohort
24+
tbb::parallel_for(tbb::blocked_range<size_t>{0, num_T}, [&](const tbb::blocked_range<size_t>& r)
25+
{
26+
// Iterate t = 1....T
27+
for (size_t time_inst = r.begin(); time_inst < r.end(); ++time_inst) {
28+
// Build objective matrix Z
29+
vnl_matrix_type z;
30+
z.clear();
31+
z.set_size(num_dims, num_N);
32+
z.fill(0.0);
33+
unsigned int row_idx_start = time_inst * num_dims;
34+
z = m_ShapeMatrix->extract(num_dims, num_N, row_idx_start, 0);
35+
36+
// Resize Gradient Updates matrix for current time instance
37+
if (m_Time_PointsUpdate->at(time_inst).rows() != num_dims || m_Time_PointsUpdate->at(time_inst).cols() != num_N)
38+
{
39+
m_Time_PointsUpdate->at(time_inst).set_size(num_dims, num_N);
40+
m_Time_PointsUpdate->at(time_inst).fill(0.0);
41+
}
42+
43+
// Compute mean and mean centred objective matrix for current time instance t_i
44+
vnl_matrix_type points_minus_mean_t;
45+
points_minus_mean_t.clear();
46+
points_minus_mean_t.set_size(num_dims, num_N);
47+
points_minus_mean_t.fill(0.0);
48+
Eigen::MatrixXd inv_cov_t;
49+
inv_cov_t.setZero();
50+
51+
m_points_mean_time_cohort->at(time_inst).clear();
52+
m_points_mean_time_cohort->at(time_inst).set_size(num_dims, 1);
53+
54+
for (unsigned int j = 0; j < num_dims; ++j)
55+
{
56+
double sum_across_col = 0.0;
57+
for (unsigned int i = 0; i < num_N; ++i)
58+
{
59+
sum_across_col += z(j, i);
60+
}
61+
m_points_mean_time_cohort->at(time_inst).put(j,0, sum_across_col/(double)num_N);
62+
}
63+
64+
for (unsigned int j = 0; j < num_dims; j++)
65+
{
66+
for (unsigned int i = 0; i < num_N; i++)
67+
{
68+
points_minus_mean_t(j, i) = z(j, i) - m_points_mean_time_cohort->at(time_inst).get(j,0);
69+
}
70+
}
71+
72+
vnl_diag_matrix<double> W_t;
73+
vnl_matrix_type gramMat_t(num_N, num_N, 0.0); // gram matrix = Y^T X Y
74+
vnl_matrix_type pinvMat_t(num_N, num_N, 0.0); // inverse of gram Matrix
75+
76+
if (this->m_UseMeanEnergy)
77+
{
78+
pinvMat_t.set_identity();
79+
m_InverseCovMatrices_time_cohort->at(time_inst).setZero();
80+
}
81+
else
82+
{
83+
gramMat_t = points_minus_mean_t.transpose()* points_minus_mean_t;
84+
vnl_svd <double> svd(gramMat_t);
85+
vnl_matrix_type UG = svd.U();
86+
W_t = svd.W();
87+
vnl_diag_matrix<double> invLambda_t = svd.W();
88+
invLambda_t.set_diagonal(invLambda_t.get_diagonal()/(double)(num_N-1) + m_MinimumVariance);
89+
invLambda_t.invert_in_place();
90+
91+
pinvMat_t = (UG * invLambda_t) * UG.transpose();
92+
vnl_matrix_type projMat_t = points_minus_mean_t * UG;
93+
const auto lhs = projMat_t * invLambda_t;
94+
const auto rhs = invLambda_t * projMat_t.transpose();
95+
inv_cov_t.resize(num_dims, num_dims);
96+
Utils::multiply_into(inv_cov_t, lhs, rhs);
97+
}
98+
// Update Gradient points update infor
99+
m_Time_PointsUpdate->at(time_inst).update(points_minus_mean_t * pinvMat_t);
100+
double currentEnergy_t = 0.0;
101+
if (m_UseMeanEnergy) currentEnergy_t = points_minus_mean_t.frobenius_norm();
102+
else
103+
{
104+
m_MinimumEigenValue_time_cohort[time_inst] = W_t(0)*W_t(0) + m_MinimumVariance;
105+
for (unsigned int i = 0; i < num_N; i++)
106+
{
107+
double val_i = W_t(i)*W_t(i) + m_MinimumVariance;
108+
if ( val_i < m_MinimumEigenValue_time_cohort[time_inst])
109+
m_MinimumEigenValue_time_cohort[time_inst] = val_i;
110+
currentEnergy_t += log(val_i);
111+
}
112+
}
113+
currentEnergy_t /= 2.0;
114+
if (m_UseMeanEnergy) m_MinimumEigenValue_time_cohort[time_inst] = currentEnergy_t / 2.0;
115+
// Update Inv Covariance Matrix
116+
m_InverseCovMatrices_time_cohort->at(time_inst) = inv_cov_t;
117+
}
118+
});
119+
120+
// computation across shape cohort
121+
tbb::parallel_for(tbb::blocked_range<size_t>{0, num_N}, [&](const tbb::blocked_range<size_t>& r)
122+
{
123+
// Iterate n = 1....N
124+
for (size_t sub = r.begin(); sub < r.end(); ++sub) {
125+
// Build objective matrix Z
126+
vnl_matrix_type z;
127+
z.clear();
128+
z.set_size(num_dims, num_N);
129+
z.fill(0.0);
130+
for(unsigned int t = 0; t < num_T; ++t){
131+
unsigned int row_start = num_dims * t;
132+
vnl_matrix_type time_vec = m_ShapeMatrix->extract(num_dims, 1, row_start, sub);
133+
z.set_columns(t, time_vec);
134+
}
135+
136+
// Resize Gradient Updates matrix for current time instance
137+
if (m_Shape_PointsUpdate->at(sub).rows() != num_dims || m_Shape_PointsUpdate->at(sub).cols() != num_T)
138+
{
139+
m_Shape_PointsUpdate->at(sub).set_size(num_dims, num_T);
140+
m_Shape_PointsUpdate->at(sub).fill(0.0);
141+
}
142+
143+
// Compute mean and mean centred objective matrix for current time instance t_i
144+
vnl_matrix_type points_minus_mean_n;
145+
points_minus_mean_n.clear();
146+
points_minus_mean_n.set_size(num_dims, num_T);
147+
points_minus_mean_n.fill(0.0);
148+
Eigen::MatrixXd inv_cov_n;
149+
inv_cov_n.setZero();
150+
151+
m_points_mean_shape_cohort->at(sub).clear();
152+
m_points_mean_shape_cohort->at(sub).set_size(num_dims, 1);
153+
154+
for (unsigned int j = 0; j < num_dims; ++j)
155+
{
156+
double sum_across_col = 0.0;
157+
for (unsigned int i = 0; i < num_T; ++i)
158+
{
159+
sum_across_col += z(j, i);
160+
}
161+
m_points_mean_shape_cohort->at(sub).put(j,0, sum_across_col/(double)num_T);
162+
}
163+
164+
for (unsigned int j = 0; j < num_dims; j++)
165+
{
166+
for (unsigned int i = 0; i < num_T; i++)
167+
{
168+
points_minus_mean_n(j, i) = z(j, i) - m_points_mean_shape_cohort->at(sub).get(j,0);
169+
}
170+
}
171+
172+
vnl_diag_matrix<double> W_n;
173+
vnl_matrix_type gramMat_n(num_T, num_T, 0.0); // gram matrix = Y^T X Y
174+
vnl_matrix_type pinvMat_n(num_T, num_T, 0.0); // inverse of gram Matrix
175+
176+
if (this->m_UseMeanEnergy)
177+
{
178+
pinvMat_n.set_identity();
179+
m_InverseCovMatrices_shape_cohort->at(sub).setZero();
180+
}
181+
else
182+
{
183+
gramMat_n = points_minus_mean_n.transpose() * points_minus_mean_n;
184+
vnl_svd <double> svd(gramMat_n);
185+
vnl_matrix_type UG = svd.U();
186+
W_n = svd.W();
187+
vnl_diag_matrix<double> invLambda_n = svd.W();
188+
invLambda_n.set_diagonal(invLambda_n.get_diagonal()/(double)(num_T-1) + m_MinimumVariance);
189+
invLambda_n.invert_in_place();
190+
191+
pinvMat_n = (UG * invLambda_n) * UG.transpose();
192+
vnl_matrix_type projMat_n = points_minus_mean_n * UG;
193+
const auto lhs = projMat_n * invLambda_n;
194+
const auto rhs = invLambda_n * projMat_n.transpose();
195+
inv_cov_n.resize(num_dims, num_dims);
196+
Utils::multiply_into(inv_cov_n, lhs, rhs);
197+
}
198+
199+
// Update Gradient points update infor
200+
m_Shape_PointsUpdate->at(sub).update(points_minus_mean_n * pinvMat_n);
201+
double currentEnergy_n = 0.0;
202+
if (m_UseMeanEnergy) currentEnergy_n = points_minus_mean_n.frobenius_norm();
203+
else
204+
{
205+
m_MinimumEigenValue_shape_cohort[sub] = W_n(0)*W_n(0) + m_MinimumVariance;
206+
for (unsigned int i = 0; i < num_T; i++)
207+
{
208+
double val_i = W_n(i) * W_n(i) + m_MinimumVariance;
209+
if (val_i < m_MinimumEigenValue_shape_cohort[sub])
210+
m_MinimumEigenValue_shape_cohort[sub] = val_i;
211+
currentEnergy_n += log(val_i);
212+
}
213+
}
214+
currentEnergy_n /= 2.0;
215+
if (m_UseMeanEnergy) m_MinimumEigenValue_shape_cohort[sub] = currentEnergy_n / 2.0;
216+
// Update Inv Covariance Matrix
217+
m_InverseCovMatrices_shape_cohort->at(sub) = inv_cov_n;
218+
}
219+
});
220+
221+
}
222+
223+
DisentangledCorrespondenceFunction::VectorType DisentangledCorrespondenceFunction ::Evaluate(unsigned int idx, unsigned int d,
224+
const ParticleSystem* system,
225+
double& maxdt,
226+
double& energy) const {
227+
228+
const unsigned int num_N = m_ShapeMatrix->cols(); // Total number of subjects
229+
const unsigned int num_T = m_ShapeMatrix->GetDomainsPerShape(); // Total number of time points
230+
231+
const unsigned int cur_sub = d / num_T; // index of current subject
232+
const unsigned int cur_time_point = d % num_T;
233+
234+
// maximum update possible = sum of max possible updates across both cohorts (time and shape)
235+
maxdt = m_MinimumEigenValue_shape_cohort[cur_sub] + m_MinimumEigenValue_time_cohort[cur_time_point];
236+
237+
VectorType gradE; // gradient update vector for Point defined by ParticleSystem system of domain d and dimension index idx
238+
unsigned int shape_matrix_start_idx = 0;
239+
240+
int dom = d % num_T;
241+
for (int i = 0; i < dom; i++)
242+
shape_matrix_start_idx += system->GetNumberOfParticles(i) * VDimension;
243+
shape_matrix_start_idx += idx*VDimension;
244+
245+
unsigned int particle_idx = VDimension * idx;
246+
247+
// Energy computation across time cohort
248+
vnl_matrix_type Xi_time_cohort(3,1,0.0);
249+
Xi_time_cohort(0,0) = m_ShapeMatrix->operator()(shape_matrix_start_idx , cur_sub) - m_points_mean_time_cohort->at(cur_time_point).get(particle_idx, 0);
250+
Xi_time_cohort(1,0) = m_ShapeMatrix->operator()(shape_matrix_start_idx+1, cur_sub) - m_points_mean_time_cohort->at(cur_time_point).get(particle_idx+1, 0);
251+
Xi_time_cohort(2,0) = m_ShapeMatrix->operator()(shape_matrix_start_idx+2, cur_sub) - m_points_mean_time_cohort->at(cur_time_point).get(particle_idx+2, 0);
252+
vnl_matrix_type tmp1_time(3, 3, 0.0);
253+
if (this->m_UseMeanEnergy) {
254+
tmp1_time.set_identity();
255+
} else {
256+
Eigen::MatrixXd region = m_InverseCovMatrices_time_cohort->at(cur_time_point).block(particle_idx, particle_idx, 3, 3);
257+
// convert to vnl
258+
for (unsigned int i = 0; i < 3; i++) {
259+
for (unsigned int j = 0; j < 3; j++) {
260+
tmp1_time(i, j) = region(i, j);
261+
}
262+
}
263+
}
264+
vnl_matrix_type tmp_time = Xi_time_cohort.transpose()*tmp1_time;
265+
tmp_time *= Xi_time_cohort;
266+
267+
// Energy computation across shape cohort
268+
vnl_matrix_type Xi_shape_cohort(3,1,0.0);
269+
Xi_shape_cohort(0,0) = m_ShapeMatrix->operator()(shape_matrix_start_idx , cur_sub) - m_points_mean_shape_cohort->at(cur_sub).get(particle_idx, 0);
270+
Xi_shape_cohort(1,0) = m_ShapeMatrix->operator()(shape_matrix_start_idx+1, cur_sub) - m_points_mean_shape_cohort->at(cur_sub).get(particle_idx+1, 0);
271+
Xi_shape_cohort(2,0) = m_ShapeMatrix->operator()(shape_matrix_start_idx+2, cur_sub) - m_points_mean_shape_cohort->at(cur_sub).get(particle_idx+2, 0);
272+
vnl_matrix_type tmp1_shape(3, 3, 0.0);
273+
if (this->m_UseMeanEnergy) {
274+
tmp1_shape.set_identity();
275+
} else {
276+
Eigen::MatrixXd region = m_InverseCovMatrices_shape_cohort->at(cur_sub).block(particle_idx, particle_idx, 3, 3);
277+
// convert to vnl
278+
for (unsigned int i = 0; i < 3; i++) {
279+
for (unsigned int j = 0; j < 3; j++) {
280+
tmp1_time(i, j) = region(i, j);
281+
}
282+
}
283+
}
284+
vnl_matrix_type tmp_shape = Xi_shape_cohort.transpose()*tmp1_shape;
285+
tmp_shape *= Xi_shape_cohort;
286+
287+
// Net Energy
288+
energy = tmp_time(0,0) + tmp_shape(0, 0);
289+
290+
// Net Gradient
291+
for (unsigned int i = 0; i< VDimension; i++)
292+
{
293+
gradE[i] = m_Time_PointsUpdate->at(cur_time_point).get(particle_idx + i, cur_sub) + m_Shape_PointsUpdate->at(cur_sub).get(particle_idx + i, cur_time_point);
294+
}
295+
return system->TransformVector(gradE,
296+
system->GetInversePrefixTransform(d) *
297+
system->GetInverseTransform(d));
298+
}
299+
300+
} // namespace shapeworks

0 commit comments

Comments
 (0)