Skip to content

Commit ce2a59d

Browse files
committed
Add support for tf32 and set precision to bf16-mixed if available
1 parent 0651a73 commit ce2a59d

1 file changed

Lines changed: 6 additions & 1 deletion

File tree

mmlearn/cli/run.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from omegaconf import OmegaConf
1313
from pytorch_lightning.utilities import rank_zero_only
1414
from torch.utils.data import DataLoader
15+
from transformers.utils.import_utils import is_torch_tf32_available
1516

1617
from mmlearn.cli._instantiators import (
1718
instantiate_callbacks,
@@ -41,7 +42,11 @@ def main(cfg: MMLearnConf) -> None: # noqa: PLR0912
4142
cfg_copy = copy.deepcopy(cfg) # copy of the config for logging
4243

4344
L.seed_everything(cfg.seed, workers=True)
44-
torch.set_float32_matmul_precision("high")
45+
46+
if is_torch_tf32_available():
47+
torch.backends.cuda.matmul.allow_tf32 = True
48+
if "16-mixed" in cfg.trainer.precision:
49+
cfg.trainer.precision = "bf16-mixed"
4550

4651
# setup trainer first so that we can get some variables for distributed training
4752
callbacks = instantiate_callbacks(cfg.trainer.get("callbacks"))

0 commit comments

Comments
 (0)