Skip to content

Commit a1801a5

Browse files
committed
Fix types for bot_ai.py
1 parent a722ff8 commit a1801a5

3 files changed

Lines changed: 80 additions & 45 deletions

File tree

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ project_includes = [
101101
project-excludes = [
102102
# Disable for those files and folders
103103
"sc2/data.py",
104-
# TODO Temp disable for those files and folders
105-
"sc2/bot_ai.py",
106104
]
107105

108106
[tool.pyrefly.errors]

sc2/bot_ai.py

Lines changed: 79 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -250,24 +250,24 @@ async def get_next_expansion(self) -> Point2 | None:
250250
"""Find next expansion location."""
251251

252252
closest = None
253-
distance = math.inf
254-
for el in self.expansion_locations_list:
253+
best_distance = math.inf
254+
start_position = self.game_info.player_start_location
255+
for position in self.expansion_locations_list:
255256

256257
def is_near_to_expansion(t):
257-
return t.distance_to(el) < self.EXPANSION_GAP_THRESHOLD
258+
return t.distance_to(position) < self.EXPANSION_GAP_THRESHOLD
258259

259260
if any(map(is_near_to_expansion, self.townhalls)):
260261
# already taken
261262
continue
262263

263-
startp = self.game_info.player_start_location
264-
d = await self.client.query_pathing(startp, el)
265-
if d is None:
264+
distance = await self.client.query_pathing(start_position, position)
265+
if distance is None:
266266
continue
267267

268-
if d < distance:
269-
distance = d
270-
closest = el
268+
if distance < best_distance:
269+
best_distance = distance
270+
closest = position
271271

272272
return closest
273273

@@ -492,7 +492,11 @@ def calculate_cost(self, item_id: UnitTypeId | UpgradeId | AbilityId) -> Cost:
492492
return self.calculate_unit_value(UnitTypeId.ARCHON)
493493
unit_data = self.game_data.units[item_id.value]
494494
# Cost of morphs is automatically correctly calculated by 'calculate_ability_cost'
495-
return self.game_data.calculate_ability_cost(unit_data.creation_ability.exact_id)
495+
creation_ability = unit_data.creation_ability
496+
if creation_ability is None:
497+
logger.error(f"Unknown item_id for calculate_cost: {item_id}")
498+
return Cost(0, 0)
499+
return self.game_data.calculate_ability_cost(creation_ability.exact_id)
496500

497501
if isinstance(item_id, UpgradeId):
498502
cost = self.game_data.upgrades[item_id.value].cost
@@ -566,11 +570,18 @@ async def can_cast(
566570
ability_target: int = self.game_data.abilities[ability_id.value]._proto.target
567571
# Check if target is in range (or is a self cast like stimpack)
568572
if (
573+
# Can't replace 1 with "Target.None.value" because ".None" doesn't seem to be a valid enum name
569574
ability_target == 1
570575
or ability_target == Target.PointOrNone.value
571-
and isinstance(target, Point2)
572-
and unit.distance_to(target) <= unit.radius + target.radius + cast_range
573-
): # cant replace 1 with "Target.None.value" because ".None" doesnt seem to be a valid enum name
576+
and (
577+
# Target is unit
578+
isinstance(target, Unit)
579+
and unit.distance_to(target) <= unit.radius + target.radius + cast_range
580+
# Target is position
581+
or isinstance(target, Point2)
582+
and unit.distance_to(target) <= unit.radius + cast_range
583+
)
584+
):
574585
return True
575586
# Check if able to use ability on a unit
576587
if (
@@ -620,8 +631,12 @@ def select_build_worker(self, pos: Unit | Point2, force: bool = False) -> Unit |
620631
async def can_place_single(self, building: AbilityId | UnitTypeId, position: Point2) -> bool:
621632
"""Checks the placement for only one position."""
622633
if isinstance(building, UnitTypeId):
623-
creation_ability = self.game_data.units[building.value].creation_ability.id
624-
return (await self.client._query_building_placement_fast(creation_ability, [position]))[0]
634+
creation_ability = self.game_data.units[building.value].creation_ability
635+
if creation_ability is None:
636+
logger.error(f"Unknown building for can_place_single: {building}")
637+
return False
638+
creation_ability_id = creation_ability.id
639+
return (await self.client._query_building_placement_fast(creation_ability_id, [position]))[0]
625640
return (await self.client._query_building_placement_fast(building, [position]))[0]
626641

627642
async def can_place(self, building: AbilityData | AbilityId | UnitTypeId, positions: list[Point2]) -> list[bool]:
@@ -637,24 +652,26 @@ async def can_place(self, building: AbilityData | AbilityId | UnitTypeId, positi
637652
638653
:param building:
639654
:param position:"""
640-
building_type = type(building)
641-
assert type(building) in {AbilityData, AbilityId, UnitTypeId}, f"{building}, {building_type}"
642-
if building_type == UnitTypeId:
643-
building = self.game_data.units[building.value].creation_ability.id
644-
elif building_type == AbilityData:
655+
if isinstance(building, UnitTypeId):
656+
creation_ability = self.game_data.units[building.value].creation_ability
657+
if creation_ability is None:
658+
return [False for _ in positions]
659+
building = creation_ability.id
660+
elif isinstance(building, AbilityData):
645661
warnings.warn(
646662
"Using AbilityData is deprecated and may be removed soon. Please use AbilityId or UnitTypeId instead.",
647663
DeprecationWarning,
648664
stacklevel=2,
649665
)
650-
building = building_type.id
666+
building = building.id
651667

652668
if isinstance(positions, (Point2, tuple)):
653669
warnings.warn(
654670
"The support for querying single entries will be removed soon. Please use either 'await self.can_place_single(building, position)' or 'await (self.can_place(building, [position]))[0]",
655671
DeprecationWarning,
656672
stacklevel=2,
657673
)
674+
# pyrefly: ignore
658675
return await self.can_place_single(building, positions)
659676
assert isinstance(positions, list), f"Expected an iterable (list, tuple), but was: {positions}"
660677
assert isinstance(positions[0], Point2), (
@@ -690,7 +707,10 @@ async def find_placement(
690707
assert isinstance(near, Point2), f"{near} is no Point2 object"
691708

692709
if isinstance(building, UnitTypeId):
693-
building = self.game_data.units[building.value].creation_ability.id
710+
creation_ability = self.game_data.units[building.value].creation_ability
711+
if creation_ability is None:
712+
return None
713+
building = creation_ability.id
694714

695715
if await self.can_place_single(building, near) and (
696716
not addon_place or await self.can_place_single(AbilityId.TERRANBUILD_SUPPLYDEPOT, near.offset((2.5, -0.5)))
@@ -749,10 +769,13 @@ def already_pending_upgrade(self, upgrade_type: UpgradeId) -> float:
749769
assert isinstance(upgrade_type, UpgradeId), f"{upgrade_type} is no UpgradeId"
750770
if upgrade_type in self.state.upgrades:
751771
return 1
752-
creationAbilityID = self.game_data.upgrades[upgrade_type.value].research_ability.exact_id
772+
research_ability = self.game_data.upgrades[upgrade_type.value].research_ability
773+
if research_ability is None:
774+
return 0
775+
creation_ability_id = research_ability.exact_id
753776
for structure in self.structures.filter(lambda unit: unit.is_ready):
754777
for order in structure.orders:
755-
if order.ability.exact_id == creationAbilityID:
778+
if order.ability.exact_id == creation_ability_id:
756779
return order.progress
757780
return 0
758781

@@ -798,7 +821,7 @@ def structure_type_build_progress(self, structure_type: UnitTypeId | int) -> flo
798821
# SUPPLYDEPOTDROP is not in self.game_data.units, so bot_ai should not check the build progress via creation ability (worker abilities)
799822
if structure_type_value not in self.game_data.units:
800823
return max((s.build_progress for s in self.structures if s._proto.unit_type in equiv_values), default=0)
801-
creation_ability_data: AbilityData = self.game_data.units[structure_type_value].creation_ability
824+
creation_ability_data = self.game_data.units[structure_type_value].creation_ability
802825
if creation_ability_data is None:
803826
return 0
804827
creation_ability: AbilityId = creation_ability_data.exact_id
@@ -863,24 +886,33 @@ def already_pending(self, unit_type: UpgradeId | UnitTypeId) -> float:
863886
"""
864887
if isinstance(unit_type, UpgradeId):
865888
return self.already_pending_upgrade(unit_type)
866-
try:
867-
ability = self.game_data.units[unit_type.value].creation_ability.exact_id
868-
except AttributeError:
869-
if unit_type in CREATION_ABILITY_FIX:
870-
# Hotfix for checking pending archons
871-
if unit_type == UnitTypeId.ARCHON:
872-
return self._abilities_count_and_build_progress[0][AbilityId.ARCHON_WARP_TARGET] / 2
873-
# Hotfix for rich geysirs
874-
return self._abilities_count_and_build_progress[0][CREATION_ABILITY_FIX[unit_type]]
875-
logger.error(f"Uncaught UnitTypeId: {unit_type}")
889+
creation_ability = self.game_data.units[unit_type.value].creation_ability
890+
if creation_ability is None:
876891
return 0
877-
return self._abilities_count_and_build_progress[0][ability]
892+
893+
if unit_type in CREATION_ABILITY_FIX:
894+
# Hotfix for checking pending archons and other abilities
895+
if unit_type == UnitTypeId.ARCHON:
896+
return self._abilities_count_and_build_progress[0][AbilityId.ARCHON_WARP_TARGET] / 2
897+
# Hotfix for rich geysirs
898+
return self._abilities_count_and_build_progress[0][CREATION_ABILITY_FIX[unit_type]]
899+
900+
creation_ability = self.game_data.units[unit_type.value].creation_ability
901+
if creation_ability is None:
902+
logger.error(f"Unknown unit_type for already_pending: {unit_type}")
903+
return 0
904+
ability_id = creation_ability.exact_id
905+
return self._abilities_count_and_build_progress[0][ability_id]
878906

879907
def worker_en_route_to_build(self, unit_type: UnitTypeId) -> float:
880908
"""This function counts how many workers are on the way to start the construction a building.
881909
882910
:param unit_type:"""
883-
ability = self.game_data.units[unit_type.value].creation_ability.exact_id
911+
creation_ability = self.game_data.units[unit_type.value].creation_ability
912+
if creation_ability is None:
913+
logger.error(f"Unknown unit_type for worker_en_route_to_build: {unit_type}")
914+
return 0
915+
ability = creation_ability.exact_id
884916
return self._worker_orders[ability]
885917

886918
@property_cache_once_per_frame
@@ -894,7 +926,7 @@ def structures_without_construction_SCVs(self) -> Units:
894926
continue
895927
for order in worker.orders:
896928
# When a construction is resumed, the worker.orders[0].target is the tag of the structure, else it is a Point2
897-
worker_targets.add(order.target)
929+
worker_targets.add(order.target) # pyrefly: ignore
898930
return self.structures.filter(
899931
lambda structure: structure.build_progress < 1
900932
# Redundant check?
@@ -928,15 +960,15 @@ async def build(
928960
assert isinstance(near, (Unit, Point2))
929961
if not self.can_afford(building):
930962
return False
931-
p = None
963+
position = None
932964
gas_buildings = {UnitTypeId.EXTRACTOR, UnitTypeId.ASSIMILATOR, UnitTypeId.REFINERY}
933965
if isinstance(near, Unit) and building not in gas_buildings:
934966
near = near.position
935967
if isinstance(near, Point2):
936968
near = near.to2
937969
if isinstance(near, Point2):
938-
p = await self.find_placement(building, near, max_distance, random_alternative, placement_step)
939-
if p is None:
970+
position = await self.find_placement(building, near, max_distance, random_alternative, placement_step)
971+
if position is None:
940972
return False
941973
builder = build_worker or self.select_build_worker(near)
942974
if builder is None:
@@ -945,7 +977,8 @@ async def build(
945977
assert isinstance(near, Unit)
946978
builder.build_gas(near)
947979
return True
948-
self.do(builder.build(building, p), subtract_cost=True, ignore_warning=True)
980+
# pyrefly: ignore
981+
self.do(builder.build(building, position), subtract_cost=True, ignore_warning=True)
949982
return True
950983

951984
def train(
@@ -1054,6 +1087,7 @@ def train(
10541087
else:
10551088
# Normal train a unit from larva or inside a structure
10561089
successfully_trained = self.do(
1090+
# pyrefly: ignore
10571091
structure.train(unit_type),
10581092
subtract_cost=True,
10591093
subtract_supply=True,
@@ -1077,6 +1111,7 @@ def train(
10771111
trained_amount += 1
10781112
# With one command queue=False and one queue=True, you can queue 2 marines in a reactored barracks in one frame
10791113
successfully_trained = self.do(
1114+
# pyrefly: ignore
10801115
structure.train(unit_type, queue=True),
10811116
subtract_cost=True,
10821117
subtract_supply=True,
@@ -1124,6 +1159,7 @@ def research(self, upgrade_type: UpgradeId) -> bool:
11241159

11251160
research_structure_type: UnitTypeId = UPGRADE_RESEARCHED_FROM[upgrade_type]
11261161

1162+
# pyrefly: ignore
11271163
required_tech_building: UnitTypeId | None = RESEARCH_INFO[research_structure_type][upgrade_type].get(
11281164
"required_building", None
11291165
)
@@ -1166,6 +1202,7 @@ def research(self, upgrade_type: UpgradeId) -> bool:
11661202
):
11671203
# Can_afford check was already done earlier in this function
11681204
successful_action: bool = self.do(
1205+
# pyrefly: ignore
11691206
structure.research(upgrade_type),
11701207
subtract_cost=True,
11711208
ignore_warning=True,

sc2/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ async def actions(self, actions: list[UnitCommand], return_successes: bool = Fal
226226
ActionResult(result) for result in response.action.result if ActionResult(result) != ActionResult.Success
227227
]
228228

229-
async def query_pathing(self, start: Unit | Point2 | Point3, end: Point2 | Point3) -> int | float | None:
229+
async def query_pathing(self, start: Unit | Point2 | Point3, end: Point2 | Point3) -> float | None:
230230
"""Caution: returns "None" when path not found
231231
Try to combine queries with the function below because the pathing query is generally slow.
232232

0 commit comments

Comments
 (0)