|
10 | 10 | print('eval_txts:', [_.split(os.sep)[-1] for _ in eval_txts]) |
11 | 11 | score_panel = {} |
12 | 12 | sep = '&' |
13 | | -metric = ['sm', 'wfm', 'hce'][1] # we used HCE for DIS for wFm for others. |
| 13 | +metrics = ['sm', 'wfm', 'hce'] # we used HCE for DIS and wFm for others. |
| 14 | +if 'DIS5K' not in config.task: |
| 15 | + metrics.remove('hce') |
14 | 16 |
|
15 | | -current_line_nums = [] |
16 | | -for idx_et, eval_txt in enumerate(eval_txts): |
17 | | - with open(eval_txt, 'r') as f: |
18 | | - lines = [l for l in f.readlines()[3:] if '.' in l] |
19 | | - current_line_nums.append(len(lines)) |
20 | | -for idx_et, eval_txt in enumerate(eval_txts): |
21 | | - with open(eval_txt, 'r') as f: |
22 | | - lines = [l for l in f.readlines()[3:] if '.' in l] |
23 | | - for idx_line, line in enumerate(lines[:min(current_line_nums)]): # Consist line numbers by the minimal result file. |
24 | | - properties = line.strip().strip(sep).split(sep) |
25 | | - dataset = properties[0].strip() |
26 | | - ckpt = properties[1].strip() |
27 | | - if int(ckpt.split('--ep')[-1].strip()) < 0: |
28 | | - continue |
29 | | - targe_idx = { |
30 | | - 'sm': [5, 2, 2, 5, 2], |
31 | | - 'wfm': [3, 3, 8, 3, 8], |
32 | | - 'hce': [7, -1, -1, 7, -1] |
33 | | - }[metric][['DIS5K', 'COD', 'HRSOD', 'DIS5K+HRSOD+HRS10K', 'P3M-10k'].index(config.task)] |
34 | | - if metric != 'hce': |
35 | | - score_sm = float(properties[targe_idx].strip()) |
36 | | - else: |
37 | | - score_sm = int(properties[targe_idx].strip().strip('.')) |
38 | | - if idx_et == 0: |
39 | | - score_panel[ckpt] = [] |
40 | | - score_panel[ckpt].append(score_sm) |
| 17 | +for metric in metrics: |
| 18 | + current_line_nums = [] |
| 19 | + for idx_et, eval_txt in enumerate(eval_txts): |
| 20 | + with open(eval_txt, 'r') as f: |
| 21 | + lines = [l for l in f.readlines()[3:] if '.' in l] |
| 22 | + current_line_nums.append(len(lines)) |
| 23 | + for idx_et, eval_txt in enumerate(eval_txts): |
| 24 | + with open(eval_txt, 'r') as f: |
| 25 | + lines = [l for l in f.readlines()[3:] if '.' in l] |
| 26 | + for idx_line, line in enumerate(lines[:min(current_line_nums)]): # Consist line numbers by the minimal result file. |
| 27 | + properties = line.strip().strip(sep).split(sep) |
| 28 | + dataset = properties[0].strip() |
| 29 | + ckpt = properties[1].strip() |
| 30 | + if int(ckpt.split('--epoch_')[-1].strip()) < 0: |
| 31 | + continue |
| 32 | + targe_idx = { |
| 33 | + 'sm': [5, 2, 2, 5, 2], |
| 34 | + 'wfm': [3, 3, 8, 3, 8], |
| 35 | + 'hce': [7, -1, -1, 7, -1] |
| 36 | + }[metric][['DIS5K', 'COD', 'HRSOD', 'DIS5K+HRSOD+HRS10K', 'P3M-10k'].index(config.task)] |
| 37 | + if metric != 'hce': |
| 38 | + score_sm = float(properties[targe_idx].strip()) |
| 39 | + else: |
| 40 | + score_sm = int(properties[targe_idx].strip().strip('.')) |
| 41 | + if idx_et == 0: |
| 42 | + score_panel[ckpt] = [] |
| 43 | + score_panel[ckpt].append(score_sm) |
41 | 44 |
|
42 | | -metrics_min = ['hce', 'mae'] |
43 | | -max_or_min = min if metric in metrics_min else max |
44 | | -score_max = max_or_min(score_panel.values(), key=lambda x: np.sum(x)) |
| 45 | + metrics_min = ['hce', 'mae'] |
| 46 | + max_or_min = min if metric in metrics_min else max |
| 47 | + score_max = max_or_min(score_panel.values(), key=lambda x: np.sum(x)) |
45 | 48 |
|
46 | | -good_models = [] |
47 | | -for k, v in score_panel.items(): |
48 | | - if (np.sum(v) <= np.sum(score_max)) if metric in metrics_min else (np.sum(v) >= np.sum(score_max)): |
49 | | - print(k, v) |
50 | | - good_models.append(k) |
| 49 | + good_models = [] |
| 50 | + for k, v in score_panel.items(): |
| 51 | + if (np.sum(v) <= np.sum(score_max)) if metric in metrics_min else (np.sum(v) >= np.sum(score_max)): |
| 52 | + print(k, v) |
| 53 | + good_models.append(k) |
51 | 54 |
|
52 | | -# Write |
53 | | -with open(eval_txt, 'r') as f: |
54 | | - lines = f.readlines() |
55 | | -info4good_models = lines[:3] |
56 | | -metric_names = [m.strip() for m in lines[1].strip().strip('&').split('&')[2:]] |
57 | | -testset_mean_values = {metric_name: [] for metric_name in metric_names} |
58 | | -for good_model in good_models: |
59 | | - for idx_et, eval_txt in enumerate(eval_txts): |
60 | | - with open(eval_txt, 'r') as f: |
61 | | - lines = f.readlines() |
62 | | - for line in lines: |
63 | | - if set([good_model]) & set([_.strip() for _ in line.split(sep)]): |
64 | | - info4good_models.append(line) |
65 | | - metric_scores = [float(m.strip()) for m in line.strip().strip('&').split('&')[2:]] |
66 | | - for idx_score, metric_score in enumerate(metric_scores): |
67 | | - testset_mean_values[metric_names[idx_score]].append(metric_score) |
| 55 | + # Write |
| 56 | + with open(eval_txt, 'r') as f: |
| 57 | + lines = f.readlines() |
| 58 | + info4good_models = lines[:3] |
| 59 | + metric_names = [m.strip() for m in lines[1].strip().strip('&').split('&')[2:]] |
| 60 | + testset_mean_values = {metric_name: [] for metric_name in metric_names} |
| 61 | + for good_model in good_models: |
| 62 | + for idx_et, eval_txt in enumerate(eval_txts): |
| 63 | + with open(eval_txt, 'r') as f: |
| 64 | + lines = f.readlines() |
| 65 | + for line in lines: |
| 66 | + if set([good_model]) & set([_.strip() for _ in line.split(sep)]): |
| 67 | + info4good_models.append(line) |
| 68 | + metric_scores = [float(m.strip()) for m in line.strip().strip('&').split('&')[2:]] |
| 69 | + for idx_score, metric_score in enumerate(metric_scores): |
| 70 | + testset_mean_values[metric_names[idx_score]].append(metric_score) |
68 | 71 |
|
69 | | -if 'DIS5K' in config.task: |
70 | | - testset_mean_values_lst = ['{:<4}'.format(int(np.mean(v_lst[:-1]).round())) if name == 'HCE' else '{:.3f}'.format(np.mean(v_lst[:-1])).lstrip('0') for name, v_lst in testset_mean_values.items()] # [:-1] to remove DIS-VD |
71 | | - sample_line_for_placing_mean_values = info4good_models[-2] |
72 | | - numbers_placed_well = sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').strip().split('&')[3:] |
73 | | - for idx_number, (number_placed_well, testset_mean_value) in enumerate(zip(numbers_placed_well, testset_mean_values_lst)): |
74 | | - numbers_placed_well[idx_number] = number_placed_well.replace(number_placed_well.strip(), testset_mean_value) |
75 | | - testset_mean_line = '&'.join(sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').split('&')[:3] + numbers_placed_well) + '\n' |
76 | | - info4good_models.append(testset_mean_line) |
77 | | -info4good_models.append(lines[-1]) |
78 | | -info = ''.join(info4good_models) |
79 | | -print(info) |
80 | | -with open(os.path.join('e_results', 'eval-{}_best_on_{}.txt'.format(config.task, metric)), 'w') as f: |
81 | | - f.write(info + '\n') |
| 72 | + if 'DIS5K' in config.task: |
| 73 | + testset_mean_values_lst = ['{:<4}'.format(int(np.mean(v_lst[:-1]).round())) if name == 'HCE' else '{:.3f}'.format(np.mean(v_lst[:-1])).lstrip('0') for name, v_lst in testset_mean_values.items()] # [:-1] to remove DIS-VD |
| 74 | + sample_line_for_placing_mean_values = info4good_models[-2] |
| 75 | + numbers_placed_well = sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').strip().split('&')[3:] |
| 76 | + for idx_number, (number_placed_well, testset_mean_value) in enumerate(zip(numbers_placed_well, testset_mean_values_lst)): |
| 77 | + numbers_placed_well[idx_number] = number_placed_well.replace(number_placed_well.strip(), testset_mean_value) |
| 78 | + testset_mean_line = '&'.join(sample_line_for_placing_mean_values.replace(sample_line_for_placing_mean_values.split('&')[1].strip(), 'DIS-TEs').split('&')[:3] + numbers_placed_well) + '\n' |
| 79 | + info4good_models.append(testset_mean_line) |
| 80 | + info4good_models.append(lines[-1]) |
| 81 | + info = ''.join(info4good_models) |
| 82 | + print(info) |
| 83 | + with open(os.path.join('e_results', 'eval-{}_best_on_{}.txt'.format(config.task, metric)), 'w') as f: |
| 84 | + f.write(info + '\n') |
0 commit comments