|
27 | 27 | import torch |
28 | 28 | from ax import Metric, Runner |
29 | 29 | from ax.core.data import Data |
| 30 | +from ax.core.experiment import Experiment |
30 | 31 | from ax.core.generator_run import GeneratorRun |
31 | 32 | from ax.core.multi_type_experiment import MultiTypeExperiment |
32 | 33 | from ax.core.objective import Objective |
|
43 | 44 | AxParameterWarning = Warning |
44 | 45 |
|
45 | 46 | from ax.modelbridge.factory import get_sobol |
46 | | -from ax.modelbridge.registry import Models, ST_MTGP_trans |
| 47 | +from ax.modelbridge.registry import MBM_X_trans, Models, ST_MTGP_trans |
47 | 48 | from ax.modelbridge.torch import TorchModelBridge |
48 | | -from ax.modelbridge.transforms.convert_metric_names import tconfig_from_mt_experiment |
| 49 | +from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames, tconfig_from_mt_experiment |
| 50 | +from ax.modelbridge.transforms.derelativize import Derelativize |
| 51 | +from ax.modelbridge.transforms.stratified_standardize_y import StratifiedStandardizeY |
| 52 | +from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice |
| 53 | +from ax.modelbridge.transforms.trial_as_task import TrialAsTask |
49 | 54 | from ax.runners import SyntheticRunner |
50 | 55 | from ax.storage.json_store.save import save_experiment |
51 | 56 | from ax.storage.metric_registry import register_metrics |
52 | 57 | from ax.storage.runner_registry import register_runner |
53 | 58 | from ax.utils.common.result import Ok |
54 | 59 |
|
55 | | -try: |
56 | | - # For Ax >= 0.5.0 |
57 | | - from ax.modelbridge.registry import MBM_X_trans |
58 | | - from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames |
59 | | - from ax.modelbridge.transforms.derelativize import Derelativize |
60 | | - from ax.modelbridge.transforms.stratified_standardize_y import StratifiedStandardizeY |
61 | | - from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice |
62 | | - from ax.modelbridge.transforms.trial_as_task import TrialAsTask |
63 | | - |
64 | | - MT_MTGP_trans = list(MBM_X_trans) + [ |
65 | | - Derelativize, |
66 | | - ConvertMetricNames, |
67 | | - TrialAsTask, |
68 | | - StratifiedStandardizeY, |
69 | | - TaskChoiceToIntTaskChoice, |
70 | | - ] |
71 | | - |
72 | | -except ImportError: |
73 | | - # For Ax < 0.5.0 |
74 | | - from ax.modelbridge.registry import MT_MTGP_trans |
75 | | - |
76 | 60 | from libensemble.message_numbers import EVAL_GEN_TAG, FINISHED_PERSISTENT_GEN_TAG, PERSIS_STOP, STOP_TAG |
77 | 61 | from libensemble.tools.persistent_support import PersistentSupport |
78 | 62 |
|
|
90 | 74 | category=AxParameterWarning, |
91 | 75 | ) |
92 | 76 |
|
| 77 | +MT_MTGP_trans = list(MBM_X_trans) + [ |
| 78 | + Derelativize, |
| 79 | + ConvertMetricNames, |
| 80 | + TrialAsTask, |
| 81 | + StratifiedStandardizeY, |
| 82 | + TaskChoiceToIntTaskChoice, |
| 83 | +] |
| 84 | + |
93 | 85 |
|
94 | 86 | # get_MTGP based on https://ax.dev/docs/tutorials/multi_task/ |
95 | 87 | def get_MTGP( |
96 | | - experiment, |
| 88 | + experiment: Experiment, |
97 | 89 | data: Data, |
98 | | - search_space: SearchSpace | None = None, # noqa: MDA501 |
99 | | - trial_index: int | None = None, # noqa: MDA501 |
| 90 | + search_space: SearchSpace | None = None, |
| 91 | + trial_index: int | None = None, |
100 | 92 | device: torch.device = torch.device("cpu"), |
101 | 93 | dtype: torch.dtype = torch.double, |
102 | 94 | ) -> TorchModelBridge: |
|
0 commit comments