Skip to content

Commit bb64ce1

Browse files
committed
Special cases for Float and Bool
Now accept normal Python floats or bools along with Float and Bool objects
1 parent 966ab34 commit bb64ce1

3 files changed

Lines changed: 92 additions & 18 deletions

File tree

build.rs

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct PythonBindGenerator {
2626

2727
impl PythonBindGenerator {
2828
const BASE_TYPES: [&'static str; 6] = ["bool", "i32", "u32", "f32", "String", "u8"];
29+
const SPECIAL_BASE_TYPES: [&'static str; 2] = ["FloatT", "BoolT"];
2930

3031
fn new(path: &Path) -> Option<Self> {
3132
// get the filename without the extension
@@ -952,6 +953,7 @@ impl PythonBindGenerator {
952953
}
953954

954955
let mut signature_parts = Vec::new();
956+
let mut needs_python = false;
955957

956958
for variable_info in &self.types {
957959
let variable_name = &variable_info[0];
@@ -962,6 +964,15 @@ impl PythonBindGenerator {
962964
}
963965

964966
if variable_type.starts_with("Option<") {
967+
let inner_type = variable_type
968+
.trim_start_matches("Option<")
969+
.trim_start_matches("Box<")
970+
.trim_end_matches('>');
971+
972+
if Self::SPECIAL_BASE_TYPES.contains(&inner_type) {
973+
needs_python = true;
974+
}
975+
965976
signature_parts.push(format!("{variable_name}=None"));
966977
} else if !Self::BASE_TYPES.contains(&variable_type.as_str())
967978
&& (variable_type.starts_with("Box<") || variable_type.ends_with('T'))
@@ -973,9 +984,12 @@ impl PythonBindGenerator {
973984
}
974985

975986
self.write_string(format!(" #[pyo3(signature = ({}))]", signature_parts.join(", ")));
976-
977987
self.write_str(" pub fn new(");
978988

989+
if needs_python {
990+
self.write_str(" py: Python,");
991+
}
992+
979993
for variable_info in &self.types {
980994
let variable_name = &variable_info[0];
981995
let mut variable_type = variable_info[1].to_string();
@@ -1002,6 +1016,10 @@ impl PythonBindGenerator {
10021016

10031017
if Self::BASE_TYPES.contains(&inner_type) {
10041018
variable_type = format!("Option<{inner_type}>");
1019+
} else if inner_type == "Float" {
1020+
variable_type = String::from("Option<crate::Floats>");
1021+
} else if inner_type == "Bool" {
1022+
variable_type = String::from("Option<crate::Bools>");
10051023
} else {
10061024
variable_type = format!("Option<Py<super::{inner_type}>>");
10071025
}
@@ -1022,8 +1040,18 @@ impl PythonBindGenerator {
10221040

10231041
for variable_info in &self.types {
10241042
let variable_name = &variable_info[0];
1043+
let variable_type = variable_info[1]
1044+
.trim_start_matches("Option<")
1045+
.trim_start_matches("Box<")
1046+
.trim_end_matches('>');
10251047

1026-
self.file_contents.push(Cow::Owned(format!(" {variable_name},")));
1048+
if Self::SPECIAL_BASE_TYPES.contains(&variable_type) {
1049+
self.file_contents.push(Cow::Owned(format!(
1050+
" {variable_name}: {variable_name}.map(|x| x.into_gil(py)),"
1051+
)));
1052+
} else {
1053+
self.file_contents.push(Cow::Owned(format!(" {variable_name},")));
1054+
}
10271055
}
10281056

10291057
self.write_str(" }");
@@ -1407,6 +1435,10 @@ fn pyi_generator(type_data: &[(String, String, Vec<Vec<String>>)]) -> io::Result
14071435
"float"
14081436
} else if type_name == "String" {
14091437
"str"
1438+
} else if type_name == "Float" {
1439+
"Float | float"
1440+
} else if type_name == "Bool" {
1441+
"Bool | bool"
14101442
} else {
14111443
type_name
14121444
};

pytest.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,6 @@
33
from rlbot_flatbuffers import *
44

55
if __name__ == "__main__":
6-
color = Color(255, 0, 0)
7-
print(repr(color))
8-
print(color)
9-
eval(repr(color))
10-
print()
11-
12-
controller = ControllerState(throttle=1)
13-
controller.boost = True
14-
15-
player_input = PlayerInput(0, controller)
16-
176
ready_message = ReadyMessage(True, wants_game_messages=True)
187
print(repr(ready_message))
198
print(ready_message)
@@ -45,18 +34,29 @@
4534

4635
num_trials = 1_000_000
4736

37+
total_make_time = 0
4838
total_pack_time = 0
4939
total_unpack_time = 0
5040
for _ in range(num_trials):
5141
start = time_ns()
52-
packed_bytes = player_input.pack()
42+
desired_game_state = DesiredGameState(
43+
DesiredBallState(DesiredPhysics()),
44+
car_states=[DesiredCarState(boost_amount=Float(100))],
45+
game_info_state=DesiredGameInfoState(game_speed=1, world_gravity_z=Float(-650), end_match=True),
46+
console_commands=[ConsoleCommand("freeze")],
47+
)
48+
total_make_time += time_ns() - start
49+
50+
start = time_ns()
51+
packed_bytes = desired_game_state.pack()
5352
total_pack_time += time_ns() - start
5453

5554
start = time_ns()
56-
PlayerInput.unpack(packed_bytes)
55+
DesiredGameState.unpack(packed_bytes)
5756
total_unpack_time += time_ns() - start
5857

59-
print(f"Average time to pack: {total_pack_time / num_trials} ns")
60-
print(f"Average time to unpack: {total_unpack_time / num_trials} ns")
58+
print(f"Average time to make: {round(total_make_time / num_trials, 2)}ns")
59+
print(f"Average time to pack: {round(total_pack_time / num_trials, 2)}ns")
60+
print(f"Average time to unpack: {round(total_unpack_time / num_trials, 2)}ns")
6161

62-
print(f"Total time: {(total_pack_time + total_unpack_time) / 1000000 }ms")
62+
print(f"Total time: {round((total_pack_time + total_unpack_time) / 1000000, 2)}ms")

src/lib.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,48 @@ pub const fn bool_to_str(b: bool) -> &'static str {
6161
}
6262
}
6363

64+
#[derive(Debug, Clone, FromPyObject)]
65+
pub enum Floats {
66+
Flat(Py<Float>),
67+
Num(f32),
68+
}
69+
70+
impl Default for Floats {
71+
fn default() -> Self {
72+
Floats::Flat(get_py_default())
73+
}
74+
}
75+
76+
impl FromGil<Floats> for Py<Float> {
77+
fn from_gil(py: Python, floats: Floats) -> Self {
78+
match floats {
79+
Floats::Flat(float) => float,
80+
Floats::Num(num) => Py::new(py, Float::new(num)).unwrap(),
81+
}
82+
}
83+
}
84+
85+
#[derive(Debug, Clone, FromPyObject)]
86+
pub enum Bools {
87+
Flat(Py<Bool>),
88+
Num(bool),
89+
}
90+
91+
impl Default for Bools {
92+
fn default() -> Self {
93+
Self::Flat(get_py_default())
94+
}
95+
}
96+
97+
impl FromGil<Bools> for Py<Bool> {
98+
fn from_gil(py: Python, bools: Bools) -> Self {
99+
match bools {
100+
Bools::Flat(float) => float,
101+
Bools::Num(num) => Py::new(py, Bool::new(num)).unwrap(),
102+
}
103+
}
104+
}
105+
64106
macro_rules! pynamedmodule {
65107
(doc: $doc:literal, name: $name:tt, classes: [$($class_name:ident),*], vars: [$(($var_name:literal, $value:expr)),*]) => {
66108
#[doc = $doc]

0 commit comments

Comments
 (0)