Skip to content

Commit d2d7a70

Browse files
authored
Merge pull request #50 from polixir/dev
Dev (0.6.2)
2 parents fa81881 + d44ec40 commit d2d7a70

22 files changed

Lines changed: 606 additions & 288 deletions

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ pip install -e .
122122
We build an example project to include most of the features of RLA, which can be seen in ./example/simplest_code. Now we summarize the steps to use it.
123123

124124
### Step1: Configuration.
125-
1. To configure the experiment "database", you need to create a YAML file rla_config.yaml. You can use the template provided in ./example/simplest_code/rla_config.yaml as a starting point.
125+
1. To configure the experiment "database", you need to create a YAML file rla_config.yaml. You can use the template provided in ./example/rla_config.yaml as a starting point.
126126
2. Before starting your experiment, you should configure the RLA.exp_manager object. Here's an example:
127127

128128
```python
@@ -256,7 +256,7 @@ from RLA import MatplotlibRecorder as mpr
256256
def plot_func():
257257
import matplotlib.pyplot as plt
258258
plt.plot([1,1,1], [2,2,2])
259-
mpr.pretty_plot_wrapper('func', plot_func, xlabel='x', ylabel='y', title='react test', )
259+
mpr.pretty_plot_wrapper('func', plot_func, pretty_plot=True, xlabel='x', ylabel='y', title='react test')
260260
```
261261

262262
This code plots a figure using Matplotlib and saves it in the "results" directory.

RLA/const.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
DEFAULT_X_NAME = 'time-step'
2-
2+
DEFAULT_TIMESTAMP = 'timestamp'
33
class FRAMEWORK:
44
tensorflow = 'tensorflow'
55
torch = 'torch'
@@ -9,4 +9,11 @@ class FTP_PROTOCOL_NAME:
99
SFTP = 'sftp'
1010

1111
class LOG_NAME_FORMAT_VERSION:
12-
V1 = 'v1'
12+
V1 = 'v1'
13+
14+
15+
16+
class PLATFORM_TYPE:
17+
WIN = 'win'
18+
LINUX = 'linux'
19+
OTHER = 'other'

RLA/easy_log/complex_data_recorder.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def save(cls, name=None, fig=None, cover=False, add_timestamp=True, **kwargs):
3434

