diff --git a/Cargo.toml b/Cargo.toml index 84e8226..c00e7ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ either = "1.13.0" macro_rules_attribute = "0.2.0" once_cell = "1.19.0" paste = "1.0.14" -pyo3 = { version = "^0.27", features = ["extension-module"] } +pyo3 = { version = "^0.28", features = ["extension-module"] } regex = "1.10.3" serde = "1.0.197" serde_json = "1.0.114" diff --git a/src/tokenizer.rs b/src/tokenizer.rs index 41c61f8..a6c53c5 100644 --- a/src/tokenizer.rs +++ b/src/tokenizer.rs @@ -131,7 +131,7 @@ impl SmirkTokenizer { .map(|x| EncodeInput::from(x.to_string())) .collect(); // Release the GIL while tokenizing batch - let out = py.allow_threads(|| { + let out = py.detach(|| { self.tokenizer .encode_batch_char_offsets(inputs, add_special_tokens) .unwrap() @@ -149,7 +149,7 @@ impl SmirkTokenizer { ids: Vec>, skip_special_tokens: bool, ) -> PyResult> { - py.allow_threads(|| { + py.detach(|| { let sequences = ids.iter().map(|x| &x[..]).collect::>(); Ok(self .tokenizer @@ -434,7 +434,7 @@ impl SmirkTokenizer { // Train tokenizer let mut trainer: TrainerWrapper = builder.build().unwrap().into(); - let _ = py.allow_threads(|| tokenizer.train_from_files(&mut trainer, files).unwrap()); + let _ = py.detach(|| tokenizer.train_from_files(&mut trainer, files).unwrap()); Ok(SmirkTokenizer::new(tokenizer)) } }