Skip to content

Respect specified model class for heterogeneous MultiTaskDatasets#5107

Open
hvarfner wants to merge 4 commits intofacebook:mainfrom
hvarfner:export-D98672892
Open

Respect specified model class for heterogeneous MultiTaskDatasets#5107
hvarfner wants to merge 4 commits intofacebook:mainfrom
hvarfner:export-D98672892

Conversation

@hvarfner
Copy link
Copy Markdown

Summary:
Previously, choose_model_class in utils.py would always force-override
any user-specified model class to HeterogeneousMTGP when a heterogeneous
MultiTaskDataset was detected. This made it impossible to use alternative
models for heterogeneous transfer learning (e.g., ImputedMultiTaskGP).

This diff changes the behavior so that when a specified_model_class is
provided by the user (via ModelConfig.botorch_model_class), it is respected
for heterogeneous datasets. When no model class is specified, the default
behavior of selecting HeterogeneousMTGP is preserved.

This is a prerequisite for using ImputedMultiTaskGP through Ax's model
selection pipeline with heterogeneous search spaces.

Differential Revision: D98672892

Carl Hvarfner added 4 commits March 29, 2026 11:07
Summary:

D94693361 introduced a regression: when source experiments have more
parameters than target and status_quo is set, FillMissingParameters adds
extra columns to target arm data during _compute_in_design.
check_membership_df then returns [False] for all rows because
df_cols != ss_params. All target arms are incorrectly filtered out.

Fix: Override _compute_in_design in TransferLearningAdapter to use
check_all_parameters_present=False.

Differential Revision: D97625737
…ing (facebook#5102)

Summary:

When merging search spaces for transfer learning, a parameter may be
FixedParameter in one experiment and ChoiceParameter in another (e.g., a
parameter was fixed to a single value in the source but is tunable in the
target). Previously this raised a ValueError. Now we merge them into a
ChoiceParameter whose values include the union of the choice values and
the fixed value.

Differential Revision: D98247197
…acebook#5106)

Summary:
X-link: https://github.com/facebookexternal/botorch_fb/pull/34


Wire LearnedFeatureImputation and ImputedMultiTaskGP into Ax:

1. **input_transform_argparse dispatcher**: Computes `feature_indices` and `d`
   from a heterogeneous MultiTaskDataset using target-first feature ordering.
   Validates that the dataset is a MultiTaskDataset with heterogeneous features.

2. **Storage registry**: Register ImputedMultiTaskGP in MODEL_REGISTRY and
   LearnedFeatureImputation in INPUT_TRANSFORM_REGISTRY.

3. **Model selection (utils.py)**: When a heterogeneous MultiTaskDataset is
   detected and a model class is specified (e.g. ImputedMultiTaskGP), use the
   specified class instead of force-overriding to HeterogeneousMTGP. Also add
   automatic Normalize + LearnedFeatureImputation transform chaining for
   ImputedMultiTaskGP.

Differential Revision: D97625733
Summary:
Previously, `choose_model_class` in `utils.py` would always force-override
any user-specified model class to `HeterogeneousMTGP` when a heterogeneous
`MultiTaskDataset` was detected. This made it impossible to use alternative
models for heterogeneous transfer learning (e.g., `ImputedMultiTaskGP`).

This diff changes the behavior so that when a `specified_model_class` is
provided by the user (via `ModelConfig.botorch_model_class`), it is respected
for heterogeneous datasets. When no model class is specified, the default
behavior of selecting `HeterogeneousMTGP` is preserved.

This is a prerequisite for using `ImputedMultiTaskGP` through Ax's model
selection pipeline with heterogeneous search spaces.

Differential Revision: D98672892
@meta-codesync
Copy link
Copy Markdown

meta-codesync bot commented Mar 29, 2026

@hvarfner has exported this pull request. If you are a Meta employee, you can view the originating Diff in D98672892.

@meta-cla meta-cla bot added the CLA Signed Do not delete this pull request or issue due to inactivity. label Mar 29, 2026
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 43.75000% with 9 lines in your changes missing coverage. Please review.
✅ Project coverage is 96.40%. Comparing base (316dbad) to head (f915c0a).

Files with missing lines Patch % Lines
ax/adapter/transfer_learning/utils.py 0.00% 4 Missing ⚠️
ax/storage/botorch_modular_registry.py 25.00% 3 Missing ⚠️
ax/adapter/transfer_learning/adapter.py 33.33% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #5107      +/-   ##
==========================================
- Coverage   96.41%   96.40%   -0.01%     
==========================================
  Files         613      613              
  Lines       68106    68117      +11     
==========================================
+ Hits        65663    65670       +7     
- Misses       2443     2447       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed Do not delete this pull request or issue due to inactivity. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants