Skip to content

Commit 4f13598

Browse files
committed
feat(easy_log,easy_plot): add parameter.json to LOG folder. support the hyper-parmater loading (for easy_plot) from LOG folder.
1 parent d64e64d commit 4f13598

8 files changed

Lines changed: 42 additions & 13 deletions

File tree

RLA/auto_ftp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import traceback
66
from RLA.const import *
77
from RLA.easy_log import logger
8-
98
import pysftp
109

1110

@@ -17,6 +16,7 @@ def ftp_factory(name, server, username, password, port, ignore=None):
1716
else:
1817
raise NotImplementedError
1918

19+
2020
class FTPHandler(object):
2121

2222
def __init__(self, ftp_server, username, password, port, ignore=None):
@@ -139,6 +139,7 @@ def close(self):
139139
self.ftp.quit()
140140
self.ftp.close()
141141

142+
142143
class SFTPHandler(FTPHandler):
143144

144145
def __init__(self, sftp_server, username, password, port, ignore=None):

RLA/easy_log/const.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
OTHER_RESULTS = 'results'
66
ARCHIVED_TABLE = 'arc'
77
default_log_types = [LOG, CODE, CHECKPOINT, ARCHIVE_TESTER, OTHER_RESULTS]
8-
8+
HYPARAM = 'parameter'
99

1010
class LoadTesterMode:
1111
FORK_TO_NEW = 'fork'

RLA/easy_log/tester.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,11 @@ def log_files_gen(self):
215215
self.serialize_object_and_save()
216216
self.__copy_source_code(self.run_file, code_dir)
217217
self._feed_hyper_params_to_tb()
218+
params = self.hyper_param
219+
for param_dir in [self.code_dir, self.log_dir]:
220+
with open(osp.join(param_dir, HYPARAM + '.json'), 'w') as f:
221+
json.dump(params, f, sort_keys=True, indent=4, allow_nan=True, default=lambda o: '<not serializable>')
222+
print("gen:", osp.join(param_dir, 'parameter.json'))
218223
self.print_log_dir()
219224

220225
def update_log_files_location(self, root:str):
@@ -782,9 +787,6 @@ def print_args(self):
782787
# formatted_log_name = self.log_name_formatter(self.get_task_table_name(), self.record_date)
783788
params = exp_manager.hyper_param
784789
# params['formatted_log_name'] = formatted_log_name
785-
json.dump(params, open(osp.join(self.code_dir, 'parameter.json'), 'w'),
786-
sort_keys=True, indent=4, allow_nan=True, default=lambda o: '<not serializable>')
787-
print("gen:", osp.join(self.code_dir, 'parameter.json'))
788790

789791

790792
def print_large_memory_variable(self):

RLA/easy_plot/plot_func_v2.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
# Created by xionghuichen at 2022/8/10
22
# Email: chenxh@lamda.nju.edu.cn
33
import glob
4+
import json
45
import os.path as osp
56
import os
67
import dill
78
import copy
89
import numpy as np
910
from typing import Dict, List, Tuple, Type, Union, Optional, Callable
1011
import matplotlib.pyplot as plt
11-
1212
from RLA import logger
1313
from RLA.const import DEFAULT_X_NAME
1414
from RLA.query_tool import experiment_data_query, extract_valid_index
15-
1615
from RLA.easy_plot import plot_util
17-
from RLA.easy_log.const import LOG, ARCHIVE_TESTER, OTHER_RESULTS
16+
from RLA.easy_log.const import LOG, ARCHIVE_TESTER, OTHER_RESULTS, HYPARAM
1817

1918

2019

@@ -25,7 +24,6 @@ def default_key_to_legend(parse_dict, split_keys, y_name, use_y_name=True):
2524
else:
2625
return task_split_key
2726

28-
2927
def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, metrics:list,
3028
use_buf=False, verbose=True,
3129
x_bound: Optional[int]=None,
@@ -97,7 +95,11 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
9795
if verbose:
9896
print("find log", v.dirname)
9997
counter += 1
100-
result.hyper_param = tester_dict[k].exp_manager.hyper_param
98+
if os.path.exists(osp.join(v.dirname, HYPARAM + '.json')):
99+
with open(osp.join(v.dirname, HYPARAM + '.json')) as f:
100+
result.hyper_param = json.load(f)
101+
else:
102+
result.hyper_param = tester_dict[k].exp_manager.hyper_param
101103
results.append(result)
102104
reg_group[reg].append(result)
103105
print("find log number", counter)
@@ -126,7 +128,6 @@ def plot_func(data_root:str, task_table_name:str, regs:list, split_keys:list, me
126128
split_by_metrics=split_by_metrics, regs2legends=regs2legends, *args, **kwargs)
127129
print("--- complete process ---")
128130
if save_name is not None:
129-
import os
130131
file_name = osp.join(data_root, OTHER_RESULTS, 'easy_plot', save_name)
131132
os.makedirs(os.path.dirname(file_name), exist_ok=True)
132133
if lgd is not None:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"env_id": "Test-v1",
3+
"info": "default exp info",
4+
"input_size": 16,
5+
"learning_rate": 0.001,
6+
"loaded_date": true,
7+
"loaded_task_name": "",
8+
"seed": 88
9+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"input_size": 16,
3+
"learning_rate": 0.0001
4+
}

test/test_plot.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@
731731
},
732732
{
733733
"cell_type": "code",
734-
"execution_count": 22,
734+
"execution_count": 16,
735735
"id": "a55937ed",
736736
"metadata": {
737737
"scrolled": false

test/test_proj/proj/test_manager.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_log_tf(self):
5555

5656
exp_manager.new_saver(var_prefix='', max_to_keep=1)
5757
# synthetic target function.
58-
for i in range(0, 1000):
58+
for i in range(0, 100):
5959
exp_manager.time_step_holder.set_time(i)
6060
x_input = np.random.normal(0, 3, [64, kwargs["input_size"]])
6161
y = target_func(x_input)
@@ -143,10 +143,22 @@ def test_sent_to_master(self):
143143
yaml = self._load_rla_config()
144144
try:
145145
from test.test_proj.proj import private_config
146+
# try to import libs
146147
except ImportError as e:
147148
print("[WARN] for this test, you should config your username, password, and the remote root firstly.")
148149
return
149150
# raise RuntimeError
151+
try:
152+
if private_config.protocol == 'ftp':
153+
import ftplib
154+
elif private_config.protocol == 'sftp':
155+
import pysftp
156+
else:
157+
raise NotImplementedError
158+
except ImportError as e:
159+
print(e)
160+
print(f"[WARN] the select protocol {private_config.protocol} cannot be loaded. skip the unittest.")
161+
return
150162
yaml['DL_FRAMEWORK'] = 'torch'
151163
yaml['SEND_LOG_FILE'] = True
152164
yaml['REMOTE_SETTING']['ftp_server'] = '127.0.0.1'

0 commit comments

Comments
 (0)