Skip to content

Commit e85bee5

Browse files
committed
working on foldcomp conversion
1 parent a6d6856 commit e85bee5

8 files changed

Lines changed: 601 additions & 376 deletions

File tree

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Representation Conversion Guide (PDB ⇄ RT ⇄ Quaternion)
2+
3+
This guide documents the conversion flow used by
4+
`foldtree2/scripts/test_representation_conversions.py`.
5+
6+
## Overview
7+
8+
The script follows this sequence:
9+
10+
1. **PDB backbone extraction**
11+
- Extract per-residue backbone atom coordinates: `N`, `CA`, `C`.
12+
2. **Backbone coordinates → local frames**
13+
- Use `PDB2PyG.compute_local_frame(coords)` with `coords` of shape `(N, 3, 3)` in order `[N, CA, C]`.
14+
- Output:
15+
- `R`: rotation matrices, shape `(N, 3, 3)`
16+
- `t`: translation vectors, shape `(N, 3)`
17+
3. **Rotation matrices → quaternions**
18+
- Use `rotation_matrix_to_quaternion(R)`.
19+
- Quaternion convention in this repo: `(w, x, y, z)` (scalar first).
20+
4. **Quaternions → rotation matrices**
21+
- Use `quaternion_to_rotation_matrix(q)` for roundtrip reconstruction.
22+
5. **RT → chain coordinates**
23+
- Use `reconstruct_positions(R, t)` to reconstruct coordinates from transforms.
24+
25+
## Noise Experiments
26+
27+
The script evaluates robustness by injecting noise in each representation:
28+
29+
- **Coordinate noise**: add Gaussian noise to `(N, CA, C)` coordinates, then recompute `R, t, q`.
30+
- **RT noise**:
31+
- left-multiply random small rotations onto `R`
32+
- add Gaussian noise to `t`
33+
- **Quaternion noise**:
34+
- add Gaussian noise to quaternion components
35+
- renormalize quaternions to unit norm
36+
- convert back to rotation matrices
37+
38+
## Losses
39+
40+
For each noisy variant, the script reports:
41+
42+
- **FAPE loss** via `fape_loss(true_R, true_t, pred_R, pred_t, batch)`
43+
- **lDDT-style loss** via `compute_lddt_loss(true_positions, pred_positions)`
44+
where positions come from `reconstruct_positions`
45+
46+
Lower values indicate better consistency with the baseline representation.
47+
48+
## Run
49+
50+
From the repository root:
51+
52+
```bash
53+
python -m foldtree2.scripts.test_representation_conversions \
54+
--pdb-path foldtree2/config/1eei.pdb \
55+
--coord-noise 0.25 \
56+
--rot-noise-rad 0.05 \
57+
--trans-noise 0.10 \
58+
--quat-noise 0.05 \
59+
--seed 0
60+
```
61+
62+
## Key Notes
63+
64+
- Use backbone triplets `[N, CA, C]` to define local frames.
65+
- CA-only coordinates are not sufficient for unique residue local orientation without extra assumptions.
66+
- Keep quaternion convention consistent as `(w, x, y, z)` throughout conversions.