3535
@classmethod
3636
def pretty_plot_wrapper(cls, name:str, plot_func:Callable,
37-
cover=False, legend_outside=False, xlabel='', ylabel='', title='',
37+
cover=False, legend_outside=False, pretty_plot=False, xlabel='', ylabel='', title='',
3838
add_timestamp=True, *args, **kwargs):
3939
"""
4040
Save the customized plot figure to the RLA database.
@@ -47,6 +47,8 @@ def pretty_plot_wrapper(cls, name:str, plot_func:Callable,
4747
:type cover: bool
4848
:param legend_outside: let legend be outside of the figure.
4949
:type legend_outside: bool
50+
:param pretty_plot: use predefined configurations for plotting.
51+
:type pretty_plot: bool
5052
:param xlabel: name of xlabel
5153
:type xlabel: str
5254
:param ylabel: name of xlabel
@@ -66,12 +68,13 @@ def pretty_plot_wrapper(cls, name:str, plot_func:Callable,
6668
plot_func()
6769
lgd = plt.legend(prop={'size': 15}, loc=2 if legend_outside else None,
6870
bbox_to_anchor=(1, 1) if legend_outside else None)
69-
plt.xlabel(xlabel, fontsize=18)
70-
plt.ylabel(ylabel, fontsize=18)
71-
plt.xticks(fontsize=14)
72-
plt.yticks(fontsize=14)
73-
plt.title(title, fontsize=13)
74-
plt.grid(True)
71+
if pretty_plot:
72+
plt.xlabel(xlabel, fontsize=18)
73+
plt.ylabel(ylabel, fontsize=18)
74+
plt.xticks(fontsize=14)
75+
plt.yticks(fontsize=14)
76+
plt.title(title, fontsize=13)
77+
plt.grid(True)
7578
if lgd is not None:
7679
cls.save(name, cover=cover, add_timestamp=add_timestamp, bbox_extra_artists=tuple([lgd]),
7780
bbox_inches='tight', *args, **kwargs)

RLA/easy_log/exp_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def is_valid_config(self):
4848
if self.load_date is not None and self.task_name is not None and self.data_root is not None:
4949
return True
5050
else:
51-
logger.warn("meet invalid loader config when use it")
51+
logger.warn("meet invalid loader config when using it")
5252
logger.warn("load_date", self.load_date)
5353
logger.warn("task_name", self.task_name)
5454
logger.warn("root", self.data_root)

RLA/easy_log/log_tools.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import json
1818
from RLA.easy_log.tester import Tester
19-
19+
from RLA.utils.utils import get_dir_seperator
2020
class Filter(object):
2121
ALL = 'all'
2222
SMALL_TIMESTEP = 'small_ts'
@@ -214,8 +214,9 @@ def _archive_log(self, show=False):
214214
empty = False
215215
if os.path.exists(root_dir):
216216
# remove the overlapped path.
217+
septor = get_dir_seperator()
217218
archiving_target = osp.join(archive_root_dir, root_dir[prefix_len+1:])
218-
archiving_target_dir = '/'.join(archiving_target.split('/')[:-1])
219+
archiving_target_dir = septor.join(archiving_target.split(septor)[:-1])
219220
os.makedirs(archiving_target_dir, exist_ok=True)
220221
if os.path.isdir(root_dir):
221222
if not show:
@@ -269,21 +270,23 @@ def _migrate_log(self, show=False):
269270
if os.path.exists(root_dir):
270271
# remove the overlapped path.
271272
archiving_target = osp.join(target_root_dir, root_dir[prefix_len+1:])
272-
archiving_target_dir = '/'.join(archiving_target.split('/')[:-1])
273+
274+
septor = get_dir_seperator()
275+
archiving_target_dir = septor.join(archiving_target.split(septor)[:-1])
276+
print("target dir", archiving_target_dir)
273277
os.makedirs(archiving_target_dir, exist_ok=True)
274278
if os.path.isdir(root_dir):
279+
print("copy dir {}, to {}".format(root_dir, archiving_target))
275280
if not show:
276281
# os.makedirs(archiving_target, exist_ok=True)
277282
try:
278283
shutil.copytree(root_dir, archiving_target)
279284
except FileExistsError as e:
280285
print(e)
281-
282-
print("copy dir {}, to {}".format(root_dir, archiving_target))
283286
else:
287+
print("copy file {}, to {}".format(root_dir, archiving_target))
284288
if not show:
285289
shutil.copy(root_dir, archiving_target)
286-
print("copy file {}, to {}".format(root_dir, archiving_target))
287290
else:
288291
print("no dir {}".format(root_dir))
289292
if empty: print("empty regex {}".format(root_dir_regex))

RLA/easy_log/logger.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from typing import Any, Dict, List, Optional, Sequence, TextIO, Tuple, Union
1515
from contextlib import contextmanager
16-
from RLA.const import DEFAULT_X_NAME, FRAMEWORK
16+
from RLA.const import DEFAULT_X_NAME, FRAMEWORK, DEFAULT_TIMESTAMP
1717

1818
DEBUG = 10
1919
INFO = 20
@@ -510,6 +510,12 @@ def dumpkvs():
510510
print(e)
511511
for fmt in Logger.CURRENT.output_formats:
512512
print(fmt)
513+
try:
514+
get_current().logkv(DEFAULT_TIMESTAMP, time.time())
515+
except NotImplementedError as e:
516+
print(e)
517+
for fmt in Logger.CURRENT.output_formats:
518+
print(fmt)
513519
return get_current().dumpkvs()
514520

515521
def getkvs():

RLA/easy_log/tester.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import shutil
2727
import argparse
2828
from typing import Dict, List, Tuple, Type, Union, Optional
29-
from RLA.utils.utils import deprecated_alias, load_yaml
29+
from RLA.utils.utils import deprecated_alias, load_yaml, get_dir_seperator
3030
from RLA.const import DEFAULT_X_NAME, FRAMEWORK
3131
import pathspec
3232

@@ -150,10 +150,11 @@ def configure(self, task_table_name: str, rla_config: Union[str, dict], data_roo
150150
logger.info("private_config: ")
151151
self.dl_framework = self.private_config["DL_FRAMEWORK"]
152152
self.is_master_node = is_master_node
153-
153+
self.septor = get_dir_seperator()
154154
if code_root is None:
155155
if isinstance(rla_config, str):
156-
self.project_root = "/".join(rla_config.split("/")[:-1])
156+
157+
self.project_root = self.septor.join(rla_config.split(self.septor)[:-1])
157158
else:
158159
raise NotImplementedError("If you pass the rla_config dict directly, "
159160
"you should define the root of your codebase (for backup) explicitly by pass the code_root.")
@@ -309,14 +310,19 @@ def load_tester(cls, record_date, task_table_name, log_root):
309310
def add_record_param(self, keys):
310311
for k in keys:
311312
if '.' in k:
313+
sub_k = None
312314
try:
313315
sub_k_list = k.split('.')
314-
v = self.hyper_param[sub_k_list[0]]
316+
sub_k = sub_k_list[0]
317+
v = self.hyper_param[sub_k]
315318
for sub_k in sub_k_list[1:]:
316319
v = v[sub_k]
317320
self.hyper_param_record.append(str(k) + '=' + str(v).replace('[', '{').replace(']', '}').replace('/', '_'))
318321
except KeyError as e:
319-
print("do not include dot ('.') in your hyperparemeter name")
322+
print(f"the current key to parsed is: {k}. Can not find a matching key for {sub_k}."
323+
"\n Hint: do not include dot ('.') in your hyperparemeter name."
324+
"\n The recorded hyper parameters are")
325+
self.print_args()
320326
else:
321327
self.hyper_param_record.append(str(k) + '=' + str(self.hyper_param[k]).replace('[', '{').replace(']', '}').replace('/', '_'))
322328

@@ -620,7 +626,7 @@ def time_record_end(self, name:str):
620626
end_time = time.time()
621627
start_time = self._rc_start_time[name]
622628
logger.record_tabular("time_used/{}".format(name), end_time - start_time)
623-
logger.info("[test] func {0} time used {1:.2f}".format(name, end_time - start_time))
629+
# logger.info("[test] func {0} time used {1:.2f}".format(name, end_time - start_time))
624630
del self._rc_start_time[name]
625631

626632
# Saver manger.

RLA/easy_log/time_step.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from RLA.easy_log import logger
2-
2+
import time
33

44
class TimeStepHolder(object):
55

@@ -8,20 +8,10 @@ def __init__(self, time, epoch, tf_log=False):
88
self.__outer_epoch = epoch
99
self.tf_log = tf_log
1010

11-
def config(self, time=0, epoch=0, tf_log=False):
11+
def config(self, time=0, tf_log=False):
1212
self.__timesteps = time
13-
self.__outer_epoch = epoch
1413
self.tf_log = tf_log
1514

16-
def set_outer_epoch(self, epoch):
17-
self.__outer_epoch = epoch
18-
19-
def get_outer_epoch(self):
20-
return self.__outer_epoch
21-
22-
def inc_outer_epoch(self):
23-
self.__outer_epoch +=1
24-
2515
def set_time(self, time):
2616
self.__timesteps = time
2717
self.__update_tf_times()

RLA/easy_log/time_used_recorder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,5 +103,5 @@ def time_record_end(name):
103103
end_time = time.time()
104104
start_time = rc_start_time[name]
105105
logger.record_tabular("time_used/{}".format(name), end_time - start_time)
106-
logger.info("[test] func {0} time used {1:.2f}".format(name, end_time - start_time))
106+
# logger.info("[test] func {0} time used {1:.2f}".format(name, end_time - start_time))
107107
del rc_start_time[name]

RLA/easy_plot/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from RLA.easy_plot.plot_saved_images import plot_saved_images
2+
from RLA.easy_plot.plot_func_v2 import plot_func

0 commit comments

Comments
 (0)