Skip to content

Commit cadf95c

Browse files
authored
Merge branch 'main' into dependabot/github_actions/astral-sh/setup-uv-6.6.1
2 parents e2a398f + 0063966 commit cadf95c

9 files changed

Lines changed: 2263 additions & 2089 deletions

File tree

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ jobs:
8585
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
8686
8787
- name: Download artifact
88-
uses: actions/download-artifact@v4
88+
uses: actions/download-artifact@v5
8989
with:
9090
name: docs-site
9191
path: site

.pre-commit-config.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v5.0.0 # Use the ref you want to point at
3+
rev: v6.0.0 # Use the ref you want to point at
44
hooks:
55
- id: trailing-whitespace
66
- id: check-ast
@@ -17,7 +17,7 @@ repos:
1717
- id: check-toml
1818

1919
- repo: https://github.com/astral-sh/ruff-pre-commit
20-
rev: 'v0.9.10'
20+
rev: 'v0.12.11'
2121
hooks:
2222
- id: ruff
2323
args: [--fix, --exit-non-zero-on-fix]
@@ -26,7 +26,7 @@ repos:
2626
types_or: [python, jupyter]
2727

2828
- repo: https://github.com/pre-commit/mirrors-mypy
29-
rev: v1.15.0
29+
rev: v1.17.1
3030
hooks:
3131
- id: mypy
3232
entry: python3 -m mypy --config-file pyproject.toml
@@ -35,7 +35,7 @@ repos:
3535
exclude: "tests"
3636

3737
- repo: https://github.com/crate-ci/typos
38-
rev: v1
38+
rev: v1.35.6
3939
hooks:
4040
- id: typos
4141
args: [--force-exclude]

atomgen/models/configuration_atomformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
cls_token_id: int = 122,
3131
**kwargs: Any,
3232
) -> None:
33-
super().__init__(**kwargs) # type: ignore[no-untyped-call]
33+
super().__init__(**kwargs)
3434
self.vocab_size = vocab_size
3535
self.dim = dim
3636
self.num_heads = num_heads