foldtree2/foldcomp2fasta.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
6+
7+
def _read_ids(ids_file):
8+
ids = []
9+
with open(ids_file, 'r', encoding='utf-8') as f:
10+
for line in f:
11+
line = line.strip()
12+
if not line or line.startswith('#'):
13+
continue
14+
ids.append(line.split()[0])
15+
return ids
16+
17+
18+
def main():
19+
parser = argparse.ArgumentParser(
20+
description='Encode a Foldcomp DB directly to FoldTree2 token FASTA.'
21+
)
22+
parser.add_argument('model', type=str, help='Path to trained encoder .pt file')
23+
parser.add_argument('foldcomp_db', type=str, help='Path to Foldcomp DB basename (without .lookup)')
24+
parser.add_argument('output_fasta', type=str, help='Output encoded FASTA path')
25+
26+
parser.add_argument('--device', type=str, default=None, help='Device (e.g., cuda, cuda:0, cpu)')
27+
parser.add_argument('--ids-file', type=str, default=None, help='Optional text file with Foldcomp IDs (one per line)')
28+
parser.add_argument('--max-structures', type=int, default=None, help='Optional max number of structures to encode')
29+
parser.add_argument('--chunk-size', type=int, default=1024, help='Foldcomp prefetch chunk size (default: 1024)')
30+
parser.add_argument('--queue-size', type=int, default=4, help='Producer/consumer queue size (default: 4)')
31+
parser.add_argument('--batch-size', type=int, default=16, help='Encoder batch size per chunk (default: 16)')
32+
parser.add_argument('--cache-size', type=int, default=0, help='Graph cache size in Foldcomp dataset (default: 0)')
33+
parser.add_argument('--no-replace', action='store_true', help='Disable FASTA special-character replacement')
34+
parser.add_argument('--quiet', action='store_true', help='Disable progress bar')
35+
36+
args = parser.parse_args()
37+
38+
if not os.path.exists(args.model):
39+
raise FileNotFoundError(f'Model not found: {args.model}')
40+
41+
ids = None
42+
if args.ids_file is not None:
43+
ids = _read_ids(args.ids_file)
44+
45+
device = torch.device(args.device) if args.device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
46+
encoder = torch.load(args.model, map_location=device, weights_only=False)
47+
encoder = encoder.to(device)
48+
encoder.device = device
49+
encoder.eval()
50+
51+
output = encoder.encode_foldcomp_fasta(
52+
foldcomp_db=args.foldcomp_db,
53+
filename=args.output_fasta,
54+
ids=ids,
55+
max_structures=args.max_structures,
56+
chunk_size=args.chunk_size,
57+
queue_size=args.queue_size,
58+
batch_size=args.batch_size,
59+
cache_size=args.cache_size,
60+
replace=not args.no_replace,
61+
verbose=not args.quiet,
62+
)
63+
64+
print(f'Encoded FASTA written to: {output}')
65+
66+
67+
if __name__ == '__main__':
68+
main()

foldtree2/foldcomp_otf.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import foldcomp
2+
# 01. Handling a FCZ file
3+
# Open a fcz file
4+
with open("test/compressed.fcz", "rb") as fcz:
5+
fcz_binary = fcz.read()
6+
7+
# Decompress
8+
(name, pdb) = foldcomp.decompress(fcz_binary) # pdb_out[0]: file name, pdb_out[1]: pdb binary string
9+
10+
# Save to a pdb file
11+
with open(name, "w") as pdb_file:
12+
pdb_file.write(pdb)
13+
14+
# Get data as dictionary
15+
data_dict = foldcomp.get_data(fcz_binary) # foldcomp.get_data(pdb) also works
16+
# Keys: phi, psi, omega, torsion_angles, residues, bond_angles, coordinates
17+
data_dict["phi"] # phi angles (C-N-CA-C)
18+
data_dict["psi"] # psi angles (N-CA-C-N)
19+
data_dict["omega"] # omega angles (CA-C-N-CA)
20+
data_dict["torsion_angles"] # torsion angles of the backbone as list (phi + psi + omega)
21+
data_dict["bond_angles"] # bond angles of the backbone as list
22+
data_dict["residues"] # amino acid residues as string
23+
data_dict["coordinates"] # coordinates of the backbone as list
24+
25+
# 02. Iterate over a database of FCZ files
26+
# Open a foldcomp database
27+
ids = ["d1asha_", "d1it2a_"]
28+
with foldcomp.open("test/example_db", ids=ids) as db:
29+
# Iterate through database
30+
for (name, pdb) in db:
31+
# save entries as seperate pdb files
32+
with open(name + ".pdb", "w") as pdb_file:
33+
pdb_file.write(pdb)

foldtree2/learn_lightning.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -537,11 +537,11 @@ def training_step(self, batch, batch_idx):
537537

