Skip to content

Commit 971c733

Browse files
committed
fixing radians, dataset building, fape
1 parent e85bee5 commit 971c733

10 files changed

Lines changed: 1254 additions & 3266 deletions

File tree

foldtree2/notebooks/experiments/dd.ipynb

Lines changed: 0 additions & 2732 deletions
This file was deleted.

foldtree2/notebooks/experiments/test_monodecoders.ipynb

Lines changed: 698 additions & 319 deletions
Large diffs are not rendered by default.

foldtree2/notebooks/monomer_graph_trainingdata.ipynb

Lines changed: 73 additions & 72 deletions
Large diffs are not rendered by default.

foldtree2/src/encoder.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,9 +284,22 @@ def forward(self, data, edge_attr_dict=None, **kwargs):
284284
for i, convs in enumerate(self.body['convs']):
285285
# Apply graph convolutions and average over all edge types
286286
if edge_attr_dict is not None:
287-
x_list = [conv(x, edge_index=edge_index_dict[tuple(edge_type.split('_'))],
288-
edge_attr = edge_attr_dict[tuple(edge_type.split('_') )] )
289-
for edge_type, conv in convs.items()]
287+
x_list = []
288+
for edge_type, conv in convs.items():
289+
edge_key = tuple(edge_type.split('_'))
290+
edge_attr = edge_attr_dict[edge_key]
291+
292+
# Normalize edge attributes to [num_edges, edge_dim] for TransformerConv.
293+
if edge_attr is not None and edge_attr.dim() == 1:
294+
edge_attr = edge_attr.unsqueeze(-1)
295+
if edge_attr is not None and edge_attr.size(-1) != self.edge_dim:
296+
if edge_attr.size(-1) > self.edge_dim:
297+
edge_attr = edge_attr[:, :self.edge_dim]
298+
else:
299+
pad_cols = self.edge_dim - edge_attr.size(-1)
300+
edge_attr = F.pad(edge_attr, (0, pad_cols))
301+
302+
x_list.append(conv(x, edge_index=edge_index_dict[edge_key], edge_attr=edge_attr))
290303
else:
291304
x_list = [conv(x, edge_index=edge_index_dict[tuple(edge_type.split('_'))])
292305
for edge_type, conv in convs.items()]

foldtree2/src/losses/fape.py

Lines changed: 72 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -334,40 +334,23 @@ def quaternion_to_rotation_matrix(quat):
334334

335335
def compute_chain_positions(quaternions, translations, reference_coords=None):
336336
"""
337-
Apply rotation (quaternion) and translation to a set of 3D reference coordinates using PyTorch.
337+
Build chain coordinates from quaternion + translation predictions.
338338
339339
Parameters:
340340
- quaternions: (N, 4) tensor of quaternions (w, x, y, z) - scalar first
341-
- translations: (N, 3) tensor of translations (tx, ty, tz)
341+
- translations: (N, 3) tensor of per-step translations
342342
- reference_coords: (M, 3) tensor of reference points (default is [[0, 0, 0]])
343343
344344
Returns:
345-
- transformed_coords: (N, 3) tensor of transformed coordinates
345+
- transformed_coords: (N, 3) tensor of reconstructed coordinates
346346
"""
347-
device = quaternions.device
348-
quaternions = quaternions / quaternions.norm(dim=-1, keepdim=True) # Normalize quaternions
349-
350-
if reference_coords is None:
351-
reference_coords = torch.zeros(1, 3, device=device)
352-
353-
N = quaternions.shape[0]
354-
355-
w, x, y, z = quaternions.unbind(-1)
356-
357-
# Rotation matrix components
358-
R = torch.stack([
359-
1 - 2*(y**2 + z**2), 2*(x*y - z*w), 2*(x*z + y*w),
360-
2*(x*y + z*w), 1 - 2*(x**2 + z**2), 2*(y*z - x*w),
361-
2*(x*z - y*w), 2*(y*z + x*w), 1 - 2*(x**2 + y**2)
362-
], dim=-1).reshape(N, 3, 3)
363-
364-
# Apply rotation to reference coordinates (take first point if multiple)
365-
rotated = torch.matmul(reference_coords[0:1], R.transpose(1,2)).squeeze(0) # (N, 3)
366-
367-
# Apply translation
368-
transformed_coords = rotated + translations
369-
370-
return transformed_coords
347+
if reference_coords is not None:
348+
# Legacy argument kept for API compatibility. Chain reconstruction does not
349+
# use an external reference point.
350+
pass
351+
R = quaternion_to_rotation_matrix(quaternions)
352+
# For generic RT chain predictions, translations are interpreted in local frame.
353+
return reconstruct_positions(R, translations, translation_frame='local', include_origin=False)
371354

