-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy patheval_metric.py
More file actions
92 lines (80 loc) · 2.97 KB
/
eval_metric.py
File metadata and controls
92 lines (80 loc) · 2.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import argparse
import functools
import multiprocessing as mp
import logging
import itertools
import json
import pandas as pd
import traceback
import pickle
import numpy as np
import pdb
import re
import csv
from abx.metric import eval_metric, make_coords, cdr_numbering, InterfaceEnergy
def parse_list(data_dir):
input_fname_pattern = '\.pdb$'
relax_fname_pattern = '\_relaxed.pdb$'
visited = set()
for parent, _, files in os.walk(data_dir):
for fname in files:
# print(f"fname: {fname}")
fpath = os.path.join(parent, fname)
if not re.search(input_fname_pattern, fname):
continue
if re.search(relax_fname_pattern, fname):
continue
if os.path.getsize(fpath) == 0:
continue
if fpath in visited:
continue
visited.add(fpath)
yield fpath
def main(args):
reference_data = {}
results = []
for ref_pdb in parse_list(os.path.join(args.data_dir, 'reference')):
# if 'relaxed' not in ref_pdb:
pdb_name = ((ref_pdb.split('/')[-1]).split('.pdb'))[0]
code, heavy_chain, light_chain, antigen_chain = pdb_name.split('_')
ref_ab_ca, ref_ab_str_seq, ref_heavy_str, ref_light_str = make_coords(ref_pdb)
cdr_def = cdr_numbering(ref_heavy_str, ref_light_str)
data = {
'cdr_def': cdr_def,
'coords': ref_ab_ca,
'str_seq': ref_ab_str_seq
}
# if args.energy:
# ref_energy = InterfaceEnergy(ref_pdb)
# data.update(
# {'dG': ref_energy}
# )
reference_data[f'{pdb_name}'] = data
# print(f"ref: {reference_data}")
func = functools.partial(eval_metric, args=args, reference_data=reference_data)
with mp.Pool(args.cpus) as p:
results = p.starmap(func, ((pdb_file,) for pdb_file in parse_list(args.data_dir)))
# Average Results for each Metric
df = pd.DataFrame(results)
column_means = df.mean()
filtered_means = column_means.filter(like='RMSD').append(column_means.filter(like='AAR'))
print(f"---------------------")
print(f"Average Results for each Metric")
print(f"---------------------")
print(filtered_means)
csv_file_path = os.path.join(args.data_dir, 'results.csv')
with open(csv_file_path, mode='w', newline='') as file:
fieldnames = results[0].keys() if results else []
writer = csv.DictWriter(file, fieldnames=fieldnames)
writer.writeheader()
for result in results:
writer.writerow(result)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-o', '--data_dir', type=str, required=True)
parser.add_argument('-c', '--cpus', type=int, default=1)
parser.add_argument('-e', '--energy', type=bool, default=False)
parser.add_argument('-v', '--verbose', type=bool, default=False)
args = parser.parse_args()
main(args)