|
26 | 26 | import shutil |
27 | 27 | import argparse |
28 | 28 | from typing import Dict, List, Tuple, Type, Union, Optional |
29 | | -from RLA.utils.utils import deprecated_alias, load_yaml |
| 29 | +from RLA.utils.utils import deprecated_alias, load_yaml, get_dir_seperator |
30 | 30 | from RLA.const import DEFAULT_X_NAME, FRAMEWORK |
31 | 31 | import pathspec |
32 | 32 |
|
@@ -150,10 +150,11 @@ def configure(self, task_table_name: str, rla_config: Union[str, dict], data_roo |
150 | 150 | logger.info("private_config: ") |
151 | 151 | self.dl_framework = self.private_config["DL_FRAMEWORK"] |
152 | 152 | self.is_master_node = is_master_node |
153 | | - |
| 153 | + self.septor = get_dir_seperator() |
154 | 154 | if code_root is None: |
155 | 155 | if isinstance(rla_config, str): |
156 | | - self.project_root = "/".join(rla_config.split("/")[:-1]) |
| 156 | + |
| 157 | + self.project_root = self.septor.join(rla_config.split(self.septor)[:-1]) |
157 | 158 | else: |
158 | 159 | raise NotImplementedError("If you pass the rla_config dict directly, " |
159 | 160 | "you should define the root of your codebase (for backup) explicitly by pass the code_root.") |
@@ -309,14 +310,19 @@ def load_tester(cls, record_date, task_table_name, log_root): |
309 | 310 | def add_record_param(self, keys): |
310 | 311 | for k in keys: |
311 | 312 | if '.' in k: |
| 313 | + sub_k = None |
312 | 314 | try: |
313 | 315 | sub_k_list = k.split('.') |
314 | | - v = self.hyper_param[sub_k_list[0]] |
| 316 | + sub_k = sub_k_list[0] |
| 317 | + v = self.hyper_param[sub_k] |
315 | 318 | for sub_k in sub_k_list[1:]: |
316 | 319 | v = v[sub_k] |
317 | 320 | self.hyper_param_record.append(str(k) + '=' + str(v).replace('[', '{').replace(']', '}').replace('/', '_')) |
318 | 321 | except KeyError as e: |
319 | | - print("do not include dot ('.') in your hyperparemeter name") |
| 322 | + print(f"the current key to parsed is: {k}. Can not find a matching key for {sub_k}." |
| 323 | + "\n Hint: do not include dot ('.') in your hyperparemeter name." |
| 324 | + "\n The recorded hyper parameters are") |
| 325 | + self.print_args() |
320 | 326 | else: |
321 | 327 | self.hyper_param_record.append(str(k) + '=' + str(self.hyper_param[k]).replace('[', '{').replace(']', '}').replace('/', '_')) |
322 | 328 |
|
@@ -620,7 +626,7 @@ def time_record_end(self, name:str): |
620 | 626 | end_time = time.time() |
621 | 627 | start_time = self._rc_start_time[name] |
622 | 628 | logger.record_tabular("time_used/{}".format(name), end_time - start_time) |
623 | | - logger.info("[test] func {0} time used {1:.2f}".format(name, end_time - start_time)) |
| 629 | + # logger.info("[test] func {0} time used {1:.2f}".format(name, end_time - start_time)) |
624 | 630 | del self._rc_start_time[name] |
625 | 631 |
|
626 | 632 | # Saver manger. |
|
0 commit comments