Skip to content

Commit 3b24649

Browse files
committed
fix: types
1 parent 6310613 commit 3b24649

21 files changed

Lines changed: 181 additions & 135 deletions

File tree

.github/workflows/test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
run: uvx ruff format --check
2828

2929
- name: Typecheck
30-
run: uvx mypy src --enable-incomplete-feature=NewGenericSyntax
30+
run: uvx ty check
3131

3232
- name: Test
3333
run: uv run pytest tests

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ repos:
2121
stages: [commit-msg]
2222
- id: typecheck
2323
name: typecheck
24-
entry: uvx mypy src/scratch/
24+
entry: uvx ty check
2525
language: system
2626
types: [python]
2727
pass_filenames: false

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ dev = [
4242
"ipykernel>=6.29.5",
4343
"jupyter>=1.1.1",
4444
"ruff>=0.8.4",
45+
"ty>=0.0.1a10",
46+
4547
]
4648

4749
[tool.uv]
@@ -62,6 +64,8 @@ packages = ["src/scratch"]
6264
target-version = "py312"
6365
include = ["src/**", "tests/**"]
6466
line-length = 88
67+
68+
[tool.ruff.lint]
6569
select = [
6670
"E", # pycodestyle
6771
"W", # pycodestyle
@@ -79,10 +83,10 @@ select = [
7983
]
8084
ignore = ["F722"]
8185

82-
[tool.ruff.pydocstyle]
86+
[tool.ruff.lint.pydocstyle]
8387
convention = "google"
8488

85-
[tool.ruff.isort]
89+
[tool.ruff.lint.isort]
8690
known-first-party = ["src"]
8791

8892
[tool.pytest.ini_options]

src/scratch/datasets/utils.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import warnings
55
from dataclasses import dataclass, field
66

7-
from transformers import AutoTokenizer, PreTrainedTokenizerBase
7+
from transformers import AutoTokenizer
8+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
89

910

1011
def patch_datasets_warning():
@@ -32,16 +33,6 @@ def filter_specific_warning(warning):
3233
frame = frame.f_back
3334
return False
3435

35-
# Register the custom filter
36-
warnings.filterwarnings("ignore", category=UserWarning, module=r".*")
37-
warnings.showwarning = (
38-
lambda message, category, filename, lineno, file=None, line=None: None
39-
if filter_specific_warning(
40-
warnings.WarningMessage(message, category, filename, lineno)
41-
)
42-
else warnings.showwarning(message, category, filename, lineno)
43-
)
44-
4536

4637
@dataclass
4738
class TokenizerMetadata:
@@ -62,7 +53,7 @@ class TokenizerMetadata:
6253
def from_tokenizer(cls, tokenizer: PreTrainedTokenizerBase, max_length: int):
6354
"""Create metadata from a tokenizer instance."""
6455
vocab_size = tokenizer.vocab_size # type: ignore
65-
if not vocab_size:
56+
if not vocab_size or not isinstance(vocab_size, int):
6657
raise ValueError("The tokenizer does not have a vocab size.")
6758
return cls(
6859
vocab_size=vocab_size,

src/scratch/deep_learning/layers/attention/rope.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def precompute_theta_pos_freqs(dim: int, end: int, theta: float = 10000.0):
9292
"""
9393
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)] / dim))
9494
t = jnp.arange(end, dtype=jnp.float32)
95+
96+
assert isinstance(t, jnp.ndarray)
97+
assert isinstance(freqs, jnp.ndarray)
98+
9599
freqs = jnp.outer(t, freqs)
96100
freqs_cis = jnp.exp(1j * freqs) # Using Euler's formula to create complex numbers
97101
return freqs_cis

src/scratch/image_classification/cnn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,7 @@ def __call__(self, x):
112112
# And comment out the following line
113113
logger = None
114114

115-
trainer = ImageClassificationParallelTrainer(model, trainer_config, logger=logger)
115+
trainer = ImageClassificationParallelTrainer[CNN](
116+
model, trainer_config, logger=logger
117+
)
116118
trainer.train_and_evaluate(dataset.train, dataset.test)

src/scratch/image_classification/resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,5 +321,5 @@ def __call__(self, x):
321321
trainer_config = ImageClassificationParallelTrainerConfig(
322322
batch_size=batch_size, learning_rate=0.01, epochs=3
323323
)
324-
trainer = ImageClassificationParallelTrainer(model, trainer_config)
324+
trainer = ImageClassificationParallelTrainer[ResNet](model, trainer_config)
325325
trainer.train_and_evaluate(dataset.train, dataset.test)

src/scratch/image_classification/swin_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,5 +741,5 @@ def __call__(self, x: jnp.ndarray, train=True):
741741
trainer_config = ImageClassificationParallelTrainerConfig(
742742
batch_size=batch_size, epochs=5
743743
)
744-
trainer = ImageClassificationParallelTrainer(model, trainer_config)
744+
trainer = ImageClassificationParallelTrainer[SwinTransformer](model, trainer_config)
745745
trainer.train_and_evaluate(dataset.train, dataset.test)

src/scratch/image_classification/trainer.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
devices.
66
"""
77

8-
from collections.abc import Callable
98
from dataclasses import dataclass
10-
from typing import TypeVar
9+
from typing import Protocol, TypeVar
1110

1211
import jax
1312
import jax.numpy as jnp
@@ -26,6 +25,14 @@
2625
M = TypeVar("M", bound=nnx.Module)
2726

2827

28+
class CallableModule(Protocol):
29+
"""Protocol for callable modules."""
30+
31+
def __call__(self, *args, **kwargs) -> jnp.ndarray:
32+
"""Call the module."""
33+
... # pragma: no cover
34+
35+
2936
@dataclass
3037
class ImageClassificationParallelTrainerConfig(SupervisedTrainerConfig):
3138
"""Configuration for the ImageClassificationParallelTrainer."""
@@ -83,8 +90,8 @@ def train(
8390
def train_step(
8491
model: M, train_state: TrainState, inputs: jnp.ndarray, targets: jnp.ndarray
8592
):
86-
def loss_fn(model: Callable):
87-
logits = model(inputs)
93+
def loss_fn(model: nnx.Module):
94+
logits = model(inputs) # type: ignore
8895
assert logits.shape == targets.shape
8996
loss = optax.softmax_cross_entropy(logits=logits, labels=targets).mean()
9097
return loss, logits

src/scratch/image_classification/vision_transformer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,5 +214,7 @@ def img_to_patch(x: jnp.ndarray, patch_size: int):
214214
trainer_config = ImageClassificationParallelTrainerConfig(
215215
batch_size=batch_size, epochs=5
216216
)
217-
trainer = ImageClassificationParallelTrainer(model, trainer_config)
217+
trainer = ImageClassificationParallelTrainer[VisionTransformer](
218+
model, trainer_config
219+
)
218220
trainer.train_and_evaluate(dataset.train, dataset.test)

0 commit comments

Comments
 (0)