@@ -2547,10 +2547,10 @@ def forward(
25472547 return input_embeds , pos_embeds
25482548
25492549
2550- class AtomformerPreTrainedModel (PreTrainedModel ):
2550+ class AtomformerPreTrainedModel (PreTrainedModel ): # type: ignore[no-untyped-call]
25512551 """Base class for all transformer models."""
25522552
2553- config_class = AtomformerConfig # type: ignore[assignment]
2553+ config_class = AtomformerConfig
25542554 base_model_prefix = "model"
25552555 supports_gradient_checkpointing = True
25562556 _no_split_modules = ["ParallelBlock" ]
@@ -2562,7 +2562,7 @@ def _set_gradient_checkpointing( # type: ignore[override]
25622562 module .gradient_checkpointing = value
25632563
25642564
2565- class AtomformerModel (AtomformerPreTrainedModel ):
2565+ class AtomformerModel (AtomformerPreTrainedModel ): # type: ignore[no-untyped-call]
25662566 """Atomformer model for atom modeling."""
25672567
25682568 def __init__ (self , config : AtomformerConfig ):
@@ -2581,7 +2581,7 @@ def forward(
25812581 return output
25822582
25832583
2584- class AtomformerForMaskedAM (AtomformerPreTrainedModel ):
2584+ class AtomformerForMaskedAM (AtomformerPreTrainedModel ): # type: ignore[no-untyped-call]
25852585 """Atomformer with an atom modeling head on top for masked atom modeling."""
25862586
25872587 def __init__ (self , config : AtomformerConfig ):
@@ -2611,7 +2611,7 @@ def forward(
26112611 return loss , logits
26122612
26132613
2614- class AtomformerForCoordinateAM (AtomformerPreTrainedModel ):
2614+ class AtomformerForCoordinateAM (AtomformerPreTrainedModel ): # type: ignore[no-untyped-call]
26152615 """Atomformer with an atom coordinate head on top for coordinate denoising."""
26162616
26172617 def __init__ (self , config : AtomformerConfig ):
@@ -2641,7 +2641,7 @@ def forward(
26412641 return loss , coords_pred
26422642
26432643
2644- class InitialStructure2RelaxedStructure (AtomformerPreTrainedModel ):
2644+ class InitialStructure2RelaxedStructure (AtomformerPreTrainedModel ): # type: ignore[no-untyped-call]
26452645 """Atomformer with an coordinate head on top for relaxed structure prediction."""
26462646
26472647 def __init__ (self , config : AtomformerConfig ):
@@ -2674,7 +2674,7 @@ def forward(
26742674 return loss , coords_pred
26752675
26762676
2677- class InitialStructure2RelaxedEnergy (AtomformerPreTrainedModel ):
2677+ class InitialStructure2RelaxedEnergy (AtomformerPreTrainedModel ): # type: ignore[no-untyped-call]
26782678 """Atomformer with an energy head on top for relaxed energy prediction."""
26792679
26802680 def __init__ (self , config : AtomformerConfig ):
@@ -2704,7 +2704,7 @@ def forward(
27042704 return loss , energy
27052705
27062706
2707- class InitialStructure2RelaxedStructureAndEnergy (AtomformerPreTrainedModel ):
2707+ class InitialStructure2RelaxedStructureAndEnergy (AtomformerPreTrainedModel ): # type: ignore[no-untyped-call]
27082708 """Atomformer with an coordinate and energy head."""
27092709
27102710 def __init__ (self , config : AtomformerConfig ):
@@ -2757,7 +2757,7 @@ def forward(
27572757 return loss , (formation_energy_pred , coords_pred )
27582758
27592759
2760- class Structure2Energy (AtomformerPreTrainedModel ):
2760+ class Structure2Energy (AtomformerPreTrainedModel ): # type: ignore[no-untyped-call]
27612761 """Atomformer with an atom modeling head on top for masked atom modeling."""
27622762
27632763 def __init__ (self , config : AtomformerConfig ):
@@ -2799,7 +2799,7 @@ def forward(
27992799 )
28002800
28012801
2802- class Structure2Forces (AtomformerPreTrainedModel ):
2802+ class Structure2Forces (AtomformerPreTrainedModel ): # type: ignore[no-untyped-call]
28032803 """Atomformer with a forces head on top for forces prediction."""
28042804
28052805 def __init__ (self , config : AtomformerConfig ):
@@ -2841,7 +2841,7 @@ def forward(
28412841 )
28422842
28432843
2844- class Structure2EnergyAndForces (AtomformerPreTrainedModel ):
2844+ class Structure2EnergyAndForces (AtomformerPreTrainedModel ): # type: ignore[no-untyped-call]
28452845 """Atomformer with an energy and forces head for energy and forces prediction."""
28462846
28472847 def __init__ (self , config : AtomformerConfig ):
@@ -2892,7 +2892,7 @@ def forward(
28922892 return loss , (formation_energy_pred , forces_pred , attention_mask )
28932893
28942894
2895- class Structure2TotalEnergyAndForces (AtomformerPreTrainedModel ):
2895+ class Structure2TotalEnergyAndForces (AtomformerPreTrainedModel ): # type: ignore[no-untyped-call]
28962896 """Atomformer with an energy and forces head for energy and forces prediction."""
28972897
28982898 def __init__ (self , config : AtomformerConfig ):
@@ -2949,7 +2949,7 @@ def forward(
29492949 return loss , (total_energy_pred , forces_pred , attention_mask )
29502950
29512951
2952- class AtomFormerForSystemClassification (AtomformerPreTrainedModel ):
2952+ class AtomFormerForSystemClassification (AtomformerPreTrainedModel ): # type: ignore[no-untyped-call]
29532953 """Atomformer with a classification head for system classification."""
29542954
29552955 def __init__ (self , config : AtomformerConfig ):
0 commit comments