Skip to content

Commit a722ff8

Browse files
committed
Fix types for bot_ai_internal.py
1 parent f2a0df1 commit a722ff8

3 files changed

Lines changed: 38 additions & 31 deletions

File tree

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ project-excludes = [
102102
# Disable for those files and folders
103103
"sc2/data.py",
104104
# TODO Temp disable for those files and folders
105-
"sc2/bot_ai_internal.py",
106105
"sc2/bot_ai.py",
107106
]
108107

sc2/bot_ai_internal.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,17 @@
4444
if TYPE_CHECKING:
4545
from sc2.client import Client
4646
from sc2.game_info import GameInfo
47+
from sc2.bot_ai import BotAI
4748

4849

4950
class BotAIInternal(ABC):
5051
"""Base class for bots."""
5152

52-
def __init__(self) -> None:
53+
def __init__(self: BotAI) -> None:
5354
self._initialize_variables()
5455

5556
@final
56-
def _initialize_variables(self) -> None:
57+
def _initialize_variables(self: BotAI) -> None:
5758
"""Called from main.py internally"""
5859
self.cache: dict[str, Any] = {}
5960
# Specific opponent bot ID used in sc2ai ladder games http://sc2ai.net/ and on ai arena https://aiarena.net
@@ -103,7 +104,8 @@ def _initialize_variables(self) -> None:
103104
self.actions: list[UnitCommand] = []
104105
self.blips: set[Blip] = set()
105106

106-
self.race: Race = None
107+
# Will be set on AbstractPlayer init
108+
self.race: Race = None # pyrefly: ignore
107109
self.enemy_race: Race | None = None
108110
self._generated_frame = -100
109111
self._units_created: Counter = Counter()
@@ -161,7 +163,7 @@ def _client(self) -> Client:
161163

