Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
113 commits
Select commit Hold shift + click to select a range
24668c3
feat: create ml pipeline for linear probe
vojtech-cifka May 1, 2026
f340038
refactor(ml): switch DataModule to HF datasets with fold-based split
vojtech-cifka May 4, 2026
c3ef38a
feat(ml): wire up linear probe training with k-fold CV on cached embe…
vojtech-cifka May 7, 2026
c644f22
fix(configs): use override for class_mapping in experiment yaml
vojtech-cifka May 7, 2026
564b0b1
fix(scripts): drop duplicate +ml= from linear-probe submit command
vojtech-cifka May 7, 2026
3a77adc
fix(ml): register random_seed/len resolvers and unflatten class_mappi…
vojtech-cifka May 7, 2026
11b19f0
fix(ml): accept already-canonical labels in datamodule label map
vojtech-cifka May 8, 2026
c6bfe8e
feat(ml): class-weighted CE, raise class_coverage_min to 0.5
vojtech-cifka May 8, 2026
894c27b
fix: sort only tiles parquet
vojtech-cifka May 8, 2026
fc824ad
fix: log join types of tile keys
vojtech-cifka May 8, 2026
11931d1
fix: remove embeddings from the join
vojtech-cifka May 8, 2026
fb6b320
fix: remove label column
vojtech-cifka May 8, 2026
7434ae9
fix: prevent overflow
vojtech-cifka May 8, 2026
1b18daa
Merge remote-tracking branch 'origin/master' into feature/linear-probe
vojtech-cifka May 8, 2026
bef70df
feat: add embedding dataset build pipeline
vojtech-cifka May 8, 2026
911bec2
feat: add class tresholds and run ids
vojtech-cifka May 8, 2026
1a02395
fix: wrong run id
vojtech-cifka May 8, 2026
08d7ba5
Merge remote-tracking branch 'origin/master' into feature/embedding-d…
vojtech-cifka May 9, 2026
b38465e
feat: add timing
vojtech-cifka May 9, 2026
bfc9578
refactor: use pyarrow to avoid to pandas conversion
vojtech-cifka May 9, 2026
eb213c6
fix: join on keys only
vojtech-cifka May 9, 2026
c92d9a1
fix: typing
vojtech-cifka May 9, 2026
01cc394
fix: add prints
vojtech-cifka May 9, 2026
cad0d37
refactor: use combine chunks
vojtech-cifka May 9, 2026
ae04552
fix: lazy-cast embeddings to large_list and stay in Arrow during join
vojtech-cifka May 9, 2026
82320db
fix: validate label/tissue_prop columns when derive=False
vojtech-cifka May 9, 2026
3b0137f
chore: remove time
vojtech-cifka May 9, 2026
8df47aa
feat: add timing
vojtech-cifka May 10, 2026
926753d
chore: revert to the previous state
vojtech-cifka May 10, 2026
b0e9ba4
feat: add prints
vojtech-cifka May 10, 2026
6a915de
refactor: use discusssed thresholds
vojtech-cifka May 11, 2026
0f50307
refactor: use different labeling strategy
vojtech-cifka May 11, 2026
4d953dc
feat: implement training pipeline
vojtech-cifka May 11, 2026
d5798bc
feat: add class weights
vojtech-cifka May 11, 2026
ae45cd5
refactor: join embeddings with metadata while loading the dataset
vojtech-cifka May 11, 2026
bdce760
feat: add prints
vojtech-cifka May 11, 2026
ac633d5
fix: use chunks
vojtech-cifka May 11, 2026
2793562
fix: use numpy chunks
vojtech-cifka May 11, 2026
e81973e
fix: call end at the end of the main
vojtech-cifka May 11, 2026
0071592
chore: remove prints
vojtech-cifka May 11, 2026
c0a7499
chore: remove debug prints, stale TODO, and unused preprocessing pipe…
vojtech-cifka May 11, 2026
fe918d1
chore: remove markdown file
vojtech-cifka May 11, 2026
6b7d1e8
fix: edge cases
vojtech-cifka May 12, 2026
4ff988e
feat: normalize the confusion matrix rows per class recall
vojtech-cifka May 12, 2026
32375b2
fix: format
vojtech-cifka May 12, 2026
af9538a
feat: use stratified k fold run
vojtech-cifka May 12, 2026
bc0819a
fix: remove criterion
vojtech-cifka May 12, 2026
b8e85e0
fix: remove criterion from configs
vojtech-cifka May 12, 2026
c387189
feat: implement test pipeline
vojtech-cifka May 13, 2026
1216504
fix: Hydra unreached
vojtech-cifka May 13, 2026
7ec86ef
fix: set weights only to false
vojtech-cifka May 13, 2026
c9b566e
fix: criterion weight
vojtech-cifka May 13, 2026
ff4d307
Merge branch 'master' into feature/ml-linear-classifier
vojtech-cifka May 13, 2026
3cc670d
feat: add option to use different kfold strategies
vojtech-cifka May 13, 2026
ad0a4e7
feat: add training without validation
vojtech-cifka May 13, 2026
811e21c
feat: implement final test run
vojtech-cifka May 13, 2026
27ceea3
fix: lower LR and patience
vojtech-cifka May 13, 2026
efde82a
fix: use f1 macro as a monitor
vojtech-cifka May 14, 2026
c8102de
fix: rever back to validation loss
vojtech-cifka May 14, 2026
c5bab90
fix: add weight decay 1e-3 to linear classifier
vojtech-cifka May 14, 2026
475b67c
Revert "fix: add weight decay 1e-3 to linear classifier"
vojtech-cifka May 14, 2026
43663a9
feat: add logistic regression
vojtech-cifka May 14, 2026
a2fe451
feat: polish and add two distinct submission scripts
vojtech-cifka May 14, 2026
31ecf6d
fix: submission scripts
vojtech-cifka May 14, 2026
ff8d0bf
feat: implement knn
vojtech-cifka May 14, 2026
1f87154
refactor: focus on convergence
vojtech-cifka May 14, 2026
7039307
Remove kNN sklearn baseline
vojtech-cifka May 14, 2026
729eccd
fix: change monitor to focus on train losss
vojtech-cifka May 14, 2026
d3ed2ed
feat: add run name
vojtech-cifka May 14, 2026
e9fd559
chore: remove logistic regression
vojtech-cifka May 15, 2026
6dadbd7
feat: implement lbfgs
vojtech-cifka May 15, 2026
d5d3edd
fix: run id
vojtech-cifka May 15, 2026
2163699
fix: cache the tiles and embeddings so they do not need to be downloa…
vojtech-cifka May 15, 2026
9286807
fix: limit num of workers
vojtech-cifka May 15, 2026
bb8a043
fix: support checkpoint test and prediction export
vojtech-cifka May 15, 2026
c284d8d
Merge remote-tracking branch 'origin/feature/linear-probe' into featu…
vojtech-cifka May 15, 2026
efddcd6
Revert "Merge remote-tracking branch 'origin/feature/linear-probe' in…
vojtech-cifka May 15, 2026
420534e
Merge remote-tracking branch 'origin/feature/ml-linear-classifier' in…
vojtech-cifka May 15, 2026
14909e2
feat: add functionality to submit final train for both adamw and lbfgs
vojtech-cifka May 15, 2026
8167363
feat: implement prediction maps
vojtech-cifka May 16, 2026
4e45ce1
fix: change the adamw checkpoint dir name to last
vojtech-cifka May 16, 2026
8f9ce70
fix: lower the batch so the compute does not hang
vojtech-cifka May 16, 2026
99c2d0d
fix: put num workers to 0
vojtech-cifka May 16, 2026
01486bd
feat: add prints
vojtech-cifka May 16, 2026
64963ac
Merge branch 'master' into feature/ml-test-mode
vojtech-cifka May 16, 2026
85270fd
feat: add diagnostic prints
vojtech-cifka May 16, 2026
5db671c
fix: use numpy buffer
vojtech-cifka May 16, 2026
3aea3c2
refactor: use HeatmapAssembler
vojtech-cifka May 16, 2026
756642a
chore: clean config structure
vojtech-cifka May 16, 2026
2771e78
fix: prediction maps class indices
vojtech-cifka May 16, 2026
918b691
fix: format and mypy
vojtech-cifka May 16, 2026
4032df3
feat: add posibility to predict the whole slide with tissue area
vojtech-cifka May 17, 2026
ca50a7c
feat: add embeddings for whole slide
vojtech-cifka May 17, 2026
6489cd0
refactor: compute grayscale mask per each class
vojtech-cifka May 17, 2026
9d8729a
refactor: do not generate error masks
vojtech-cifka May 18, 2026
ac0ce16
chore: config cleanup
vojtech-cifka May 18, 2026
099e277
feat: add prints to the prediction maps writer
vojtech-cifka May 18, 2026
4909324
feat: add embeddings run id for the whole tissue tiles run
vojtech-cifka May 18, 2026
ee9d2da
feat: add prediction maps in configs
vojtech-cifka May 18, 2026
8b3a82d
chore: deduplicate, apply safety nets
vojtech-cifka May 18, 2026
e16426e
fix: pytorch checkpoint loading
vojtech-cifka May 18, 2026
fd3fdd6
chore: remove redundancy, rename variables
vojtech-cifka May 18, 2026
c401015
chore: remove username and branch
vojtech-cifka May 18, 2026
847c3cc
refactor: rename configs
vojtech-cifka May 18, 2026
2ba0562
fix: keep criterion.weight in state_dict for strict checkpoint load
vojtech-cifka May 18, 2026
e370417
fix: criterion weight
vojtech-cifka May 18, 2026
632a8f6
fix: keep space in MUG prediction masks names
vojtech-cifka May 18, 2026
3cd0243
fix: log test accuracy as jsons
vojtech-cifka May 18, 2026
76e4194
chore: remove username from the submission script
vojtech-cifka May 18, 2026
597e348
fix: force the entering of the write phase of the prediction maps
vojtech-cifka May 18, 2026
e4a4cc5
fix: surface why prediction-map write phase skips
vojtech-cifka May 18, 2026
3829ebd
fix: remove username
vojtech-cifka May 18, 2026
79b47a2
fix: preserve original wsi name
vojtech-cifka May 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/data/dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dataset:
test_split_filename: "split_mapping/test_split.csv"
tiling_run_id: "fdf7550a2004474f8c7a05dc0cf1fd86"
filter_tiles_run_id: "4e8f5d3c82124ea5a8f871a42d3ed9ba"
stratified_kfold_run_id: "c7eafdffa32743aa9eb6dd2bf3a185b5"
stratified_kfold_run_id: "850c81506684450b9af92296acfd045a"
stratified_group_kfold_run_id: "382b41d2fa894514908e8067949c4326"
embedding_run_id: "c325e3a5033b4077b6febb0e3e6b0bd6"
tissue_masks_run_id: "52bc0924f8624b259819c480c7cf213f"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# @package _global_