372355

373356
def compute_chain_positions_rotmat(rotations, translations):
@@ -381,37 +364,14 @@ def compute_chain_positions_rotmat(rotations, translations):
381364
Returns:
382365
Tensor: Shape (*, N, 3) global coordinates for each position
383366
"""
384-
# Handle batched or unbatched input
385-
orig_shape = rotations.shape[:-2]
386367
if rotations.ndim == 3:
387-
rotations = rotations.unsqueeze(0)
388-
translations = translations.unsqueeze(0)
368+
return reconstruct_positions(rotations, translations, translation_frame='local', include_origin=False)
389369

390-
batch_size = rotations.shape[0]
391-
N = rotations.shape[1]
392-
positions = []
393-
394-
for b in range(batch_size):
395-
# Initialize starting position and rotation
396-
global_R = torch.eye(3, dtype=rotations.dtype, device=rotations.device)
397-
curr_pos = torch.zeros(3, dtype=translations.dtype, device=translations.device)
398-
chain_positions = []
399-
400-
for i in range(N):
401-
chain_positions.append(curr_pos.clone())
402-
# Update global rotation and position
403-
global_R = torch.matmul(global_R, rotations[b, i])
404-
curr_pos = curr_pos + torch.matmul(global_R, translations[b, i])
405-
406-
positions.append(torch.stack(chain_positions))
407-
408-
positions = torch.stack(positions)
409-
410-
# Return to original shape if unbatched input
411-
if len(orig_shape) == 1:
412-
positions = positions.squeeze(0)
413-
414-
return positions
370+
# Batched input: process each structure and stack
371+
coords = []
372+
for b in range(rotations.shape[0]):
373+
coords.append(reconstruct_positions(rotations[b], translations[b], translation_frame='local', include_origin=False))
374+
return torch.stack(coords, dim=0)
415375

416376

417377

@@ -466,13 +426,14 @@ def transform_rt_to_coordinates(rotations, translations):
466426
"""
467427
Convert R, t matrices into global 3D coordinates.
468428
"""
469-
batch_size, num_residues, _ = rotations.shape
470-
coords = torch.zeros((batch_size, num_residues, 3), device=rotations.device)
471-
for b in range(batch_size):
472-
transform = torch.eye(4, device=rotations.device)
473-
for i in range(num_residues):
474-
pass # Implementation needed
475-
return coords
429+
if rotations.ndim == 3:
430+
return reconstruct_positions(rotations, translations, translation_frame='local', include_origin=False)
431+
if rotations.ndim != 4:
432+
raise ValueError(f"Expected rotations ndim 3 or 4, got {rotations.ndim}")
433+
coords = []
434+
for b in range(rotations.shape[0]):
435+
coords.append(reconstruct_positions(rotations[b], translations[b], translation_frame='local', include_origin=False))
436+
return torch.stack(coords, dim=0)
476437

477438

478439
# ============================================================================
@@ -937,36 +898,60 @@ def rotation_matrix_to_quaternion(rot_matrices):
937898
return quat
938899

939900

940-
def reconstruct_positions(R, T , batch_idx=None):
901+
def reconstruct_positions(R, T, batch_idx=None, translation_frame='global', include_origin=True):
941902
"""
942-
Reconstruct 3D CA positions from CA-to-CA displacement vectors.
903+
Reconstruct CA positions from rotations and translations.
943904
944-
T[i] = CA[i+1] - CA[i] in the global frame, so positions are simply the
945-
cumulative sum of translations starting from the origin. The rotation
946-
matrices R are kept as a parameter for API compatibility but are not used
947-
here — they represent the *local frame orientation*, not a transform that
948-
should be applied to global-frame displacements.
905+
Supports two translation conventions:
906+
- `global`: T[i] is already in global frame (e.g., CA[i+1] - CA[i]).
907+
- `local`: T[i] is in the current local frame and must be rotated into
908+
global coordinates as the chain is composed.
949909
950910
Args:
951-
R (torch.Tensor): Local rotation matrices, shape (N, 3, 3) [unused]
952-
T (torch.Tensor): CA-to-CA displacement vectors, shape (N, 3)
911+
R (torch.Tensor): Rotation matrices, shape (N, 3, 3).
912+
T (torch.Tensor): Translation vectors, shape (N, 3).
913+
batch_idx (torch.Tensor, optional): Per-residue batch indices (N,).
914+
translation_frame (str): 'global' or 'local'.
915+
include_origin (bool): If True, prepend origin row for each chain.
953916
954917
Returns:
955-
torch.Tensor: Reconstructed CA positions of shape (N+1, 3), starting
956-
from the origin.
918+
torch.Tensor:
919+
- Unbatched: (N+1, 3) if include_origin else (N, 3)
920+
- Batched: concatenation of per-chain outputs in batch order.
957921
"""
958-
if batch_idx is not None:
959-
# Handle batched input
960-
unique_batches = batch_idx.unique()
961-
positions = []
962-
for b in unique_batches:
963-
mask = batch_idx == b
964-
T_b = T[mask]
965-
origin = torch.zeros(1, 3, dtype=T.dtype, device=T.device)
966-
positions.append(torch.cat([origin, torch.cumsum(T_b, dim=0)], dim=0))
967-
return torch.cat(positions, dim=0)
968-
else:
969-
origin = torch.zeros(1, 3, dtype=T.dtype, device=T.device)
970-
return torch.cat([origin, torch.cumsum(T, dim=0)], dim=0)
922+
if translation_frame not in ('global', 'local'):
923+
raise ValueError(f"Unknown translation_frame: {translation_frame}")
924+
925+
def _reconstruct_single(R_s, T_s):
926+
if translation_frame == 'global':
927+
pos_no_origin = torch.cumsum(T_s, dim=0)
928+
else:
929+
# Compose rigid transforms along chain: x_{i+1} = x_i + R_global @ t_i
930+
N = T_s.shape[0]
931+
curr_pos = torch.zeros(3, dtype=T_s.dtype, device=T_s.device)
932+
curr_R = torch.eye(3, dtype=T_s.dtype, device=T_s.device)
933+
positions = []
934+
for i in range(N):
935+
step_global = curr_R @ T_s[i]
936+
curr_pos = curr_pos + step_global
937+
positions.append(curr_pos.clone())
938+
curr_R = curr_R @ R_s[i]
939+
pos_no_origin = torch.stack(positions, dim=0) if positions else T_s.new_zeros((0, 3))
940+
941+
if include_origin:
942+
origin = torch.zeros(1, 3, dtype=T_s.dtype, device=T_s.device)
943+
return torch.cat([origin, pos_no_origin], dim=0)
944+
return pos_no_origin
945+
946+
if batch_idx is None:
947+
return _reconstruct_single(R, T)
948+
949+
outs = []
950+
for b in torch.unique(batch_idx):
951+
mask = batch_idx == b
952+
outs.append(_reconstruct_single(R[mask], T[mask]))
953+
if len(outs) == 0:
954+
return torch.zeros((0, 3), dtype=T.dtype, device=T.device)
955+
return torch.cat(outs, dim=0)
971956

972957

foldtree2/src/losses/losses.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,7 +1198,7 @@ def ss_reconstruction_loss(ss, recon_ss, mask_plddt=False, plddt_threshold=0.3 ,
11981198
return ss_loss
11991199

12001200

1201-
def angles_reconstruction_loss(true, pred, beta=0.5 , plddt_mask = None , plddt_thresh = 0.3 , normalize = False , convert_to_radians = True):
1201+
def angles_reconstruction_loss(true, pred, beta=0.5 , plddt_mask = None , plddt_thresh = 0.3 , normalize = False):
12021202
"""Compute backbone dihedral angle reconstruction loss with circular distance.
12031203
12041204
This loss trains the decoder to predict protein backbone torsion angles (phi, psi, omega)
@@ -1237,19 +1237,14 @@ def angles_reconstruction_loss(true, pred, beta=0.5 , plddt_mask = None , plddt_
12371237
... )
12381238
12391239
Note:
1240-
Angles are computed from PDB coordinates during preprocessing using
1241-
BioPython's calc_dihedral function. They represent protein backbone geometry.
1242-
Circular distance is essential because 179° and -179° are actually very close!
1240+
Angles are expected in radians throughout the training path.
1241+
Circular distance is essential because angles near +π and -π are actually very close.
12431242
12441243
Reference:
12451244
Smooth L1 (Huber) loss: Girshick, R. (2015). Fast R-CNN. ICCV.
12461245
"""
12471246
# Compute circular angular difference in [-π, π]
1248-
# atan2 correctly handles the wrap-around at ±180°
1249-
1250-
if convert_to_radians:
1251-
true = true * (torch.pi / 180.0)
1252-
pred = pred * (torch.pi / 180.0)
1247+
# atan2 correctly handles the wrap-around at ±pi
12531248

12541249
delta = torch.atan2(torch.sin(pred - true), torch.cos(pred - true))
12551250

foldtree2/src/mono_decoders.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,7 @@ def forward(self, data , contact_pred_index, **kwargs):
274274

275275
if self.angles_mlp is not None:
276276
angles = self.angles_mlp( z )
277-
#tanh is -1 to 1, multiply by 180 to get angles in degrees
278-
angles = angles * 180 # Scale from [-1, 1] to [-180, 180]
277+
angles = angles * torch.pi # Scale from [-1, 1] to [-pi, pi]
279278

280279
if contact_pred_index is None:
281280
return { 'edge_probs': None , 'zgodnode' :None , 'fft2pred':fft2_pred , 'rt_pred': None , 'angles': angles , 'edge_logits': edge_logits , 'ss_pred': ss_pred , 'z': z }
@@ -581,7 +580,7 @@ def forward(self, data, contact_pred_index, **kwargs):
581580
angles = None
582581
if 'angles_mlp' in self.head:
583582
angles = self.head['angles_mlp'](z)
584-
angles = angles * 180 # Scale from [-1, 1] to [-180, 180]
583+
angles = angles * torch.pi # Scale from [-1, 1] to [-pi, pi]
585584

586585
# Contact prediction
587586
edge_logits = None
@@ -1350,7 +1349,7 @@ def forward(self, data, contact_pred_index=None, **kwargs):
13501349
if self.output_angles and 'angles_cnn' in self.head:
13511350
angles_out = self.head['angles_cnn'](xi_cnn) # (1, 3, seq_len)
13521351
angles_out = angles_out.permute(2, 0, 1).squeeze(1) # (seq_len, 3)
1353-
angles_out = angles_out * 180 # Scale from [-1, 1] to [-180, 180]
1352+
angles_out = angles_out * torch.pi # Scale from [-1, 1] to [-pi, pi]
13541353
angles_list.append(angles_out)
13551354
else:
13561355
# DNN decoder path
@@ -1371,7 +1370,7 @@ def forward(self, data, contact_pred_index=None, **kwargs):
13711370
ss_pred = torch.cat(ss_list, dim=0)
13721371
if angles_list:
13731372
angles = torch.cat(angles_list, dim=0)
1374-
angles = angles * 180 # Scale from [-1, 1] to [-180, 180]
1373+
angles = angles * torch.pi # Scale from [-1, 1] to [-pi, pi]
13751374
else:
13761375
# Single graph case
13771376
if use_cnn:
@@ -1393,7 +1392,7 @@ def forward(self, data, contact_pred_index=None, **kwargs):
13931392
if self.output_angles and 'angles_cnn' in self.head:
13941393
angles = self.head['angles_cnn'](x_cnn) # (1, 3, seq_len)
13951394
angles = angles.permute(2, 0, 1).squeeze(1) # (seq_len, 3)
1396-
angles = angles * 180 # Scale from [-1, 1] to [-180, 180]
1395+
angles = angles * torch.pi # Scale from [-1, 1] to [-pi, pi]
13971396
else:
13981397
# DNN decoder path
13991398
x = x.squeeze(1) # (seq_len, d_model)
@@ -1406,7 +1405,7 @@ def forward(self, data, contact_pred_index=None, **kwargs):
14061405

14071406
if self.output_angles and 'angles_head' in self.head:
14081407
angles = self.head['angles_head'](x)
1409-
angles = angles * 180 # Scale from [-1, 1] to [-180, 180]
1408+
angles = angles * torch.pi # Scale from [-1, 1] to [-pi, pi]
14101409

14111410
# Normalize quaternion part (first 4 dims) of rt_pred for proper geometry
14121411
if rt_pred is not None:

0 commit comments

Comments
 (0)