@@ -334,40 +334,23 @@ def quaternion_to_rotation_matrix(quat):
334334
335335def compute_chain_positions (quaternions , translations , reference_coords = None ):
336336 """
337- Apply rotation ( quaternion) and translation to a set of 3D reference coordinates using PyTorch .
337+ Build chain coordinates from quaternion + translation predictions .
338338
339339 Parameters:
340340 - quaternions: (N, 4) tensor of quaternions (w, x, y, z) - scalar first
341- - translations: (N, 3) tensor of translations (tx, ty, tz)
341+ - translations: (N, 3) tensor of per-step translations
342342 - reference_coords: (M, 3) tensor of reference points (default is [[0, 0, 0]])
343343
344344 Returns:
345- - transformed_coords: (N, 3) tensor of transformed coordinates
345+ - transformed_coords: (N, 3) tensor of reconstructed coordinates
346346 """
347- device = quaternions .device
348- quaternions = quaternions / quaternions .norm (dim = - 1 , keepdim = True ) # Normalize quaternions
349-
350- if reference_coords is None :
351- reference_coords = torch .zeros (1 , 3 , device = device )
352-
353- N = quaternions .shape [0 ]
354-
355- w , x , y , z = quaternions .unbind (- 1 )
356-
357- # Rotation matrix components
358- R = torch .stack ([
359- 1 - 2 * (y ** 2 + z ** 2 ), 2 * (x * y - z * w ), 2 * (x * z + y * w ),
360- 2 * (x * y + z * w ), 1 - 2 * (x ** 2 + z ** 2 ), 2 * (y * z - x * w ),
361- 2 * (x * z - y * w ), 2 * (y * z + x * w ), 1 - 2 * (x ** 2 + y ** 2 )
362- ], dim = - 1 ).reshape (N , 3 , 3 )
363-
364- # Apply rotation to reference coordinates (take first point if multiple)
365- rotated = torch .matmul (reference_coords [0 :1 ], R .transpose (1 ,2 )).squeeze (0 ) # (N, 3)
366-
367- # Apply translation
368- transformed_coords = rotated + translations
369-
370- return transformed_coords
347+ if reference_coords is not None :
348+ # Legacy argument kept for API compatibility. Chain reconstruction does not
349+ # use an external reference point.
350+ pass
351+ R = quaternion_to_rotation_matrix (quaternions )
352+ # For generic RT chain predictions, translations are interpreted in local frame.
353+ return reconstruct_positions (R , translations , translation_frame = 'local' , include_origin = False )
371354
372355
373356def compute_chain_positions_rotmat (rotations , translations ):
@@ -381,37 +364,14 @@ def compute_chain_positions_rotmat(rotations, translations):
381364 Returns:
382365 Tensor: Shape (*, N, 3) global coordinates for each position
383366 """
384- # Handle batched or unbatched input
385- orig_shape = rotations .shape [:- 2 ]
386367 if rotations .ndim == 3 :
387- rotations = rotations .unsqueeze (0 )
388- translations = translations .unsqueeze (0 )
368+ return reconstruct_positions (rotations , translations , translation_frame = 'local' , include_origin = False )
389369
390- batch_size = rotations .shape [0 ]
391- N = rotations .shape [1 ]
392- positions = []
393-
394- for b in range (batch_size ):
395- # Initialize starting position and rotation
396- global_R = torch .eye (3 , dtype = rotations .dtype , device = rotations .device )
397- curr_pos = torch .zeros (3 , dtype = translations .dtype , device = translations .device )
398- chain_positions = []
399-
400- for i in range (N ):
401- chain_positions .append (curr_pos .clone ())
402- # Update global rotation and position
403- global_R = torch .matmul (global_R , rotations [b , i ])
404- curr_pos = curr_pos + torch .matmul (global_R , translations [b , i ])
405-
406- positions .append (torch .stack (chain_positions ))
407-
408- positions = torch .stack (positions )
409-
410- # Return to original shape if unbatched input
411- if len (orig_shape ) == 1 :
412- positions = positions .squeeze (0 )
413-
414- return positions
370+ # Batched input: process each structure and stack
371+ coords = []
372+ for b in range (rotations .shape [0 ]):
373+ coords .append (reconstruct_positions (rotations [b ], translations [b ], translation_frame = 'local' , include_origin = False ))
374+ return torch .stack (coords , dim = 0 )
415375
416376
417377
@@ -466,13 +426,14 @@ def transform_rt_to_coordinates(rotations, translations):
466426 """
467427 Convert R, t matrices into global 3D coordinates.
468428 """
469- batch_size , num_residues , _ = rotations .shape
470- coords = torch .zeros ((batch_size , num_residues , 3 ), device = rotations .device )
471- for b in range (batch_size ):
472- transform = torch .eye (4 , device = rotations .device )
473- for i in range (num_residues ):
474- pass # Implementation needed
475- return coords
429+ if rotations .ndim == 3 :
430+ return reconstruct_positions (rotations , translations , translation_frame = 'local' , include_origin = False )
431+ if rotations .ndim != 4 :
432+ raise ValueError (f"Expected rotations ndim 3 or 4, got { rotations .ndim } " )
433+ coords = []
434+ for b in range (rotations .shape [0 ]):
435+ coords .append (reconstruct_positions (rotations [b ], translations [b ], translation_frame = 'local' , include_origin = False ))
436+ return torch .stack (coords , dim = 0 )
476437
477438
478439# ============================================================================
@@ -937,36 +898,60 @@ def rotation_matrix_to_quaternion(rot_matrices):
937898 return quat
938899
939900
940- def reconstruct_positions (R , T , batch_idx = None ):
901+ def reconstruct_positions (R , T , batch_idx = None , translation_frame = 'global' , include_origin = True ):
941902 """
942- Reconstruct 3D CA positions from CA-to-CA displacement vectors .
903+ Reconstruct CA positions from rotations and translations .
943904
944- T[i] = CA[i+1] - CA[i] in the global frame, so positions are simply the
945- cumulative sum of translations starting from the origin. The rotation
946- matrices R are kept as a parameter for API compatibility but are not used
947- here — they represent the *local frame orientation*, not a transform that
948- should be applied to global-frame displacements.
905+ Supports two translation conventions:
906+ - `global`: T[i] is already in global frame (e.g., CA[i+1] - CA[i]).
907+ - `local`: T[i] is in the current local frame and must be rotated into
908+ global coordinates as the chain is composed.
949909
950910 Args:
951- R (torch.Tensor): Local rotation matrices, shape (N, 3, 3) [unused]
952- T (torch.Tensor): CA-to-CA displacement vectors, shape (N, 3)
911+ R (torch.Tensor): Rotation matrices, shape (N, 3, 3).
912+ T (torch.Tensor): Translation vectors, shape (N, 3).
913+ batch_idx (torch.Tensor, optional): Per-residue batch indices (N,).
914+ translation_frame (str): 'global' or 'local'.
915+ include_origin (bool): If True, prepend origin row for each chain.
953916
954917 Returns:
955- torch.Tensor: Reconstructed CA positions of shape (N+1, 3), starting
956- from the origin.
918+ torch.Tensor:
919+ - Unbatched: (N+1, 3) if include_origin else (N, 3)
920+ - Batched: concatenation of per-chain outputs in batch order.
957921 """
958- if batch_idx is not None :
959- # Handle batched input
960- unique_batches = batch_idx .unique ()
961- positions = []
962- for b in unique_batches :
963- mask = batch_idx == b
964- T_b = T [mask ]
965- origin = torch .zeros (1 , 3 , dtype = T .dtype , device = T .device )
966- positions .append (torch .cat ([origin , torch .cumsum (T_b , dim = 0 )], dim = 0 ))
967- return torch .cat (positions , dim = 0 )
968- else :
969- origin = torch .zeros (1 , 3 , dtype = T .dtype , device = T .device )
970- return torch .cat ([origin , torch .cumsum (T , dim = 0 )], dim = 0 )
922+ if translation_frame not in ('global' , 'local' ):
923+ raise ValueError (f"Unknown translation_frame: { translation_frame } " )
924+
925+ def _reconstruct_single (R_s , T_s ):
926+ if translation_frame == 'global' :
927+ pos_no_origin = torch .cumsum (T_s , dim = 0 )
928+ else :
929+ # Compose rigid transforms along chain: x_{i+1} = x_i + R_global @ t_i
930+ N = T_s .shape [0 ]
931+ curr_pos = torch .zeros (3 , dtype = T_s .dtype , device = T_s .device )
932+ curr_R = torch .eye (3 , dtype = T_s .dtype , device = T_s .device )
933+ positions = []
934+ for i in range (N ):
935+ step_global = curr_R @ T_s [i ]
936+ curr_pos = curr_pos + step_global
937+ positions .append (curr_pos .clone ())
938+ curr_R = curr_R @ R_s [i ]
939+ pos_no_origin = torch .stack (positions , dim = 0 ) if positions else T_s .new_zeros ((0 , 3 ))
940+
941+ if include_origin :
942+ origin = torch .zeros (1 , 3 , dtype = T_s .dtype , device = T_s .device )
943+ return torch .cat ([origin , pos_no_origin ], dim = 0 )
944+ return pos_no_origin
945+
946+ if batch_idx is None :
947+ return _reconstruct_single (R , T )
948+
949+ outs = []
950+ for b in torch .unique (batch_idx ):
951+ mask = batch_idx == b
952+ outs .append (_reconstruct_single (R [mask ], T [mask ]))
953+ if len (outs ) == 0 :
954+ return torch .zeros ((0 , 3 ), dtype = T .dtype , device = T .device )
955+ return torch .cat (outs , dim = 0 )
971956
972957
0 commit comments