Skip to content

Commit c1040ff

Browse files
committed
checkout persistent_ax_multitask as-is already on develop, but remove use of "Optional"
1 parent af65de6 commit c1040ff

1 file changed

Lines changed: 18 additions & 26 deletions

File tree

libensemble/gen_funcs/persistent_ax_multitask.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import torch
2828
from ax import Metric, Runner
2929
from ax.core.data import Data
30+
from ax.core.experiment import Experiment
3031
from ax.core.generator_run import GeneratorRun
3132
from ax.core.multi_type_experiment import MultiTypeExperiment
3233
from ax.core.objective import Objective
@@ -43,36 +44,19 @@
4344
AxParameterWarning = Warning
4445

4546
from ax.modelbridge.factory import get_sobol
46-
from ax.modelbridge.registry import Models, ST_MTGP_trans
47+
from ax.modelbridge.registry import MBM_X_trans, Models, ST_MTGP_trans
4748
from ax.modelbridge.torch import TorchModelBridge
48-
from ax.modelbridge.transforms.convert_metric_names import tconfig_from_mt_experiment
49+
from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames, tconfig_from_mt_experiment
50+
from ax.modelbridge.transforms.derelativize import Derelativize
51+
from ax.modelbridge.transforms.stratified_standardize_y import StratifiedStandardizeY
52+
from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice
53+
from ax.modelbridge.transforms.trial_as_task import TrialAsTask
4954
from ax.runners import SyntheticRunner
5055
from ax.storage.json_store.save import save_experiment
5156
from ax.storage.metric_registry import register_metrics
5257
from ax.storage.runner_registry import register_runner
5358
from ax.utils.common.result import Ok
5459

55-
try:
56-
# For Ax >= 0.5.0
57-
from ax.modelbridge.registry import MBM_X_trans
58-
from ax.modelbridge.transforms.convert_metric_names import ConvertMetricNames
59-
from ax.modelbridge.transforms.derelativize import Derelativize
60-
from ax.modelbridge.transforms.stratified_standardize_y import StratifiedStandardizeY
61-
from ax.modelbridge.transforms.task_encode import TaskChoiceToIntTaskChoice
62-
from ax.modelbridge.transforms.trial_as_task import TrialAsTask
63-
64-
MT_MTGP_trans = list(MBM_X_trans) + [
65-
Derelativize,
66-
ConvertMetricNames,
67-
TrialAsTask,
68-
StratifiedStandardizeY,
69-
TaskChoiceToIntTaskChoice,
70-
]
71-
72-
except ImportError:
73-
# For Ax < 0.5.0
74-
from ax.modelbridge.registry import MT_MTGP_trans
75-
7660
from libensemble.message_numbers import EVAL_GEN_TAG, FINISHED_PERSISTENT_GEN_TAG, PERSIS_STOP, STOP_TAG
7761
from libensemble.tools.persistent_support import PersistentSupport
7862

@@ -90,13 +74,21 @@
9074
category=AxParameterWarning,
9175
)
9276

77+
MT_MTGP_trans = list(MBM_X_trans) + [
78+
Derelativize,
79+
ConvertMetricNames,
80+
TrialAsTask,
81+
StratifiedStandardizeY,
82+
TaskChoiceToIntTaskChoice,
83+
]
84+
9385

9486
# get_MTGP based on https://ax.dev/docs/tutorials/multi_task/
9587
def get_MTGP(
96-
experiment,
88+
experiment: Experiment,
9789
data: Data,
98-
search_space: SearchSpace | None = None, # noqa: MDA501
99-
trial_index: int | None = None, # noqa: MDA501
90+
search_space: SearchSpace | None = None,
91+
trial_index: int | None = None,
10092
device: torch.device = torch.device("cpu"),
10193
dtype: torch.dtype = torch.double,
10294
) -> TorchModelBridge:

0 commit comments

Comments
 (0)