Skip to content

Commit f2a0df1

Browse files
committed
Fix type hints for client.py
1 parent dd9e9db commit f2a0df1

2 files changed

Lines changed: 79 additions & 63 deletions

File tree

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ project-excludes = [
102102
# Disable for those files and folders
103103
"sc2/data.py",
104104
# TODO Temp disable for those files and folders
105-
"sc2/client.py",
106105
"sc2/bot_ai_internal.py",
107106
"sc2/bot_ai.py",
108107
]

sc2/client.py

Lines changed: 79 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Iterable
44
from pathlib import Path
5-
from typing import Any
5+
from typing import Any, Literal
66

77
from aiohttp import ClientWebSocketResponse
88
from loguru import logger
@@ -36,8 +36,10 @@ def __init__(self, ws: ClientWebSocketResponse, save_replay_path: str | None = N
3636
# How many frames will be waited between iterations before the next one is called
3737
self.game_step: int = 4
3838
self.save_replay_path: str | None = save_replay_path
39-
self._player_id: int = None
40-
self._game_result: dict[int, Result] = None
39+
# The following will be set on join_game()
40+
self._player_id: int = None # pyrefly: ignore
41+
# The following will be set on leave()
42+
self._game_result: dict[int, Result] = None # pyrefly: ignore
4143
# Store a hash value of all the debug requests to prevent sending the same ones again if they haven't changed last frame
4244
self._debug_hash_tuple_last_iteration: tuple[int, int, int, int] = (0, 0, 0, 0)
4345
self._debug_draw_last_frame = False
@@ -88,26 +90,26 @@ async def join_game(
8890
if race is None:
8991
assert isinstance(observed_player_id, int), f"observed_player_id is of type {type(observed_player_id)}"
9092
# join as observer
91-
req = sc_pb.RequestJoinGame(observed_player_id=observed_player_id, options=ifopts)
93+
request = sc_pb.RequestJoinGame(observed_player_id=observed_player_id, options=ifopts)
9294
else:
9395
assert isinstance(race, Race)
94-
req = sc_pb.RequestJoinGame(race=race.value, options=ifopts)
96+
request = sc_pb.RequestJoinGame(race=race.value, options=ifopts)
9597

9698
if portconfig:
97-
req.server_ports.game_port = portconfig.server[0]
98-
req.server_ports.base_port = portconfig.server[1]
99+
request.server_ports.game_port = portconfig.server[0]
100+
request.server_ports.base_port = portconfig.server[1]
99101

100102
for ppc in portconfig.players:
101-
p = req.client_ports.add()
103+
p = request.client_ports.add() # pyrefly: ignore
102104
p.game_port = ppc[0]
103105
p.base_port = ppc[1]
104106

105107
if name is not None:
106108
assert isinstance(name, str), f"name is of type {type(name)}"
107-
req.player_name = name
109+
request.player_name = name
108110

109-
result = await self._execute(join_game=req)
110-
self._game_result = None
111+
result = await self._execute(join_game=request)
112+
self._game_result = None # pyrefly: ignore
111113
self._player_id = result.join_game.player_id
112114
return result.join_game.player_id
113115

@@ -149,7 +151,7 @@ async def observation(self, game_loop: int | None = None):
149151
result = await self._execute(observation=sc_pb.RequestObservation())
150152
assert result.observation.player_result
151153

152-
player_id_to_result = {}
154+
player_id_to_result = dict[int, Result]()
153155
for pr in result.observation.player_result:
154156
player_id_to_result[pr.player_id] = Result(pr.result)
155157
self._game_result = player_id_to_result
@@ -210,14 +212,19 @@ async def actions(self, actions: list[UnitCommand], return_successes: bool = Fal
210212

211213
# On realtime=True, might get an error here: sc2.protocol.ProtocolError: ['Not in a game']
212214
try:
213-
res = await self._execute(
214-
action=sc_pb.RequestAction(actions=(sc_pb.Action(action_raw=a) for a in combine_actions(actions)))
215+
response = await self._execute(
216+
action=sc_pb.RequestAction(
217+
# pyrefly: ignore
218+
actions=(sc_pb.Action(action_raw=action) for action in combine_actions(actions))
219+
)
215220
)
216221
except ProtocolError:
217222
return []
218223
if return_successes:
219-
return [ActionResult(r) for r in res.action.result]
220-
return [ActionResult(r) for r in res.action.result if ActionResult(r) != ActionResult.Success]
224+
return [ActionResult(result) for result in response.action.result]
225+
return [
226+
ActionResult(result) for result in response.action.result if ActionResult(result) != ActionResult.Success
227+
]
221228

222229
async def query_pathing(self, start: Unit | Point2 | Point3, end: Point2 | Point3) -> int | float | None:
223230
"""Caution: returns "None" when path not found
@@ -237,25 +244,25 @@ async def query_pathing(self, start: Unit | Point2 | Point3, end: Point2 | Point
237244
return None
238245
return distance
239246

240-
async def query_pathings(self, zipped_list: list[list[Unit | Point2 | Point3]]) -> list[float]:
247+
async def query_pathings(self, zipped_list: list[tuple[Unit | Point2 | Point3, Point2 | Point3]]) -> list[float]:
241248
"""Usage: await self.query_pathings([[unit1, target2], [unit2, target2]])
242249
-> returns [distance1, distance2]
243250
Caution: returns 0 when path not found
244251
245252
:param zipped_list:
246253
"""
247-
assert zipped_list, "No zipped_list"
248-
assert isinstance(zipped_list, list), f"{type(zipped_list)}"
249-
assert isinstance(zipped_list[0], list), f"{type(zipped_list[0])}"
250-
assert len(zipped_list[0]) == 2, f"{len(zipped_list[0])}"
251-
assert isinstance(zipped_list[0][0], (Point2, Unit)), f"{type(zipped_list[0][0])}"
252-
assert isinstance(zipped_list[0][1], Point2), f"{type(zipped_list[0][1])}"
253-
if isinstance(zipped_list[0][0], Point2):
254-
path = (
255-
query_pb.RequestQueryPathing(start_pos=p1.as_Point2D, end_pos=p2.as_Point2D) for p1, p2 in zipped_list
254+
assert zipped_list, "No entry in zipped_list"
255+
path = (
256+
query_pb.RequestQueryPathing(
257+
# pyrefly: ignore
258+
unit_tag=p1.tag if isinstance(p1, Unit) else None,
259+
# pyrefly: ignore
260+
start_pos=None if isinstance(p1, Unit) else p1.as_Point2D,
261+
end_pos=p2.as_Point2D,
256262
)
257-
else:
258-
path = (query_pb.RequestQueryPathing(unit_tag=p1.tag, end_pos=p2.as_Point2D) for p1, p2 in zipped_list)
263+
for p1, p2 in zipped_list
264+
)
265+
# pyrefly: ignore
259266
results = await self._execute(query=query_pb.RequestQuery(pathing=path))
260267
return [float(d.distance) for d in results.query.pathing]
261268

@@ -271,6 +278,7 @@ async def _query_building_placement_fast(
271278
"""
272279
result = await self._execute(
273280
query=query_pb.RequestQuery(
281+
# pyrefly: ignore
274282
placements=(
275283
query_pb.RequestQueryBuildingPlacement(ability_id=ability.value, target_pos=position.as_Point2D)
276284
for position in positions
@@ -296,6 +304,7 @@ async def query_building_placement(
296304
assert isinstance(ability, AbilityData)
297305
result = await self._execute(
298306
query=query_pb.RequestQuery(
307+
# pyrefly: ignore
299308
placements=(
300309
query_pb.RequestQueryBuildingPlacement(ability_id=ability.id.value, target_pos=position.as_Point2D)
301310
for position in positions
@@ -319,23 +328,24 @@ async def query_available_abilities(
319328
assert units
320329
result = await self._execute(
321330
query=query_pb.RequestQuery(
331+
# pyrefly: ignore
322332
abilities=(query_pb.RequestQueryAvailableAbilities(unit_tag=unit.tag) for unit in units),
323333
ignore_resource_requirements=ignore_resource_requirements,
324334
)
325335
)
326336
""" Fix for bots that only query a single unit, may be removed soon """
327337
if not input_was_a_list:
328-
# pyre-fixme[7]
338+
# pyrefly: ignore
329339
return [[AbilityId(a.ability_id) for a in b.abilities] for b in result.query.abilities][0]
330340
return [[AbilityId(a.ability_id) for a in b.abilities] for b in result.query.abilities]
331341

332342
async def query_available_abilities_with_tag(
333343
self, units: list[Unit] | Units, ignore_resource_requirements: bool = False
334344
) -> dict[int, set[AbilityId]]:
335345
"""Query abilities of multiple units"""
336-
337346
result = await self._execute(
338347
query=query_pb.RequestQuery(
348+
# pyrefly: ignore
339349
abilities=(query_pb.RequestQueryAvailableAbilities(unit_tag=unit.tag) for unit in units),
340350
ignore_resource_requirements=ignore_resource_requirements,
341351
)
@@ -367,30 +377,28 @@ async def toggle_autocast(self, units: list[Unit] | Units, ability: AbilityId) -
367377
sc_pb.Action(
368378
action_raw=raw_pb.ActionRaw(
369379
toggle_autocast=raw_pb.ActionRawToggleAutocast(
370-
ability_id=ability.value, unit_tags=(u.tag for u in units)
380+
ability_id=ability.value,
381+
# pyrefly: ignore
382+
unit_tags=(u.tag for u in units),
371383
)
372384
)
373385
)
374386
]
375387
)
376388
)
377389

378-
async def debug_create_unit(self, unit_spawn_commands: list[list[UnitTypeId | int | Point2 | Point3]]) -> None:
390+
async def debug_create_unit(
391+
self, unit_spawn_commands: list[tuple[UnitTypeId, int, Point2 | Point3, Literal[1, 2]]]
392+
) -> None:
379393
"""Usage example (will spawn 5 marines in the center of the map for player ID 1):
380394
await self._client.debug_create_unit([[UnitTypeId.MARINE, 5, self._game_info.map_center, 1]])
381395
382396
:param unit_spawn_commands:"""
383-
assert isinstance(unit_spawn_commands, list)
384-
assert unit_spawn_commands
385-
assert isinstance(unit_spawn_commands[0], list)
386-
assert len(unit_spawn_commands[0]) == 4
387-
assert isinstance(unit_spawn_commands[0][0], UnitTypeId)
388-
assert unit_spawn_commands[0][1] > 0 # careful, in realtime=True this function may create more units
389-
assert isinstance(unit_spawn_commands[0][2], (Point2, Point3))
390-
assert 1 <= unit_spawn_commands[0][3] <= 2
397+
assert unit_spawn_commands, "List is empty"
391398

392399
await self._execute(
393400
debug=sc_pb.RequestDebug(
401+
# pyrefly: ignore
394402
debug=(
395403
debug_pb.DebugCommand(
396404
create_unit=debug_pb.DebugCreateUnit(
@@ -416,6 +424,7 @@ async def debug_kill_unit(self, unit_tags: Unit | Units | list[int] | set[int])
416424
assert unit_tags
417425

418426
await self._execute(
427+
# pyrefly: ignore
419428
debug=sc_pb.RequestDebug(debug=[debug_pb.DebugCommand(kill_unit=debug_pb.DebugKillUnit(tag=unit_tags))])
420429
)
421430

@@ -635,15 +644,19 @@ async def _send_debug(self) -> None:
635644
debug=[
636645
debug_pb.DebugCommand(
637646
draw=debug_pb.DebugDraw(
647+
# pyrefly: ignore
638648
text=[text.to_proto() for text in self._debug_texts]
639649
if self._debug_texts
640650
else None,
651+
# pyrefly: ignore
641652
lines=[line.to_proto() for line in self._debug_lines]
642653
if self._debug_lines
643654
else None,
655+
# pyrefly: ignore
644656
boxes=[box.to_proto() for box in self._debug_boxes]
645657
if self._debug_boxes
646658
else None,
659+
# pyrefly: ignore
647660
spheres=[sphere.to_proto() for sphere in self._debug_spheres]
648661
if self._debug_spheres
649662
else None,
@@ -665,6 +678,7 @@ async def _send_debug(self) -> None:
665678
await self._execute(
666679
debug=sc_pb.RequestDebug(
667680
debug=[
681+
# pyrefly: ignore
668682
debug_pb.DebugCommand(draw=debug_pb.DebugDraw(text=None, lines=None, boxes=None, spheres=None))
669683
]
670684
)
@@ -696,6 +710,7 @@ async def debug_set_unit_value(
696710
assert value >= 0, "Value can't be negative"
697711
await self._execute(
698712
debug=sc_pb.RequestDebug(
713+
# pyrefly: ignore
699714
debug=(
700715
debug_pb.DebugCommand(
701716
unit_value=debug_pb.DebugSetUnitValue(
@@ -786,7 +801,7 @@ def to_debug_color(color: tuple[float, float, float] | list[float] | Point3 | No
786801
return debug_pb.Color(r=255, g=255, b=255)
787802
# Need to check if not of type Point3 because Point3 inherits from tuple
788803
if isinstance(color, (tuple, list)) or isinstance(color, Point3) and len(color) == 3:
789-
return debug_pb.Color(r=color[0], g=color[1], b=color[2])
804+
return debug_pb.Color(r=int(color[0]), g=int(color[1]), b=int(color[2]))
790805
# In case color is of type Point3
791806
r = getattr(color, "r", getattr(color, "x", 255))
792807
g = getattr(color, "g", getattr(color, "y", 255))
@@ -818,6 +833,7 @@ def to_proto(self):
818833
color=self.to_debug_color(self._color),
819834
text=self._text,
820835
virtual_pos=self._start_point.to3.as_Point,
836+
# pyrefly: ignore
821837
world_pos=None,
822838
size=self._font_size,
823839
)
@@ -829,20 +845,21 @@ def __hash__(self) -> int:
829845
class DrawItemWorldText(DrawItem):
830846
def __init__(
831847
self,
832-
start_point: Point3 = None,
833-
color: tuple[float, float, float] | list[float] | Point3 | None = None,
848+
start_point: Point3,
849+
color: tuple[float, float, float] | list[float] | Point3 | None,
834850
text: str = "",
835851
font_size: int = 8,
836852
) -> None:
837-
self._start_point: Point3 = start_point
838-
self._color: Point3 = color
839-
self._text: str = text
840-
self._font_size: int = font_size
853+
self._start_point = start_point
854+
self._color = color
855+
self._text = text
856+
self._font_size = font_size
841857

842858
def to_proto(self):
843859
return debug_pb.DebugText(
844860
color=self.to_debug_color(self._color),
845861
text=self._text,
862+
# pyrefly: ignore
846863
virtual_pos=None,
847864
world_pos=self._start_point.as_Point,
848865
size=self._font_size,
@@ -855,13 +872,13 @@ def __hash__(self) -> int:
855872
class DrawItemLine(DrawItem):
856873
def __init__(
857874
self,
858-
start_point: Point3 = None,
859-
end_point: Point3 = None,
875+
start_point: Point3,
876+
end_point: Point3,
860877
color: tuple[float, float, float] | list[float] | Point3 | None = None,
861878
) -> None:
862-
self._start_point: Point3 = start_point
863-
self._end_point: Point3 = end_point
864-
self._color: Point3 = color
879+
self._start_point = start_point
880+
self._end_point = end_point
881+
self._color = color
865882

866883
def to_proto(self):
867884
return debug_pb.DebugLine(
@@ -876,13 +893,13 @@ def __hash__(self) -> int:
876893
class DrawItemBox(DrawItem):
877894
def __init__(
878895
self,
879-
start_point: Point3 = None,
880-
end_point: Point3 = None,
896+
start_point: Point3,
897+
end_point: Point3,
881898
color: tuple[float, float, float] | list[float] | Point3 | None = None,
882899
) -> None:
883-
self._start_point: Point3 = start_point
884-
self._end_point: Point3 = end_point
885-
self._color: Point3 = color
900+
self._start_point = start_point
901+
self._end_point = end_point
902+
self._color = color
886903

887904
def to_proto(self):
888905
return debug_pb.DebugBox(
@@ -898,13 +915,13 @@ def __hash__(self) -> int:
898915
class DrawItemSphere(DrawItem):
899916
def __init__(
900917
self,
901-
start_point: Point3 = None,
902-
radius: float = None,
918+
start_point: Point3,
919+
radius: float,
903920
color: tuple[float, float, float] | list[float] | Point3 | None = None,
904921
) -> None:
905-
self._start_point: Point3 = start_point
906-
self._radius: float = radius
907-
self._color: Point3 = color
922+
self._start_point = start_point
923+
self._radius = radius
924+
self._color = color
908925

909926
def to_proto(self):
910927
return debug_pb.DebugSphere(

0 commit comments

Comments
 (0)