Skip to content

Commit 83baaf4

Browse files
authored
fix model forward to accomodate disentangle x cavs function
1 parent 78945aa commit 83baaf4

1 file changed

Lines changed: 6 additions & 12 deletions

File tree

scripts/models.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,26 +61,20 @@ def forward_from_projected_and_residual(
6161
mute_remainder=False,
6262
):
6363
if cavs_list is not None:
64-
y_hat_x_avs, y_hat_remainder = self.disentangle_avs_x_cavs(
64+
y_projected, y_residual = self.disentangle_avs_x_cavs(
6565
y_projected, y_residual, cavs_list, mute_x_avs, mute_remainder
6666
)
67-
68-
if mute_x_avs: # which part of the activations to use
69-
y_hat_x_avs.register_hook(lambda grad: torch.zeros_like(grad))
70-
if mute_remainder:
71-
y_hat_remainder.register_hook(lambda grad: torch.zeros_like(grad))
72-
y_hat = y_hat_x_avs + y_hat_remainder + self.zscore_mean
73-
else:
74-
y_hat = (
75-
self.merge_projected_and_residual(y_projected, y_residual)
76-
+ self.zscore_mean
77-
)
67+
y_hat = (
68+
self.merge_projected_and_residual(y_projected, y_residual)
69+
+ self.zscore_mean
70+
)
7871

7972
if self.orig_shape is not None:
8073
y_hat = y_hat.reshape((-1, *self.orig_shape[1:]))
8174

8275
# resume back to normal forward process
8376
y_hat = self.resume_forward_from_select_layer(y_hat)
77+
8478
return y_hat
8579

8680
def project_avs_to_pca(self, y):

0 commit comments

Comments
 (0)