defaults:
- /ml/linear_classifier
- /ml/task: kfold_linear_classifier
- _self_

kfold_strategy: stratified_group
kfold_run_id: ${dataset.mlflow_artifacts.stratified_group_kfold_run_id}

model:
optimizer: adamw
learning_rate: 1.0e-4
weight_decay: 0.0
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# @package _global_

defaults:
- /ml/linear_classifier
- /ml/task: kfold_linear_classifier
- _self_

kfold_strategy: stratified
kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id}

model:
optimizer: adamw
learning_rate: 1.0e-4
weight_decay: 0.0
23 changes: 23 additions & 0 deletions configs/experiment/ml/linear_classifier_final_adamw.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# @package _global_

defaults:
- /ml/task: final_linear_classifier
- override /ml/trainer: early_stopping
- _self_

# AdamW final: trained to convergence with the same early-stopping rule as the
# k-fold sweep (monitor train/loss_epoch, patience 1, min_delta 1e-4), not a
# fixed 6-epoch budget. weight_decay=1e-3 = best AdamW sweep point (flat curve).
model:
optimizer: adamw
learning_rate: 1.0e-4
weight_decay: 1.0e-3

trainer:
callbacks:
model_checkpoint:
save_last: true

metadata:
run_name: Final Linear Classifier AdamW ${dataset.name}
description: "Final AdamW linear probe over frozen Virchow2 embeddings, trained on all training folds with early stopping on train/loss_epoch."
37 changes: 37 additions & 0 deletions configs/experiment/ml/linear_classifier_final_lbfgs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# @package _global_