162164
@final
163165
@property_cache_once_per_frame
164-
def expansion_locations(self) -> dict[Point2, Units]:
166+
def expansion_locations(self: BotAI) -> dict[Point2, Units]:
165167
"""Same as the function above."""
166168
assert self._expansion_positions_list, (
167169
"self._find_expansion_locations() has not been run yet, so accessing the list of expansion locations is pointless."
@@ -190,7 +192,8 @@ def _cluster_center(self, group: list[Unit]) -> Point2:
190192
if not group:
191193
raise ValueError("Cannot calculate center of empty group")
192194

193-
total_x = total_y = 0
195+
total_x: float = 0
196+
total_y: float = 0
194197
for unit in group:
195198
total_x += unit.position.x
196199
total_y += unit.position.y
@@ -353,7 +356,7 @@ def _find_expansion_locations(self) -> None:
353356

354357
# Distance offsets we apply to center of each resource group to find expansion position
355358
offset_range: int = 7
356-
offsets = [
359+
offsets: list[tuple[float, float]] = [
357360
(x, y)
358361
for x, y in itertools.product(range(-offset_range, offset_range + 1), repeat=2)
359362
if 4 < math.hypot(x, y) <= 8
@@ -436,17 +439,20 @@ def _abilities_count_and_build_progress(self) -> tuple[Counter[AbilityId], dict[
436439
if unit.type_id in CREATION_ABILITY_FIX:
437440
if unit.type_id == UnitTypeId.ARCHON:
438441
# Hotfix for archons in morph state
439-
creation_ability = AbilityId.ARCHON_WARP_TARGET
440-
abilities_amount[creation_ability] += 2
442+
creation_ability_id = AbilityId.ARCHON_WARP_TARGET
443+
abilities_amount[creation_ability_id] += 2
441444
else:
442445
# Hotfix for rich geysirs
443-
creation_ability = CREATION_ABILITY_FIX[unit.type_id]
444-
abilities_amount[creation_ability] += 1
446+
creation_ability_id = CREATION_ABILITY_FIX[unit.type_id]
447+
abilities_amount[creation_ability_id] += 1
445448
else:
446-
creation_ability: AbilityId = self.game_data.units[unit.type_id.value].creation_ability.exact_id
447-
abilities_amount[creation_ability] += 1
448-
max_build_progress[creation_ability] = max(
449-
max_build_progress.get(creation_ability, 0), unit.build_progress
449+
creation_ability = self.game_data.units[unit.type_id.value].creation_ability
450+
if creation_ability is None:
451+
continue
452+
creation_ability_id = creation_ability.exact_id
453+
abilities_amount[creation_ability_id] += 1
454+
max_build_progress[creation_ability_id] = max(
455+
max_build_progress.get(creation_ability_id, 0), unit.build_progress
450456
)
451457

452458
return abilities_amount, max_build_progress
@@ -472,7 +478,7 @@ def _worker_orders(self) -> Counter[AbilityId]:
472478

473479
@final
474480
def do(
475-
self,
481+
self: BotAI,
476482
action: UnitCommand,
477483
subtract_cost: bool = False,
478484
subtract_supply: bool = False,
@@ -541,7 +547,7 @@ def do(
541547
return True
542548

543549
@final
544-
async def synchronous_do(self, action: UnitCommand):
550+
async def synchronous_do(self: BotAI, action: UnitCommand):
545551
"""
546552
Not recommended. Use self.do instead to reduce lag.
547553
This function is only useful for realtime=True in the first frame of the game to instantly produce a worker
@@ -553,7 +559,7 @@ async def synchronous_do(self, action: UnitCommand):
553559
if not self.can_afford(action.ability):
554560
logger.warning(f"Cannot afford action {action}")
555561
return ActionResult.Error
556-
r = await self.client.actions(action)
562+
r: ActionResult = await self.client.actions(action) # pyrefly: ignore
557563
if not r: # success
558564
cost = self.game_data.calculate_ability_cost(action.ability)
559565
self.minerals -= cost.minerals
@@ -593,11 +599,14 @@ def prevent_double_actions(action: UnitCommand) -> bool:
593599
# Different action, return True
594600
return True
595601
with suppress(AttributeError):
596-
if current_action.target == action.target.tag:
602+
if current_action.target == action.target.tag: # pyrefly: ignore
597603
# Same action, remove action if same target unit
598604
return False
599605
with suppress(AttributeError):
600-
if action.target.x == current_action.target.x and action.target.y == current_action.target.y:
606+
if (
607+
# pyrefly: ignore
608+
action.target.x == current_action.target.x and action.target.y == current_action.target.y
609+
):
601610
# Same action, remove action if same target position
602611
return False
603612
return True
@@ -648,7 +657,7 @@ def _prepare_first_step(self) -> None:
648657
self._time_before_step: float = time.perf_counter()
649658

650659
@final
651-
def _prepare_step(self, state: GameState, proto_game_info: sc_pb.Response) -> None:
660+
def _prepare_step(self: BotAI, state: GameState, proto_game_info: sc_pb.Response) -> None:
652661
"""
653662
:param state:
654663
:param proto_game_info:
@@ -689,7 +698,7 @@ def _prepare_step(self, state: GameState, proto_game_info: sc_pb.Response) -> No
689698
self.enemy_race = Race(self.all_enemy_units.first.race)
690699

691700
@final
692-
def _prepare_units(self) -> None:
701+
def _prepare_units(self: BotAI) -> None:
693702
# Set of enemy units detected by own sensor tower, as blips have less unit information than normal visible units
694703
self.blips: set[Blip] = set()
695704
self.all_units: Units = Units([], self)
@@ -815,7 +824,7 @@ async def _after_step(self) -> int:
815824
return self.state.game_loop
816825

817826
@final
818-
async def _advance_steps(self, steps: int) -> None:
827+
async def _advance_steps(self: BotAI, steps: int) -> None:
819828
"""Advances the game loop by amount of 'steps'. This function is meant to be used as a debugging and testing tool only.
820829
If you are using this, please be aware of the consequences, e.g. 'self.units' will be filled with completely new data."""
821830
await self._after_step()
@@ -828,7 +837,7 @@ async def _advance_steps(self, steps: int) -> None:
828837
await self.issue_events()
829838

830839
@final
831-
async def issue_events(self) -> None:
840+
async def issue_events(self: BotAI) -> None:
832841
"""This function will be automatically run from main.py and triggers the following functions:
833842
- on_unit_created
834843
- on_unit_destroyed
@@ -843,7 +852,7 @@ async def issue_events(self) -> None:
843852
await self._issue_vision_events()
844853

845854
@final
846-
async def _issue_unit_added_events(self) -> None:
855+
async def _issue_unit_added_events(self: BotAI) -> None:
847856
for unit in self.units:
848857
if unit.tag not in self._units_previous_map and unit.tag not in self._unit_tags_seen_this_game:
849858
self._unit_tags_seen_this_game.add(unit.tag)
@@ -860,14 +869,14 @@ async def _issue_unit_added_events(self) -> None:
860869
await self.on_unit_type_changed(unit, previous_frame_unit.type_id)
861870

862871
@final
863-
async def _issue_upgrade_events(self) -> None:
872+
async def _issue_upgrade_events(self: BotAI) -> None:
864873
difference = self.state.upgrades - self._previous_upgrades
865874
for upgrade_completed in difference:
866875
await self.on_upgrade_complete(upgrade_completed)
867876
self._previous_upgrades = self.state.upgrades
868877

869878
@final
870-
async def _issue_building_events(self) -> None:
879+
async def _issue_building_events(self: BotAI) -> None:
871880
for structure in self.structures:
872881
if structure.tag not in self._structures_previous_map:
873882
if structure.build_progress < 1:
@@ -899,7 +908,7 @@ async def _issue_building_events(self) -> None:
899908
await self.on_building_construction_complete(structure)
900909

901910
@final
902-
async def _issue_vision_events(self) -> None:
911+
async def _issue_vision_events(self: BotAI) -> None:
903912
# Call events for enemy unit entered vision
904913
for enemy_unit in self.enemy_units:
905914
if enemy_unit.tag not in self._enemy_units_previous_map:
@@ -917,7 +926,7 @@ async def _issue_vision_events(self) -> None:
917926
await self.on_enemy_unit_left_vision(enemy_structure_tag)
918927

919928
@final
920-
async def _issue_unit_dead_events(self) -> None:
929+
async def _issue_unit_dead_events(self: BotAI) -> None:
921930
for unit_tag in self.state.dead_units & set(self._all_units_previous_map):
922931
await self.on_unit_destroyed(unit_tag)
923932

sc2/unit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262

6363
if TYPE_CHECKING:
6464
from sc2.bot_ai import BotAI
65-
from sc2.bot_ai_internal import BotAIInternal
6665
from sc2.game_data import AbilityData, UnitTypeData
6766

6867

@@ -108,7 +107,7 @@ class Unit(HasPosition2D):
108107
def __init__(
109108
self,
110109
proto_data: raw_pb2.Unit,
111-
bot_object: BotAI | BotAIInternal,
110+
bot_object: BotAI,
112111
distance_calculation_index: int = -1,
113112
base_build: int = -1,
114113
) -> None:

0 commit comments

Comments
 (0)