538538
# lDDT loss
539539
lddt_loss = torch.tensor(0.0, device=self.device)
540-
if (self.args.lddt_weight > 0 or getattr(self.args, 'lddt_loss', False)) and out.get('coords') is not None and hasattr(data, 'coords') and hasattr(data['coords'], 'x'):
540+
if (self.args.lddt_weight > 0 or getattr(self.args, 'lddt_loss', False)) and out.get('quat_pred') is not None and out.get('trans_pred') is not None and hasattr(data, 'coords') and hasattr(data['coords'], 'x'):
541541
from foldtree2.src.losses.losses import batch_lddt_loss
542542
lddt_loss = batch_lddt_loss(
543-
pred_q=out.get('quat', None),
544-
pred_t=out.get('trans', None),
543+
pred_q=out.get('quat_pred'),
544+
pred_t=out.get('trans_pred'),
545545
true_coords=data['coords'].x,
546546
batch=getattr(data['res'], 'batch', None),
547547
plddt=data['plddt'].x if self.args.mask_plddt else None,
@@ -550,26 +550,37 @@ def training_step(self, batch, batch_idx):
550550

551551
# FAPE loss
552552
fape_loss = torch.tensor(0.0, device=self.device)
553-
if (self.args.fape_weight > 0 or getattr(self.args, 'fape_loss', False)) and out.get('quat') is not None and out.get('trans') is not None and hasattr(data, 'quat') and hasattr(data['quat'], 'x') and hasattr(data, 'trans') and hasattr(data['trans'], 'x'):
553+
if (self.args.fape_weight > 0 or getattr(self.args, 'fape_loss', False)) and out.get('quat_pred') is not None and out.get('trans_pred') is not None and hasattr(data, 'q_true') and hasattr(data['q_true'], 'x') and hasattr(data, 'coords') and hasattr(data['coords'], 'x'):
554554
from foldtree2.src.losses.losses import batch_fape_loss
555+
_fape_batch = getattr(data['res'], 'batch', None)
556+
_pred_disp = out['trans_pred']
557+
# Convert CA-to-CA displacements to CA positions (cumsum per structure)
558+
# FAPE requires absolute positions; cumsum from origin is translation-invariant
559+
if _fape_batch is not None:
560+
_pred_pos = torch.zeros_like(_pred_disp)
561+
for _b in torch.unique(_fape_batch):
562+
_m = (_fape_batch == _b).nonzero(as_tuple=True)[0]
563+
_pred_pos[_m] = torch.cumsum(_pred_disp[_m], dim=0)
564+
else:
565+
_pred_pos = torch.cumsum(_pred_disp, dim=0)
555566
fape_loss = batch_fape_loss(
556-
true_q=data['quat'].x,
557-
true_t=data['trans'].x,
558-
pred_q=out['quat'],
559-
pred_t=out['trans'],
560-
batch=getattr(data['res'], 'batch', None),
567+
true_q=data['q_true'].x,
568+
true_t=data['coords'].x,
569+
pred_q=out['quat_pred'],
570+
pred_t=_pred_pos,
571+
batch=_fape_batch,
561572
)
562573

563574
# Delta loss
564575
delta_loss_val = torch.tensor(0.0, device=self.device)
565-
if (self.args.delta_weight > 0 or getattr(self.args, 'delta_loss', False)) and out.get('coords') is not None and hasattr(data, 'coords') and hasattr(data['coords'], 'x'):
576+
if (self.args.delta_weight > 0 or getattr(self.args, 'delta_loss', False)) and (out.get('quat_pred') is not None or out.get('coords') is not None) and hasattr(data, 'coords') and hasattr(data['coords'], 'x'):
566577
from foldtree2.src.losses.losses import batch_delta_loss
567578
try:
568-
if out.get('quat') is not None and out.get('trans') is not None:
579+
if out.get('quat_pred') is not None and out.get('trans_pred') is not None:
569580
delta_loss_val = batch_delta_loss(
570581
true_ca=data['coords'].x,
571-
pred_q=out['quat'],
572-
pred_t=out['trans'],
582+
pred_q=out['quat_pred'],
583+
pred_t=out['trans_pred'],
573584
batch=getattr(data['res'], 'batch', None),
574585
plddt=data['plddt'].x if self.args.mask_plddt else None,
575586
plddt_thresh=self.args.plddt_threshold if self.args.mask_plddt else 0.0,

0 commit comments

Comments
 (0)