defaults:
- /ml/task: final_linear_classifier
- _self_

# LBFGS final: exact solve of the convex objective on the full training batch.
# weight_decay=1e-2 = best LBFGS sweep point. Full-batch guard requires
# train_shuffle=false, train_drop_last=false, train_batch_size >= len(train);
# num_workers=0 avoids the single-batch IPC deadlock.
trainer:
max_epochs: 10

data:
train_batch_size: 1000000000
eval_batch_size: 1024
train_shuffle: false
train_drop_last: false
num_workers: 0

model:
optimizer: lbfgs
learning_rate: 1.0
weight_decay: 1.0e-2
lbfgs:
max_iter: 100
max_eval: null
tolerance_grad: 1.0e-7
tolerance_change: 1.0e-9
history_size: 100
line_search_fn: strong_wolfe
accumulate_batches: 1
accumulate_on_cpu: false

metadata:
run_name: Final Linear Classifier LBFGS ${dataset.name}
description: "Final LBFGS linear probe over frozen Virchow2 embeddings, exact full-batch solve of the convex objective on all training folds."
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ trainer:
max_epochs: 10

data:
batch_size: 1000000000
train_batch_size: 1000000000
train_shuffle: false
train_drop_last: false
num_workers: 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ trainer:
max_epochs: 10

