Skip to content

Commit e707803

Browse files
authored
feat: add util for tokenizer pad id (#310)
* add tests for new util * update test for coverage * test pad_token passing
1 parent b137dc6 commit e707803

5 files changed

Lines changed: 83 additions & 7 deletions

File tree

model2vec/train/base.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.utils.data import DataLoader, Dataset
1212

1313
from model2vec import StaticModel
14+
from model2vec.train.utils import get_probable_pad_token_id
1415

1516
logger = logging.getLogger(__name__)
1617

@@ -82,7 +83,7 @@ def from_pretrained(
8283

8384
@classmethod
8485
def from_static_model(
85-
cls: type[ModelType], *, model: StaticModel, out_dim: int = 2, pad_token: str = "[PAD]", **kwargs: Any
86+
cls: type[ModelType], *, model: StaticModel, out_dim: int = 2, pad_token: str | None = None, **kwargs: Any
8687
) -> ModelType:
8788
"""Load the model from a static model."""
8889
model.embedding = np.nan_to_num(model.embedding)
@@ -92,9 +93,13 @@ def from_static_model(
9293
token_mapping = model.token_mapping.tolist()
9394
else:
9495
token_mapping = None
96+
if pad_token is not None:
97+
pad_id = model.tokenizer.get_vocab()[pad_token]
98+
else:
99+
pad_id = get_probable_pad_token_id(model.tokenizer)
95100
return cls(
96101
vectors=embeddings_converted,
97-
pad_id=model.tokenizer.token_to_id(pad_token),
102+
pad_id=pad_id,
98103
out_dim=out_dim,
99104
tokenizer=model.tokenizer,
100105
token_mapping=token_mapping,

model2vec/train/utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import logging
2+
3+
from tokenizers import Tokenizer
4+
5+
logger = logging.getLogger(__name__)
6+
7+
_KNOWN_PAD_TOKENS = ("[PAD]", "<pad>")
8+
9+
10+
def get_probable_pad_token_id(tokenizer: Tokenizer) -> int:
11+
"""Get a probable pad token by using the padding module and falling back to guessing."""
12+
if tokenizer.padding is not None:
13+
return tokenizer.padding["pad_id"]
14+
vocab = tokenizer.get_vocab()
15+
for token in _KNOWN_PAD_TOKENS:
16+
token_id = vocab.get(token)
17+
if token_id is not None:
18+
return token_id
19+
20+
logger.warning("No known pad token found, using 0 as default")
21+
return 0

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ dev = [
6060
"ruff",
6161
]
6262

63-
distill = ["torch", "transformers", "scikit-learn", "skeletoken>=0.3.1"]
63+
distill = ["torch", "transformers", "scikit-learn", "skeletoken>=0.3.2"]
6464
onnx = ["onnx", "torch"]
6565
# train also installs inference
6666
train = ["torch", "lightning", "scikit-learn", "skops"]

tests/test_trainable.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
import logging
12
from tempfile import TemporaryDirectory
23

34
import numpy as np
45
import pytest
56
import torch
7+
from skeletoken import TokenizerModel
68
from tokenizers import Tokenizer
79
from transformers import AutoTokenizer
810

911
from model2vec.model import StaticModel
1012
from model2vec.train import StaticModelForClassification
1113
from model2vec.train.base import FinetunableStaticModel, TextDataset
14+
from model2vec.train.utils import get_probable_pad_token_id
1215

1316

1417
@pytest.mark.parametrize("n_layers", [0, 1, 2, 3])
@@ -67,6 +70,21 @@ def test_init_classifier_from_model(mock_vectors: np.ndarray, mock_tokenizer: To
6770
assert s.w.shape[0] == mock_vectors.shape[0]
6871

6972

73+
def test_pad_token(mock_tokenizer: Tokenizer) -> None:
74+
"""Test initializion from a static model."""
75+
tokenizer_model = TokenizerModel.from_tokenizer(mock_tokenizer)
76+
tokenizer_model.pad_token = "[HELLO]"
77+
tokenizer = tokenizer_model.to_tokenizer()
78+
vectors = np.random.RandomState().randn(6, 10)
79+
model = StaticModel(vectors=vectors, tokenizer=tokenizer)
80+
s = StaticModelForClassification.from_static_model(model=model, pad_token="[HELLO]")
81+
assert s.w.shape[0] == vectors.shape[0]
82+
assert s.pad_id == 5
83+
84+
with pytest.raises(KeyError):
85+
StaticModelForClassification.from_static_model(model=model, pad_token="[BRR]")
86+
87+
7088
def test_encode(mock_trained_pipeline: StaticModelForClassification) -> None:
7189
"""Test the encode function."""
7290
result = mock_trained_pipeline._encode(torch.tensor([[0, 1], [1, 0]]).long())
@@ -231,3 +249,35 @@ def test_evaluate(mock_trained_pipeline: StaticModelForClassification) -> None:
231249
else:
232250
# Ignore the type error since we don't support int labels in our typing, but the code does
233251
mock_trained_pipeline.evaluate(["dog cat", "dog"], [1, 1]) # type: ignore
252+
253+
254+
def test_get_probable_pad_token_id(mock_tokenizer: Tokenizer, caplog: pytest.LogCaptureFixture) -> None:
255+
"""Test loading from a static model with a pad token."""
256+
tokenizer_model = TokenizerModel.from_tokenizer(mock_tokenizer)
257+
t = tokenizer_model.to_tokenizer()
258+
token_id = get_probable_pad_token_id(t)
259+
assert token_id == 0
260+
261+
# Adds new token
262+
tokenizer_model.pad_token = "haha"
263+
t = tokenizer_model.to_tokenizer()
264+
token_id = get_probable_pad_token_id(t)
265+
assert token_id == 5
266+
267+
tokenizer_model.pad_token = "word1"
268+
t = tokenizer_model.to_tokenizer()
269+
token_id = get_probable_pad_token_id(t)
270+
assert token_id == 1
271+
272+
# Remove padding token
273+
tokenizer_model.pad_token = None
274+
t = tokenizer_model.to_tokenizer()
275+
token_id = get_probable_pad_token_id(t)
276+
assert token_id == tokenizer_model.vocabulary["[PAD]"]
277+
278+
tokenizer_model = tokenizer_model.remove_token_from_vocabulary("[PAD]")
279+
t = tokenizer_model.to_tokenizer()
280+
with caplog.at_level(logging.WARNING, logger="model2vec.train.utils"):
281+
token_id = get_probable_pad_token_id(t)
282+
assert token_id == 0
283+
assert "No known pad token found, using 0 as default" in caplog.text

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)