@@ -26,6 +26,7 @@ struct PythonBindGenerator {
2626
2727impl PythonBindGenerator {
2828 const BASE_TYPES : [ & ' static str ; 6 ] = [ "bool" , "i32" , "u32" , "f32" , "String" , "u8" ] ;
29+ const SPECIAL_BASE_TYPES : [ & ' static str ; 2 ] = [ "FloatT" , "BoolT" ] ;
2930
3031 fn new ( path : & Path ) -> Option < Self > {
3132 // get the filename without the extension
@@ -952,6 +953,7 @@ impl PythonBindGenerator {
952953 }
953954
954955 let mut signature_parts = Vec :: new ( ) ;
956+ let mut needs_python = false ;
955957
956958 for variable_info in & self . types {
957959 let variable_name = & variable_info[ 0 ] ;
@@ -962,6 +964,15 @@ impl PythonBindGenerator {
962964 }
963965
964966 if variable_type. starts_with ( "Option<" ) {
967+ let inner_type = variable_type
968+ . trim_start_matches ( "Option<" )
969+ . trim_start_matches ( "Box<" )
970+ . trim_end_matches ( '>' ) ;
971+
972+ if Self :: SPECIAL_BASE_TYPES . contains ( & inner_type) {
973+ needs_python = true ;
974+ }
975+
965976 signature_parts. push ( format ! ( "{variable_name}=None" ) ) ;
966977 } else if !Self :: BASE_TYPES . contains ( & variable_type. as_str ( ) )
967978 && ( variable_type. starts_with ( "Box<" ) || variable_type. ends_with ( 'T' ) )
@@ -973,9 +984,12 @@ impl PythonBindGenerator {
973984 }
974985
975986 self . write_string ( format ! ( " #[pyo3(signature = ({}))]" , signature_parts. join( ", " ) ) ) ;
976-
977987 self . write_str ( " pub fn new(" ) ;
978988
989+ if needs_python {
990+ self . write_str ( " py: Python," ) ;
991+ }
992+
979993 for variable_info in & self . types {
980994 let variable_name = & variable_info[ 0 ] ;
981995 let mut variable_type = variable_info[ 1 ] . to_string ( ) ;
@@ -1002,6 +1016,10 @@ impl PythonBindGenerator {
10021016
10031017 if Self :: BASE_TYPES . contains ( & inner_type) {
10041018 variable_type = format ! ( "Option<{inner_type}>" ) ;
1019+ } else if inner_type == "Float" {
1020+ variable_type = String :: from ( "Option<crate::Floats>" ) ;
1021+ } else if inner_type == "Bool" {
1022+ variable_type = String :: from ( "Option<crate::Bools>" ) ;
10051023 } else {
10061024 variable_type = format ! ( "Option<Py<super::{inner_type}>>" ) ;
10071025 }
@@ -1022,8 +1040,18 @@ impl PythonBindGenerator {
10221040
10231041 for variable_info in & self . types {
10241042 let variable_name = & variable_info[ 0 ] ;
1043+ let variable_type = variable_info[ 1 ]
1044+ . trim_start_matches ( "Option<" )
1045+ . trim_start_matches ( "Box<" )
1046+ . trim_end_matches ( '>' ) ;
10251047
1026- self . file_contents . push ( Cow :: Owned ( format ! ( " {variable_name}," ) ) ) ;
1048+ if Self :: SPECIAL_BASE_TYPES . contains ( & variable_type) {
1049+ self . file_contents . push ( Cow :: Owned ( format ! (
1050+ " {variable_name}: {variable_name}.map(|x| x.into_gil(py)),"
1051+ ) ) ) ;
1052+ } else {
1053+ self . file_contents . push ( Cow :: Owned ( format ! ( " {variable_name}," ) ) ) ;
1054+ }
10271055 }
10281056
10291057 self . write_str ( " }" ) ;
@@ -1407,6 +1435,10 @@ fn pyi_generator(type_data: &[(String, String, Vec<Vec<String>>)]) -> io::Result
14071435 "float"
14081436 } else if type_name == "String" {
14091437 "str"
1438+ } else if type_name == "Float" {
1439+ "Float | float"
1440+ } else if type_name == "Bool" {
1441+ "Bool | bool"
14101442 } else {
14111443 type_name
14121444 } ;
0 commit comments