data:
batch_size: 1000000000
train_batch_size: 1000000000
train_shuffle: false
train_drop_last: false

Expand Down
58 changes: 58 additions & 0 deletions configs/experiment/ml/linear_classifier_predict_tissue_tiles.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# @package _global_

defaults:
- /ml/task: final_linear_classifier
- _self_

# Predict over every test-split tile that intersects the tissue mask. This run
# is unlabeled: it writes predictions/maps for review and does not compute test
# metrics.
mode: predict
final_train_run_id: 0e2230c722134ce0985e09a18ccadf75
checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt
checkpoint_weights_only: false

tissue_embedding_run_id: 95a02c93c164415e94702ad5c83ccca2
Comment thread
vojtech-cifka marked this conversation as resolved.
tissue_stats_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id}
tissue_stats_artifact_path: tissue_stats
tissue_column: tile_tissue_coverage
tissue_min: 0.0

test_embedding_uri: runs:/${tissue_embedding_run_id}/test/tiles
test_metadata_uri: runs:/${tissue_stats_run_id}/${tissue_stats_artifact_path}/test_tiles.parquet

data:
# num_workers MUST be 0 for this config. UnlabeledEmbeddingTilesDataset and
# EmbeddingTilesDataset load the entire split into memory in __init__; using
# worker processes (num_workers > 0) can deadlock after multiprocessing/fork.
num_workers: 0
Comment thread
vojtech-cifka marked this conversation as resolved.
predict:
_target_: ml.data.datasets.UnlabeledEmbeddingTilesDataset
embedding_uri: ${test_embedding_uri}
metadata_uri: ${test_metadata_uri}
tissue_column: ${tissue_column}
tissue_min: ${tissue_min}

model:
optimizer: lbfgs
learning_rate: 1.0
weight_decay: 1.0e-2
lbfgs:
max_iter: 100
max_eval: null
tolerance_grad: 1.0e-7
tolerance_change: 1.0e-9
history_size: 100
line_search_fn: strong_wolfe
accumulate_batches: 1
accumulate_on_cpu: false

