22
33from collections .abc import Iterable
44from pathlib import Path
5- from typing import Any
5+ from typing import Any , Literal
66
77from aiohttp import ClientWebSocketResponse
88from loguru import logger
@@ -36,8 +36,10 @@ def __init__(self, ws: ClientWebSocketResponse, save_replay_path: str | None = N
3636 # How many frames will be waited between iterations before the next one is called
3737 self .game_step : int = 4
3838 self .save_replay_path : str | None = save_replay_path
39- self ._player_id : int = None
40- self ._game_result : dict [int , Result ] = None
39+ # The following will be set on join_game()
40+ self ._player_id : int = None # pyrefly: ignore
41+ # The following will be set on leave()
42+ self ._game_result : dict [int , Result ] = None # pyrefly: ignore
4143 # Store a hash value of all the debug requests to prevent sending the same ones again if they haven't changed last frame
4244 self ._debug_hash_tuple_last_iteration : tuple [int , int , int , int ] = (0 , 0 , 0 , 0 )
4345 self ._debug_draw_last_frame = False
@@ -88,26 +90,26 @@ async def join_game(
8890 if race is None :
8991 assert isinstance (observed_player_id , int ), f"observed_player_id is of type { type (observed_player_id )} "
9092 # join as observer
91- req = sc_pb .RequestJoinGame (observed_player_id = observed_player_id , options = ifopts )
93+ request = sc_pb .RequestJoinGame (observed_player_id = observed_player_id , options = ifopts )
9294 else :
9395 assert isinstance (race , Race )
94- req = sc_pb .RequestJoinGame (race = race .value , options = ifopts )
96+ request = sc_pb .RequestJoinGame (race = race .value , options = ifopts )
9597
9698 if portconfig :
97- req .server_ports .game_port = portconfig .server [0 ]
98- req .server_ports .base_port = portconfig .server [1 ]
99+ request .server_ports .game_port = portconfig .server [0 ]
100+ request .server_ports .base_port = portconfig .server [1 ]
99101
100102 for ppc in portconfig .players :
101- p = req .client_ports .add ()
103+ p = request .client_ports .add () # pyrefly: ignore
102104 p .game_port = ppc [0 ]
103105 p .base_port = ppc [1 ]
104106
105107 if name is not None :
106108 assert isinstance (name , str ), f"name is of type { type (name )} "
107- req .player_name = name
109+ request .player_name = name
108110
109- result = await self ._execute (join_game = req )
110- self ._game_result = None
111+ result = await self ._execute (join_game = request )
112+ self ._game_result = None # pyrefly: ignore
111113 self ._player_id = result .join_game .player_id
112114 return result .join_game .player_id
113115
@@ -149,7 +151,7 @@ async def observation(self, game_loop: int | None = None):
149151 result = await self ._execute (observation = sc_pb .RequestObservation ())
150152 assert result .observation .player_result
151153
152- player_id_to_result = {}
154+ player_id_to_result = dict [ int , Result ]()
153155 for pr in result .observation .player_result :
154156 player_id_to_result [pr .player_id ] = Result (pr .result )
155157 self ._game_result = player_id_to_result
@@ -210,14 +212,19 @@ async def actions(self, actions: list[UnitCommand], return_successes: bool = Fal
210212
211213 # On realtime=True, might get an error here: sc2.protocol.ProtocolError: ['Not in a game']
212214 try :
213- res = await self ._execute (
214- action = sc_pb .RequestAction (actions = (sc_pb .Action (action_raw = a ) for a in combine_actions (actions )))
215+ response = await self ._execute (
216+ action = sc_pb .RequestAction (
217+ # pyrefly: ignore
218+ actions = (sc_pb .Action (action_raw = action ) for action in combine_actions (actions ))
219+ )
215220 )
216221 except ProtocolError :
217222 return []
218223 if return_successes :
219- return [ActionResult (r ) for r in res .action .result ]
220- return [ActionResult (r ) for r in res .action .result if ActionResult (r ) != ActionResult .Success ]
224+ return [ActionResult (result ) for result in response .action .result ]
225+ return [
226+ ActionResult (result ) for result in response .action .result if ActionResult (result ) != ActionResult .Success
227+ ]
221228
222229 async def query_pathing (self , start : Unit | Point2 | Point3 , end : Point2 | Point3 ) -> int | float | None :
223230 """Caution: returns "None" when path not found
@@ -237,25 +244,25 @@ async def query_pathing(self, start: Unit | Point2 | Point3, end: Point2 | Point
237244 return None
238245 return distance
239246
240- async def query_pathings (self , zipped_list : list [list [Unit | Point2 | Point3 ]]) -> list [float ]:
247+ async def query_pathings (self , zipped_list : list [tuple [Unit | Point2 | Point3 , Point2 | Point3 ]]) -> list [float ]:
241248 """Usage: await self.query_pathings([[unit1, target2], [unit2, target2]])
242249 -> returns [distance1, distance2]
243250 Caution: returns 0 when path not found
244251
245252 :param zipped_list:
246253 """
247- assert zipped_list , "No zipped_list"
248- assert isinstance (zipped_list , list ), f"{ type (zipped_list )} "
249- assert isinstance (zipped_list [0 ], list ), f"{ type (zipped_list [0 ])} "
250- assert len (zipped_list [0 ]) == 2 , f"{ len (zipped_list [0 ])} "
251- assert isinstance (zipped_list [0 ][0 ], (Point2 , Unit )), f"{ type (zipped_list [0 ][0 ])} "
252- assert isinstance (zipped_list [0 ][1 ], Point2 ), f"{ type (zipped_list [0 ][1 ])} "
253- if isinstance (zipped_list [0 ][0 ], Point2 ):
254- path = (
255- query_pb .RequestQueryPathing (start_pos = p1 .as_Point2D , end_pos = p2 .as_Point2D ) for p1 , p2 in zipped_list
254+ assert zipped_list , "No entry in zipped_list"
255+ path = (
256+ query_pb .RequestQueryPathing (
257+ # pyrefly: ignore
258+ unit_tag = p1 .tag if isinstance (p1 , Unit ) else None ,
259+ # pyrefly: ignore
260+ start_pos = None if isinstance (p1 , Unit ) else p1 .as_Point2D ,
261+ end_pos = p2 .as_Point2D ,
256262 )
257- else :
258- path = (query_pb .RequestQueryPathing (unit_tag = p1 .tag , end_pos = p2 .as_Point2D ) for p1 , p2 in zipped_list )
263+ for p1 , p2 in zipped_list
264+ )
265+ # pyrefly: ignore
259266 results = await self ._execute (query = query_pb .RequestQuery (pathing = path ))
260267 return [float (d .distance ) for d in results .query .pathing ]
261268
@@ -271,6 +278,7 @@ async def _query_building_placement_fast(
271278 """
272279 result = await self ._execute (
273280 query = query_pb .RequestQuery (
281+ # pyrefly: ignore
274282 placements = (
275283 query_pb .RequestQueryBuildingPlacement (ability_id = ability .value , target_pos = position .as_Point2D )
276284 for position in positions
@@ -296,6 +304,7 @@ async def query_building_placement(
296304 assert isinstance (ability , AbilityData )
297305 result = await self ._execute (
298306 query = query_pb .RequestQuery (
307+ # pyrefly: ignore
299308 placements = (
300309 query_pb .RequestQueryBuildingPlacement (ability_id = ability .id .value , target_pos = position .as_Point2D )
301310 for position in positions
@@ -319,23 +328,24 @@ async def query_available_abilities(
319328 assert units
320329 result = await self ._execute (
321330 query = query_pb .RequestQuery (
331+ # pyrefly: ignore
322332 abilities = (query_pb .RequestQueryAvailableAbilities (unit_tag = unit .tag ) for unit in units ),
323333 ignore_resource_requirements = ignore_resource_requirements ,
324334 )
325335 )
326336 """ Fix for bots that only query a single unit, may be removed soon """
327337 if not input_was_a_list :
328- # pyre-fixme[7]
338+ # pyrefly: ignore
329339 return [[AbilityId (a .ability_id ) for a in b .abilities ] for b in result .query .abilities ][0 ]
330340 return [[AbilityId (a .ability_id ) for a in b .abilities ] for b in result .query .abilities ]
331341
332342 async def query_available_abilities_with_tag (
333343 self , units : list [Unit ] | Units , ignore_resource_requirements : bool = False
334344 ) -> dict [int , set [AbilityId ]]:
335345 """Query abilities of multiple units"""
336-
337346 result = await self ._execute (
338347 query = query_pb .RequestQuery (
348+ # pyrefly: ignore
339349 abilities = (query_pb .RequestQueryAvailableAbilities (unit_tag = unit .tag ) for unit in units ),
340350 ignore_resource_requirements = ignore_resource_requirements ,
341351 )
@@ -367,30 +377,28 @@ async def toggle_autocast(self, units: list[Unit] | Units, ability: AbilityId) -
367377 sc_pb .Action (
368378 action_raw = raw_pb .ActionRaw (
369379 toggle_autocast = raw_pb .ActionRawToggleAutocast (
370- ability_id = ability .value , unit_tags = (u .tag for u in units )
380+ ability_id = ability .value ,
381+ # pyrefly: ignore
382+ unit_tags = (u .tag for u in units ),
371383 )
372384 )
373385 )
374386 ]
375387 )
376388 )
377389
378- async def debug_create_unit (self , unit_spawn_commands : list [list [UnitTypeId | int | Point2 | Point3 ]]) -> None :
390+ async def debug_create_unit (
391+ self , unit_spawn_commands : list [tuple [UnitTypeId , int , Point2 | Point3 , Literal [1 , 2 ]]]
392+ ) -> None :
379393 """Usage example (will spawn 5 marines in the center of the map for player ID 1):
380394 await self._client.debug_create_unit([[UnitTypeId.MARINE, 5, self._game_info.map_center, 1]])
381395
382396 :param unit_spawn_commands:"""
383- assert isinstance (unit_spawn_commands , list )
384- assert unit_spawn_commands
385- assert isinstance (unit_spawn_commands [0 ], list )
386- assert len (unit_spawn_commands [0 ]) == 4
387- assert isinstance (unit_spawn_commands [0 ][0 ], UnitTypeId )
388- assert unit_spawn_commands [0 ][1 ] > 0 # careful, in realtime=True this function may create more units
389- assert isinstance (unit_spawn_commands [0 ][2 ], (Point2 , Point3 ))
390- assert 1 <= unit_spawn_commands [0 ][3 ] <= 2
397+ assert unit_spawn_commands , "List is empty"
391398
392399 await self ._execute (
393400 debug = sc_pb .RequestDebug (
401+ # pyrefly: ignore
394402 debug = (
395403 debug_pb .DebugCommand (
396404 create_unit = debug_pb .DebugCreateUnit (
@@ -416,6 +424,7 @@ async def debug_kill_unit(self, unit_tags: Unit | Units | list[int] | set[int])
416424 assert unit_tags
417425
418426 await self ._execute (
427+ # pyrefly: ignore
419428 debug = sc_pb .RequestDebug (debug = [debug_pb .DebugCommand (kill_unit = debug_pb .DebugKillUnit (tag = unit_tags ))])
420429 )
421430
@@ -635,15 +644,19 @@ async def _send_debug(self) -> None:
635644 debug = [
636645 debug_pb .DebugCommand (
637646 draw = debug_pb .DebugDraw (
647+ # pyrefly: ignore
638648 text = [text .to_proto () for text in self ._debug_texts ]
639649 if self ._debug_texts
640650 else None ,
651+ # pyrefly: ignore
641652 lines = [line .to_proto () for line in self ._debug_lines ]
642653 if self ._debug_lines
643654 else None ,
655+ # pyrefly: ignore
644656 boxes = [box .to_proto () for box in self ._debug_boxes ]
645657 if self ._debug_boxes
646658 else None ,
659+ # pyrefly: ignore
647660 spheres = [sphere .to_proto () for sphere in self ._debug_spheres ]
648661 if self ._debug_spheres
649662 else None ,
@@ -665,6 +678,7 @@ async def _send_debug(self) -> None:
665678 await self ._execute (
666679 debug = sc_pb .RequestDebug (
667680 debug = [
681+ # pyrefly: ignore
668682 debug_pb .DebugCommand (draw = debug_pb .DebugDraw (text = None , lines = None , boxes = None , spheres = None ))
669683 ]
670684 )
@@ -696,6 +710,7 @@ async def debug_set_unit_value(
696710 assert value >= 0 , "Value can't be negative"
697711 await self ._execute (
698712 debug = sc_pb .RequestDebug (
713+ # pyrefly: ignore
699714 debug = (
700715 debug_pb .DebugCommand (
701716 unit_value = debug_pb .DebugSetUnitValue (
@@ -786,7 +801,7 @@ def to_debug_color(color: tuple[float, float, float] | list[float] | Point3 | No
786801 return debug_pb .Color (r = 255 , g = 255 , b = 255 )
787802 # Need to check if not of type Point3 because Point3 inherits from tuple
788803 if isinstance (color , (tuple , list )) or isinstance (color , Point3 ) and len (color ) == 3 :
789- return debug_pb .Color (r = color [0 ], g = color [1 ], b = color [2 ])
804+ return debug_pb .Color (r = int ( color [0 ]) , g = int ( color [1 ]) , b = int ( color [2 ]) )
790805 # In case color is of type Point3
791806 r = getattr (color , "r" , getattr (color , "x" , 255 ))
792807 g = getattr (color , "g" , getattr (color , "y" , 255 ))
@@ -818,6 +833,7 @@ def to_proto(self):
818833 color = self .to_debug_color (self ._color ),
819834 text = self ._text ,
820835 virtual_pos = self ._start_point .to3 .as_Point ,
836+ # pyrefly: ignore
821837 world_pos = None ,
822838 size = self ._font_size ,
823839 )
@@ -829,20 +845,21 @@ def __hash__(self) -> int:
829845class DrawItemWorldText (DrawItem ):
830846 def __init__ (
831847 self ,
832- start_point : Point3 = None ,
833- color : tuple [float , float , float ] | list [float ] | Point3 | None = None ,
848+ start_point : Point3 ,
849+ color : tuple [float , float , float ] | list [float ] | Point3 | None ,
834850 text : str = "" ,
835851 font_size : int = 8 ,
836852 ) -> None :
837- self ._start_point : Point3 = start_point
838- self ._color : Point3 = color
839- self ._text : str = text
840- self ._font_size : int = font_size
853+ self ._start_point = start_point
854+ self ._color = color
855+ self ._text = text
856+ self ._font_size = font_size
841857
842858 def to_proto (self ):
843859 return debug_pb .DebugText (
844860 color = self .to_debug_color (self ._color ),
845861 text = self ._text ,
862+ # pyrefly: ignore
846863 virtual_pos = None ,
847864 world_pos = self ._start_point .as_Point ,
848865 size = self ._font_size ,
@@ -855,13 +872,13 @@ def __hash__(self) -> int:
855872class DrawItemLine (DrawItem ):
856873 def __init__ (
857874 self ,
858- start_point : Point3 = None ,
859- end_point : Point3 = None ,
875+ start_point : Point3 ,
876+ end_point : Point3 ,
860877 color : tuple [float , float , float ] | list [float ] | Point3 | None = None ,
861878 ) -> None :
862- self ._start_point : Point3 = start_point
863- self ._end_point : Point3 = end_point
864- self ._color : Point3 = color
879+ self ._start_point = start_point
880+ self ._end_point = end_point
881+ self ._color = color
865882
866883 def to_proto (self ):
867884 return debug_pb .DebugLine (
@@ -876,13 +893,13 @@ def __hash__(self) -> int:
876893class DrawItemBox (DrawItem ):
877894 def __init__ (
878895 self ,
879- start_point : Point3 = None ,
880- end_point : Point3 = None ,
896+ start_point : Point3 ,
897+ end_point : Point3 ,
881898 color : tuple [float , float , float ] | list [float ] | Point3 | None = None ,
882899 ) -> None :
883- self ._start_point : Point3 = start_point
884- self ._end_point : Point3 = end_point
885- self ._color : Point3 = color
900+ self ._start_point = start_point
901+ self ._end_point = end_point
902+ self ._color = color
886903
887904 def to_proto (self ):
888905 return debug_pb .DebugBox (
@@ -898,13 +915,13 @@ def __hash__(self) -> int:
898915class DrawItemSphere (DrawItem ):
899916 def __init__ (
900917 self ,
901- start_point : Point3 = None ,
902- radius : float = None ,
918+ start_point : Point3 ,
919+ radius : float ,
903920 color : tuple [float , float , float ] | list [float ] | Point3 | None = None ,
904921 ) -> None :
905- self ._start_point : Point3 = start_point
906- self ._radius : float = radius
907- self ._color : Point3 = color
922+ self ._start_point = start_point
923+ self ._radius = radius
924+ self ._color = color
908925
909926 def to_proto (self ):
910927 return debug_pb .DebugSphere (
0 commit comments