-
Notifications
You must be signed in to change notification settings - Fork 31
Expand file tree
/
Copy pathconfig.py
More file actions
200 lines (174 loc) · 10.4 KB
/
config.py
File metadata and controls
200 lines (174 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import os
import logging
import typing as t
from omegaconf import (OmegaConf, DictConfig)
from mlcube.errors import ConfigurationError
from mlcube.runner import Runner
logger = logging.getLogger(__name__)
__all__ = ['IOType', 'ParameterType', 'MLCubeConfig']
class IOType(object):
INPUT = 'input'
OUTPUT = 'output'
@staticmethod
def is_valid(io: t.Text) -> bool:
return io in (IOType.INPUT, IOType.OUTPUT)
class ParameterType(object):
FILE = 'file'
DIRECTORY = 'directory'
UNKNOWN = 'unknown'
@staticmethod
def is_valid(io: t.Text) -> bool:
return io in (ParameterType.FILE, ParameterType.DIRECTORY, ParameterType.UNKNOWN)
class MLCubeConfig(object):
@staticmethod
def ensure_values_exist(config: DictConfig, keys: t.Union[t.Text, t.List], constructor: t.Callable) -> t.List:
if isinstance(keys, str):
keys = [keys]
for key in keys:
if config.get(key, None) is None:
config[key] = constructor()
return [config[key] for key in keys]
@staticmethod
def get_uri(value: t.Text) -> t.Text:
if value.startswith('storage:'):
raise ValueError(f"Storage schema is not yet supported")
return os.path.abspath(os.path.expanduser(value))
@staticmethod
def create_mlcube_config(mlcube_config_file: t.Text, mlcube_cli_args: t.Optional[DictConfig] = None,
task_cli_args: t.Optional[t.Dict] = None, runner_config: t.Optional[DictConfig] = None,
workspace: t.Optional[t.Text] = None, tasks: t.Optional[t.List[t.Text]] = None,
resolve: bool = True, runner_cls: t.Optional[t.Type[Runner]] = None) -> DictConfig:
""" Create MLCube mlcube merging different configs - base, global, local and cli.
Args:
mlcube_config_file: Path to mlcube.yaml file.
mlcube_cli_args: MLCube mlcube from command line.
task_cli_args: Task parameters from command line.
runner_config: MLCube runner configuration, usually comes from system settings file.
workspace: Workspace path to use in this MLCube run.
tasks: List of tasks to be executed. If empty or None, consider all tasks in MLCube config.
resolve: If true, compute all values (some of them may reference other parameters or environmental
variables).
runner_cls: A python class for the runner type specified in `runner_config`.
"""
if mlcube_cli_args is None:
mlcube_cli_args = OmegaConf.create({})
if runner_config is None:
runner_config = OmegaConf.create({})
logger.debug("mlcube_config_file = %s", mlcube_config_file)
logger.debug("mlcube_cli_args = %s", mlcube_cli_args)
logger.debug("task_cli_args = %s", task_cli_args)
logger.debug("runner_config = %s", str(runner_config))
logger.debug("workspace = %s", workspace)
logger.debug("tasks = %s", tasks)
# Load MLCube configuration and maybe override parameters from command line (like -Pdocker.build_strategy=...).
actual_workspace = '${runtime.root}/workspace' if workspace is None else MLCubeConfig.get_uri(workspace)
mlcube_config = OmegaConf.merge(
OmegaConf.load(mlcube_config_file),
mlcube_cli_args,
OmegaConf.create({
'runtime': {
'root': os.path.dirname(mlcube_config_file),
'workspace': actual_workspace
},
'runner': runner_config
})
)
# Maybe this is not the best idea, but originally MLCube used $WORKSPACE token to refer to the internal
# workspace. So, this value is here to simplify access to workspace value. BTW, in general, if files are to be
# located inside workspace (internal or custom), users are encouraged not to use ${runtime.workspace} or
# ${workspace} in their MLCube configuration files.
mlcube_config['workspace'] = actual_workspace
# Merge, for instance, docker runner config from system settings with docker config from MLCube config.
if runner_cls:
runner_cls.CONFIG.merge(mlcube_config)
# Need to apply CLI arguments again just in case users provided something like -Prunner.build_strategy=...
mlcube_config = OmegaConf.merge(mlcube_config, mlcube_cli_args)
if runner_cls:
runner_cls.CONFIG.validate(mlcube_config)
# TODO: This needs some discussion. Originally, one task was supposed to run at a time. Now, we seem to converge
# to support lists of tasks. Current implementation continues to use old format of task parameters, i.e.
# `param_name=param_value`. This may result in an unexpected behavior when many tasks run, so we should think
# about a different rule: `task_name.param_name=param_value`. This is similar to parameter specification in
# DVC.
# We will iterate over all tasks and make them representation canonical, but will use only tasks to run to
# check if users provided unrecognized parameters.
task_cli_args = task_cli_args or dict() # Dictionary of task arguments from a CL.
tasks_to_run = tasks or list(mlcube_config.tasks.keys()) # These tasks need to run later.
overridden_parameters = set() # Set of parameters from task_cli_args that were used.
for task_name in mlcube_config.tasks.keys():
[task] = MLCubeConfig.ensure_values_exist(mlcube_config.tasks, task_name, dict)
[parameters] = MLCubeConfig.ensure_values_exist(task, 'parameters', dict)
[inputs, outputs] = MLCubeConfig.ensure_values_exist(parameters, ['inputs', 'outputs'], dict)
overridden_inputs = MLCubeConfig.check_parameters(inputs, task_cli_args)
overridden_outputs = MLCubeConfig.check_parameters(outputs, task_cli_args)
if task_name in tasks_to_run:
overridden_parameters.update(overridden_inputs)
overridden_parameters.update(overridden_outputs)
unknown_task_cli_args = set([name for name in task_cli_args if name not in overridden_parameters])
if unknown_task_cli_args:
MLCubeConfig.report_unknown_task_cli_args(task_cli_args, unknown_task_cli_args)
raise ConfigurationError(f"Unknown task CLI arguments: {unknown_task_cli_args}")
if resolve:
OmegaConf.resolve(mlcube_config)
return mlcube_config
@staticmethod
def check_parameters(parameters: DictConfig, task_cli_args: t.Dict) -> t.Set:
""" Check that task parameters are defined according to MLCube schema.
Args:
parameters: Task parameters (`inputs` or `outputs`).
task_cli_args: Task parameters that a user provided on a command line.
Returns:
Set of parameters that were overridden using task parameters on a command line.
This function does not set `type` of parameters (if not present) in all cases.
"""
overridden_parameters = set()
for name in parameters.keys():
# The `_param_name` is anyway there, so check it's not None.
[param_def] = MLCubeConfig.ensure_values_exist(parameters, name, dict)
# Deal with the case when value is a string (default value).
if isinstance(param_def, str):
parameters[name] = {'default': param_def}
param_def = parameters[name]
# If `default` key is not present, use parameter name as value.
_ = MLCubeConfig.ensure_values_exist(param_def, 'default', lambda: name)
# One challenge is how to identify type (file, directory) of input/output parameters if users have
# not provided these types. The below is a kind of rule-based system that tries to infer types.
# Make sure every parameter definition contains 'type' field. Also, if it's unknown, we can assume it's a
# directory if a value ends with forward/backward slash.
_ = MLCubeConfig.ensure_values_exist(param_def, 'type', lambda: ParameterType.UNKNOWN)
if param_def.type == ParameterType.UNKNOWN and param_def.default.endswith(os.sep):
param_def.type = ParameterType.DIRECTORY
# See if there is value on a command line
if name in task_cli_args:
param_def.default = task_cli_args.get(name)
overridden_parameters.add(name)
# Check again parameter type. Users in certain number of cases will not be providing final slash on a
# command line for directories, so we tried to infer types above using default values. Just in case, see
# if we can do the same with user-provided values.
if param_def.type == ParameterType.UNKNOWN and param_def.default.endswith(os.sep):
param_def.type = ParameterType.DIRECTORY
# TODO: For some input parameters, that generally speaking must exist, we can figure out types later,
# when we actually use them (in one of the runners). One problem is when inputs are optional. In this
# case, we need to know their type in advance.
# It probably does not make too much sense to see, let's say, if an input parameter exists and set its
# type at this moment, because MLCube can run on remote hosts.
return overridden_parameters
@staticmethod
def check_task_cli_args(tasks: DictConfig, task_cli_args: t.Dict) -> None:
"""Find any unknown task CLI arguments and report error.
Args:
tasks: Dictionary of task definitions.
task_cli_args: Dictionary of task parameters that a user provided on a command line.
Raises:
ConfigurationError if at least one task CLI argument is not recognized.
"""
@staticmethod
def report_unknown_task_cli_args(task_cli_args: t.Dict, unknown_task_cli_args: t.Set) -> None:
"""Task CLI argument (s) has not been recognized, report this error.
Args:
task_cli_args: Dictionary of all task CLI arguments.
unknown_task_cli_args: Arguments that have not been used (recognized).
"""
print("The following task CLI arguments have not been used:")
for arg in unknown_task_cli_args:
print(f"\t{arg} = {task_cli_args[arg]}")