metadata:
run_name: Predict Linear Classifier tissue tiles ${dataset.name}
description: "Predict final linear classifier over all held-out test tiles intersecting tissue masks for external doctor review."
hyperparams:
tissue_embedding_run_id: ${tissue_embedding_run_id}
tissue_stats_run_id: ${tissue_stats_run_id}
tissue_stats_artifact_path: ${tissue_stats_artifact_path}
tissue_column: ${tissue_column}
tissue_min: ${tissue_min}
32 changes: 32 additions & 0 deletions configs/experiment/ml/linear_classifier_test_adamw.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# @package _global_

defaults:
- /experiment/ml/linear_classifier_final_adamw
- _self_

# Test the AdamW final checkpoint on the held-out test split. Same model
# architecture as the final run (required for state_dict load); optimizer
# fields are inert at test.
#
mode: test
final_train_run_id: a23e478b00b04da79cfbf4d91cada8cd
checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt
Comment thread
vojtech-cifka marked this conversation as resolved.
checkpoint_weights_only: false

# num_workers MUST stay 0. EmbeddingTilesDataset loads the entire split into
# one in-memory numpy array in __init__ and __getitem__ is pure numpy indexing
# (no per-item IO), so workers give zero speedup; num_workers>0 forks the
# parent (pyarrow/mlflow/fsspec thread state + large array) and deadlocks
# before the first test batch. final_embedding_tiles defaults to 4; override here.
data:
num_workers: 0

trainer:
callbacks:
tiff_prediction_maps:
_target_: ml.callbacks.TiffPredictionMapWriter
slides_uri: runs:/${dataset.mlflow_artifacts.tiling_run_id}/test_split/slides.parquet
artifact_path: prediction_maps_tiff
draw_region: central_stride
slide_selection: all
max_slides: null
36 changes: 36 additions & 0 deletions configs/experiment/ml/linear_classifier_test_lbfgs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# @package _global_

defaults:
- /experiment/ml/linear_classifier_final_lbfgs
- override /ml/trainer: early_stopping
- _self_

# Test the LBFGS final checkpoint on the held-out test split. The full-batch
# train_batch_size=1e9 is a TRAINING requirement for the convex LBFGS solve
# only; at test there is no optimization, so use a normal batch to avoid
# loading the whole test set as one tensor (OOM).
#
# num_workers MUST stay 0. EmbeddingTilesDataset loads the entire split into
# one in-memory numpy array in __init__ and __getitem__ is pure numpy indexing
# (no per-item IO), so workers give zero speedup; num_workers>0 forks the
# parent (which holds pyarrow/mlflow/fsspec thread state + the large array),
# which deadlocks before the first test batch.
# Same model architecture as the final run (required for state_dict load).
mode: test
final_train_run_id: 0e2230c722134ce0985e09a18ccadf75
checkpoint: mlflow-artifacts:/104/${final_train_run_id}/artifacts/checkpoints/last/checkpoint.ckpt
Comment thread
vojtech-cifka marked this conversation as resolved.
checkpoint_weights_only: false

data:
train_batch_size: 1024
num_workers: 0

trainer:
callbacks:
tiff_prediction_maps:
_target_: ml.callbacks.TiffPredictionMapWriter
slides_uri: runs:/${dataset.mlflow_artifacts.tiling_run_id}/test_split/slides.parquet
artifact_path: prediction_maps_tiff
draw_region: central_stride
slide_selection: all
max_slides: null
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# @package _global_

defaults:
- /experiment/preprocessing/embeddings_virchow2_05mpp
- _self_

# Embeddings for every tile in the train/test tiling split that intersects the
# tissue mask. This is the tile universe used for doctor-review prediction maps.
splits:
- test
tile_source_run_id: ${dataset.mlflow_artifacts.tissue_stats_run_id}
tile_source_artifact_template: "tissue_stats/{split}_tiles.parquet"
tile_filter_column: tile_tissue_coverage

metadata:
run_name: "Embeddings: ${model} tissue tiles"
description: "Tile embeddings using ${model} over held-out test split tiles with tile_tissue_coverage > 0."
1 change: 1 addition & 0 deletions configs/ml.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ defaults:
seed: ${random_seed:}
mode: ???
checkpoint: null
checkpoint_weights_only: null

trainer: {}

Expand Down
23 changes: 23 additions & 0 deletions configs/ml/data/final_embedding_tiles.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# @package _global_

