File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1212from omegaconf import OmegaConf
1313from pytorch_lightning .utilities import rank_zero_only
1414from torch .utils .data import DataLoader
15+ from transformers .utils .import_utils import is_torch_tf32_available
1516
1617from mmlearn .cli ._instantiators import (
1718 instantiate_callbacks ,
@@ -41,7 +42,11 @@ def main(cfg: MMLearnConf) -> None: # noqa: PLR0912
4142 cfg_copy = copy .deepcopy (cfg ) # copy of the config for logging
4243
4344 L .seed_everything (cfg .seed , workers = True )
44- torch .set_float32_matmul_precision ("high" )
45+
46+ if is_torch_tf32_available ():
47+ torch .backends .cuda .matmul .allow_tf32 = True
48+ if "16-mixed" in cfg .trainer .precision :
49+ cfg .trainer .precision = "bf16-mixed"
4550
4651 # setup trainer first so that we can get some variables for distributed training
4752 callbacks = instantiate_callbacks (cfg .trainer .get ("callbacks" ))
You can’t perform that action at this time.
0 commit comments