atomgen/models/modeling_atomformer.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2547,10 +2547,10 @@ def forward(
25472547
return input_embeds, pos_embeds
25482548

25492549

2550-
class AtomformerPreTrainedModel(PreTrainedModel):
2550+
class AtomformerPreTrainedModel(PreTrainedModel): # type: ignore[no-untyped-call]
25512551
"""Base class for all transformer models."""
25522552

2553-
config_class = AtomformerConfig # type: ignore[assignment]
2553+
config_class = AtomformerConfig
25542554
base_model_prefix = "model"
25552555
supports_gradient_checkpointing = True
25562556
_no_split_modules = ["ParallelBlock"]
@@ -2562,7 +2562,7 @@ def _set_gradient_checkpointing( # type: ignore[override]
25622562
module.gradient_checkpointing = value
25632563

25642564

2565-
class AtomformerModel(AtomformerPreTrainedModel):
2565+
class AtomformerModel(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
25662566
"""Atomformer model for atom modeling."""
25672567

25682568
def __init__(self, config: AtomformerConfig):
@@ -2581,7 +2581,7 @@ def forward(
25812581
return output
25822582

25832583

2584-
class AtomformerForMaskedAM(AtomformerPreTrainedModel):
2584+
class AtomformerForMaskedAM(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
25852585
"""Atomformer with an atom modeling head on top for masked atom modeling."""
25862586

25872587
def __init__(self, config: AtomformerConfig):
@@ -2611,7 +2611,7 @@ def forward(
26112611
return loss, logits
26122612

26132613

2614-
class AtomformerForCoordinateAM(AtomformerPreTrainedModel):
2614+
class AtomformerForCoordinateAM(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
26152615
"""Atomformer with an atom coordinate head on top for coordinate denoising."""
26162616

26172617
def __init__(self, config: AtomformerConfig):
@@ -2641,7 +2641,7 @@ def forward(
26412641
return loss, coords_pred
26422642

26432643

2644-
class InitialStructure2RelaxedStructure(AtomformerPreTrainedModel):
2644+
class InitialStructure2RelaxedStructure(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
26452645
"""Atomformer with an coordinate head on top for relaxed structure prediction."""
26462646

26472647
def __init__(self, config: AtomformerConfig):
@@ -2674,7 +2674,7 @@ def forward(
26742674
return loss, coords_pred
26752675

26762676

2677-
class InitialStructure2RelaxedEnergy(AtomformerPreTrainedModel):
2677+
class InitialStructure2RelaxedEnergy(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
26782678
"""Atomformer with an energy head on top for relaxed energy prediction."""
26792679

26802680
def __init__(self, config: AtomformerConfig):
@@ -2704,7 +2704,7 @@ def forward(
27042704
return loss, energy
27052705

27062706

2707-
class InitialStructure2RelaxedStructureAndEnergy(AtomformerPreTrainedModel):
2707+
class InitialStructure2RelaxedStructureAndEnergy(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
27082708
"""Atomformer with an coordinate and energy head."""
27092709

27102710
def __init__(self, config: AtomformerConfig):
@@ -2757,7 +2757,7 @@ def forward(
27572757
return loss, (formation_energy_pred, coords_pred)
27582758

27592759

2760-
class Structure2Energy(AtomformerPreTrainedModel):
2760+
class Structure2Energy(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
27612761
"""Atomformer with an atom modeling head on top for masked atom modeling."""
27622762

27632763
def __init__(self, config: AtomformerConfig):
@@ -2799,7 +2799,7 @@ def forward(
27992799
)
28002800

28012801

2802-
class Structure2Forces(AtomformerPreTrainedModel):
2802+
class Structure2Forces(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
28032803
"""Atomformer with a forces head on top for forces prediction."""
28042804

28052805
def __init__(self, config: AtomformerConfig):
@@ -2841,7 +2841,7 @@ def forward(
28412841
)
28422842

28432843

2844-
class Structure2EnergyAndForces(AtomformerPreTrainedModel):
2844+
class Structure2EnergyAndForces(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
28452845
"""Atomformer with an energy and forces head for energy and forces prediction."""
28462846

28472847
def __init__(self, config: AtomformerConfig):
@@ -2892,7 +2892,7 @@ def forward(
28922892
return loss, (formation_energy_pred, forces_pred, attention_mask)
28932893

28942894

2895-
class Structure2TotalEnergyAndForces(AtomformerPreTrainedModel):
2895+
class Structure2TotalEnergyAndForces(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
28962896
"""Atomformer with an energy and forces head for energy and forces prediction."""
28972897

28982898
def __init__(self, config: AtomformerConfig):
@@ -2949,7 +2949,7 @@ def forward(
29492949
return loss, (total_energy_pred, forces_pred, attention_mask)
29502950

29512951

2952-
class AtomFormerForSystemClassification(AtomformerPreTrainedModel):
2952+
class AtomFormerForSystemClassification(AtomformerPreTrainedModel): # type: ignore[no-untyped-call]
29532953
"""Atomformer with a classification head for system classification."""
29542954

29552955
def __init__(self, config: AtomformerConfig):

atomgen/models/schnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
cls_token_id: int = 122,
106106
**kwargs: Any,
107107
):
108-
super().__init__(**kwargs) # type: ignore[no-untyped-call]
108+
super().__init__(**kwargs)
109109
self.vocab_size = vocab_size
110110
self.hidden_channels = hidden_channels
111111
self.num_filters = num_filters
@@ -126,20 +126,20 @@ def __init__(
126126
self.cls_token_id = cls_token_id
127127

128128

129-
class SchNetPreTrainedModel(PreTrainedModel):
129+
class SchNetPreTrainedModel(PreTrainedModel): # type: ignore[no-untyped-call]
130130
"""
131131
A base class for all SchNet models.
132132
133133
An abstract class to handle weights initialization and a
134134
simple interface for loading and exporting models.
135135
"""
136136

137-
config_class = SchNetConfig # type: ignore[assignment]
137+
config_class = SchNetConfig
138138
base_model_prefix = "model"
139139
supports_gradient_checkpointing = False
140140

141141

142-
class SchNetModel(SchNetPreTrainedModel):
142+
class SchNetModel(SchNetPreTrainedModel): # type: ignore[no-untyped-call]
143143
"""
144144
SchNet model for energy prediction.
145145

atomgen/models/tokengt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2351,7 +2351,7 @@ def __init__(
23512351
gradient_checkpointing: bool = False,
23522352
**kwargs: Any,
23532353
):
2354-
super().__init__(**kwargs) # type: ignore[no-untyped-call]
2354+
super().__init__(**kwargs)
23552355
self.vocab_size = vocab_size
23562356
self.dim = dim
23572357
self.num_heads = num_heads
@@ -2507,15 +2507,15 @@ def custom_forward(*inputs: Any) -> Any:
25072507
return input_embeds
25082508

25092509

2510-
class TransformerPreTrainedModel(PreTrainedModel):
2510+
class TransformerPreTrainedModel(PreTrainedModel): # type: ignore[no-untyped-call]
25112511
"""Base class for all transformer models."""
25122512

25132513
config_class = TransformerConfig
25142514
base_model_prefix = "model"
25152515
supports_gradient_checkpointing = True
25162516
_no_split_modules = ["ParallelBlock"]
25172517

2518-
def _set_gradient_checkpointing(
2518+
def _set_gradient_checkpointing( # type: ignore[override]
25192519
self, module: nn.Module, value: bool = False
25202520
) -> None:
25212521
if isinstance(module, (TransformerEncoder)):

scripts/training/pretrain_s2ef.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def train(args: argparse.Namespace) -> None:
162162
config.gradient_checkpointing = (
163163
args.gradient_checkpointing if args.gradient_checkpointing else False
164164
)
165-
model = Structure2EnergyAndForces(config) # type: ignore[arg-type]
165+
model = Structure2EnergyAndForces(config)
166166

167167
tokenizer = AtomTokenizer(vocab_file=args.tokenizer_json)
168168
data_collator = DataCollatorForAtomModeling(
@@ -173,7 +173,7 @@ def train(args: argparse.Namespace) -> None:
173173
return_edge_indices=False,
174174
)
175175

176-
local_rank = int(os.environ.get("LOCAL_RANK", 0))
176+
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
177177
if local_rank == 0:
178178
wandb.login(key=os.environ["WANDB_API_KEY"])
179179
wandb.init(project=args.project, config=vars(args), name=args.name)
@@ -207,14 +207,14 @@ def train(args: argparse.Namespace) -> None:
207207
weight_decay=args.weight_decay,
208208
)
209209

210-
trainer = Trainer( # type: ignore[no-untyped-call]
210+
trainer = Trainer(
211211
model=model,
212212
args=training_args,
213213
train_dataset=dataset,
214214
data_collator=data_collator,
215215
)
216216

217-
trainer.train(resume_from_checkpoint=args.checkpoint_exists) # type: ignore[attr-defined]
217+
trainer.train(resume_from_checkpoint=args.checkpoint_exists)
218218

219219
model.save_pretrained(args.output_dir)
220220

scripts/training/run_atom3d.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def run_atom3d(args: argparse.Namespace) -> None:
121121
else False,
122122
problem_type=task_config["problem_type"],
123123
)
124-
model = AtomFormerForSystemClassification(config) # type: ignore[arg-type]
124+
model = AtomFormerForSystemClassification(config)
125125
else:
126126
config = AtomformerConfig.from_pretrained(
127127
args.model,
@@ -157,7 +157,7 @@ def run_atom3d(args: argparse.Namespace) -> None:
157157
return_edge_indices=False,
158158
)
159159

160-
local_rank = int(os.environ.get("LOCAL_RANK", 0))
160+
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
161161
if local_rank == 0:
162162
wandb.login(key=os.environ.get("WANDB_API_KEY"))
163163
wandb.init(project=args.project, config=vars(args), name=args.name)
@@ -182,7 +182,7 @@ def run_atom3d(args: argparse.Namespace) -> None:
182182
)
183183

184184
# Initialize trainer
185-
trainer = Trainer( # type: ignore[no-untyped-call]
185+
trainer = Trainer(
186186
model=model,
187187
args=training_args,
188188
train_dataset=dataset["train"],
@@ -192,12 +192,12 @@ def run_atom3d(args: argparse.Namespace) -> None:
192192
)
193193

194194
# Train the model
195-
trainer.train() # type: ignore[attr-defined]
195+
trainer.train()
196196

197-
trainer.evaluate(dataset["test"]) # type: ignore[attr-defined]
197+
trainer.evaluate(dataset["test"])
198198

199199
# Save the model
200-
trainer.save_model(args.output_dir) # type: ignore[attr-defined]
200+
trainer.save_model(args.output_dir)
201201

202202

203203
if __name__ == "__main__":

0 commit comments

Comments
 (0)