data:
train_batch_size: 1024
num_workers: 4
train_shuffle: true
train_drop_last: false

train:
_target_: ml.data.datasets.EmbeddingTilesDataset
embedding_uri: ${train_embedding_uri}
metadata_uri: ${train_metadata_uri}
class_indices: ${class_indices}
thresholds: ${thresholds}
tissue_prop_min: ${tissue_prop_min}

test:
_target_: ml.data.datasets.EmbeddingTilesDataset
embedding_uri: ${test_embedding_uri}
metadata_uri: ${test_metadata_uri}
class_indices: ${class_indices}
thresholds: ${thresholds}
tissue_prop_min: ${tissue_prop_min}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @package _global_

data:
batch_size: 1024
train_batch_size: 1024
num_workers: 4
train_shuffle: true
train_drop_last: true
Expand Down
16 changes: 4 additions & 12 deletions configs/ml/model/linear_classifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,7 @@ model:

class_indices: ${class_indices}

optimizer: adamw
learning_rate: 1.0e-4
weight_decay: 0.0
lbfgs:
max_iter: 100
max_eval: null
tolerance_grad: 1.0e-7
tolerance_change: 1.0e-9
history_size: 100
line_search_fn: strong_wolfe
accumulate_batches: 1
accumulate_on_cpu: false
optimizer: ???
learning_rate: ???
weight_decay: ???
lbfgs: null
48 changes: 48 additions & 0 deletions configs/ml/task/final_linear_classifier.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# @package _global_

defaults:
- /data: dataset
- /class_mapping: collapse_alterations_to_other
- /ml/trainer: final_with_prediction_maps
- /ml/data: final_embedding_tiles
- /ml/model: linear_classifier
- _self_

mode: fit

embedding_run_id: ${dataset.mlflow_artifacts.embedding_run_id}
kfold_strategy: stratified
kfold_run_id: ${dataset.mlflow_artifacts.stratified_kfold_run_id}
filter_tiles_run_id: ${dataset.mlflow_artifacts.filter_tiles_run_id}

train_embedding_uri: runs:/${embedding_run_id}/train/tiles
test_embedding_uri: runs:/${embedding_run_id}/test/tiles
train_metadata_uri: runs:/${kfold_run_id}/kfold_split/kfold_tiles.parquet
test_metadata_uri: runs:/${filter_tiles_run_id}/filter_tiles/test_tiles.parquet

tissue_prop_min: 0.2
thresholds:
Nerve: 0.0
Blood: 0.0
Connective-Tissue: 0.4
Fat: 0.6
Epithelium: 0.2
Muscle: 0.5
Other: 0.5

mlflow_artifact_path: linear_classifier_final

metadata:
run_name: Final Linear Classifier ${dataset.name}
description: "Final linear probe over frozen Virchow2 embeddings trained on all training folds for ${trainer.max_epochs} epochs."
hyperparams:
embedding_run_id: ${embedding_run_id}
kfold_strategy: ${kfold_strategy}
kfold_run_id: ${kfold_run_id}
filter_tiles_run_id: ${filter_tiles_run_id}
tissue_prop_min: ${tissue_prop_min}
thresholds: ${thresholds}
learning_rate: ${model.learning_rate}
weight_decay: ${model.weight_decay}
batch_size: ${data.train_batch_size}
max_epochs: ${trainer.max_epochs}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
defaults:
- /data: dataset
- /class_mapping: collapse_alterations_to_other
- /ml/trainer: default
- /ml/data: embedding
- /ml/trainer: early_stopping
- /ml/data: kfold_embedding_tiles
- /ml/model: linear_classifier
- _self_

Expand Down Expand Up @@ -49,6 +49,6 @@ metadata:
learning_rate: ${model.learning_rate}
weight_decay: ${model.weight_decay}
lbfgs: ${model.lbfgs}
batch_size: ${data.batch_size}
batch_size: ${data.train_batch_size}
train_shuffle: ${data.train_shuffle}
train_drop_last: ${data.train_drop_last}
File renamed without changes.
Loading