@@ -62,7 +62,7 @@ def forward_from_projected_and_residual(
6262 ):
6363 if cavs_list is not None :
6464 y_hat_x_avs , y_hat_remainder = self .disentangle_avs_x_cavs (
65- y_projected , y_residual , cavs_list
65+ y_projected , y_residual , cavs_list , mute_x_avs , mute_remainder
6666 )
6767
6868 if mute_x_avs : # which part of the activations to use
@@ -98,3 +98,45 @@ def project_post_attn_to_pca(self, y):
9898 return y_residual , y_projected
9999 else :
100100 return y , None
101+
102+ def disentangle_avs_x_cavs (
103+ self , y_projected , y_residual , cavs_list , mute_x_avs = False , mute_remainder = False
104+ ):
105+ "Given a list of CAVs, disentangle the activations"
106+ y_all = torch .cat (
107+ [y_projected , y_residual ], dim = 1
108+ ) # merge projected and residual activations because cavs is computed on projected + residual
109+
110+ cavs_matrix = torch .stack (cavs_list , dim = 1 ).to (y_all .device ) # [#dims, #cavs]
111+
112+ if cavs_matrix .shape [1 ] > cavs_matrix .shape [0 ]:
113+ print (
114+ f"Warning: CAVs matrix has more CAVs than dimensions! Remainder should be super close to 0"
115+ )
116+
117+ # check the rank of cavs_matrix first
118+ mrank = torch .linalg .matrix_rank (cavs_matrix )
119+
120+ cavs_ortho_matrix = (
121+ torch .linalg .qr (cavs_matrix , mode = "reduced" ).Q [:, :mrank ].detach ()
122+ ) # [#dims, #cavs], then keep the first mrank orthogonal basis as the remaining ones should be close to 0 and meaningless
123+ assert torch .allclose (
124+ cavs_ortho_matrix .T @ cavs_ortho_matrix ,
125+ torch .eye (mrank ).to (cavs_ortho_matrix .device ),
126+ atol = 1e-3 ,
127+ rtol = 1e-3 ,
128+ ), f"Q^TQ is not identity matrix! Please check the CAVs matrix. { cavs_ortho_matrix .T @ cavs_ortho_matrix } "
129+ y_x_avs = y_all @ cavs_ortho_matrix @ cavs_ortho_matrix .T
130+ y_remainder = y_all - y_x_avs # [# batches, # dims]
131+
132+ dim_projected = y_projected .shape [1 ] if y_projected is not None else 0
133+
134+ if mute_x_avs :
135+ y_x_avs .register_hook (lambda grad : torch .zeros_like (grad ))
136+ if mute_remainder :
137+ y_remainder .register_hook (lambda grad : torch .zeros_like (grad ))
138+
139+ return (
140+ y_x_avs [:, :dim_projected ] + y_remainder [:, :dim_projected ],
141+ y_x_avs [:, dim_projected :] + y_remainder [:, dim_projected :],
142+ )
0 commit comments