Skip to content

Commit 78945aa

Browse files
committed
minor bug fix
1 parent a00ec4e commit 78945aa

1 file changed

Lines changed: 43 additions & 1 deletion

File tree

scripts/models.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def forward_from_projected_and_residual(
6262
):
6363
if cavs_list is not None:
6464
y_hat_x_avs, y_hat_remainder = self.disentangle_avs_x_cavs(
65-
y_projected, y_residual, cavs_list
65+
y_projected, y_residual, cavs_list, mute_x_avs, mute_remainder
6666
)
6767

6868
if mute_x_avs: # which part of the activations to use
@@ -98,3 +98,45 @@ def project_post_attn_to_pca(self, y):
9898
return y_residual, y_projected
9999
else:
100100
return y, None
101+
102+
def disentangle_avs_x_cavs(
103+
self, y_projected, y_residual, cavs_list, mute_x_avs=False, mute_remainder=False
104+
):
105+
"Given a list of CAVs, disentangle the activations"
106+
y_all = torch.cat(
107+
[y_projected, y_residual], dim=1
108+
) # merge projected and residual activations because cavs is computed on projected + residual
109+
110+
cavs_matrix = torch.stack(cavs_list, dim=1).to(y_all.device) # [#dims, #cavs]
111+
112+
if cavs_matrix.shape[1] > cavs_matrix.shape[0]:
113+
print(
114+
f"Warning: CAVs matrix has more CAVs than dimensions! Remainder should be super close to 0"
115+
)
116+
117+
# check the rank of cavs_matrix first
118+
mrank = torch.linalg.matrix_rank(cavs_matrix)
119+
120+
cavs_ortho_matrix = (
121+
torch.linalg.qr(cavs_matrix, mode="reduced").Q[:, :mrank].detach()
122+
) # [#dims, #cavs], then keep the first mrank orthogonal basis as the remaining ones should be close to 0 and meaningless
123+
assert torch.allclose(
124+
cavs_ortho_matrix.T @ cavs_ortho_matrix,
125+
torch.eye(mrank).to(cavs_ortho_matrix.device),
126+
atol=1e-3,
127+
rtol=1e-3,
128+
), f"Q^TQ is not identity matrix! Please check the CAVs matrix. {cavs_ortho_matrix.T @ cavs_ortho_matrix}"
129+
y_x_avs = y_all @ cavs_ortho_matrix @ cavs_ortho_matrix.T
130+
y_remainder = y_all - y_x_avs # [# batches, # dims]
131+
132+
dim_projected = y_projected.shape[1] if y_projected is not None else 0
133+
134+
if mute_x_avs:
135+
y_x_avs.register_hook(lambda grad: torch.zeros_like(grad))
136+
if mute_remainder:
137+
y_remainder.register_hook(lambda grad: torch.zeros_like(grad))
138+
139+
return (
140+
y_x_avs[:, :dim_projected] + y_remainder[:, :dim_projected],
141+
y_x_avs[:, dim_projected:] + y_remainder[:, dim_projected:],
142+
)

0 commit comments

Comments
 (0)