Skip to content

Commit 4c9cf9d

Browse files
committed
black
1 parent 9f81a89 commit 4c9cf9d

8 files changed

Lines changed: 125 additions & 60 deletions

File tree

activitysim/abm/models/location_choice.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,9 @@ def run_location_choice(
793793
continue
794794
# using land use rather than size terms in case something goes 0 base -> nonzero project, double
795795
# check if that would be in dest_size_terms as a zero
796-
alts_context = AltsContext.from_series(dest_size_terms.index) # index zone_id, not ALT_DEST_COL_NAME
796+
alts_context = AltsContext.from_series(
797+
dest_size_terms.index
798+
) # index zone_id, not ALT_DEST_COL_NAME
797799
# assumes that dest_size_terms will always contain zeros for non-attractive zones, i.e. it will have the
798800
# same length as land_use
799801

activitysim/abm/models/parking_location_choice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def choose_parking_location(
229229
chunk_size=chunk_size,
230230
trace_hh_id=trace_hh_id,
231231
trace_label=trace_label,
232-
alts_context=alts_context
232+
alts_context=alts_context,
233233
)
234234

235235
if want_sample_table:

activitysim/abm/models/trip_scheduling_choice.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,9 @@ def run_trip_scheduling_choice(
315315
estimator=None,
316316
chunk_sizer=chunk_sizer,
317317
compute_settings=model_settings.compute_settings,
318-
alts_context= AltsContext(schedules[SCHEDULE_ID].min(), schedules[SCHEDULE_ID].max()),
318+
alts_context=AltsContext(
319+
schedules[SCHEDULE_ID].min(), schedules[SCHEDULE_ID].max()
320+
),
319321
)
320322

321323
assert len(choices.index) == len(choosers.index)

activitysim/core/interaction_sample.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from activitysim.core.exceptions import SegmentedSpecificationError
2323
from activitysim.core.skim_dataset import DatasetWrapper
2424
from activitysim.core.skim_dictionary import SkimWrapper
25+
2526
if typing.TYPE_CHECKING:
2627
from activitysim.core.random import Random
2728

@@ -36,7 +37,7 @@ def _poisson_sample_alternatives_inner(
3637
poisson_inclusion_probs: pd.DataFrame,
3738
rng: Random,
3839
trace_label: str | None,
39-
chunk_sizer:ChunkSizer,
40+
chunk_sizer: ChunkSizer,
4041
) -> pd.DataFrame:
4142
rands = rng.random_for_df(probs, n=alternative_count)
4243
chunk_sizer.log_df(trace_label, "rands", rands)
@@ -54,8 +55,8 @@ def make_sample_choices_utility_based(
5455
alternative_count,
5556
alt_col_name,
5657
allow_zero_probs,
57-
trace_label:str,
58-
chunk_sizer:ChunkSizer,
58+
trace_label: str,
59+
chunk_sizer: ChunkSizer,
5960
):
6061
assert isinstance(utilities, pd.DataFrame)
6162
assert utilities.shape == (len(choosers), alternative_count)
@@ -87,8 +88,9 @@ def make_sample_choices_utility_based(
8788
overflow_protection=not allow_zero_probs,
8889
trace_choosers=choosers,
8990
)
90-
inclusion_probs, sampled_alternatives = _poisson_sample_alternatives(alternative_count, chunk_sizer, probs,
91-
sample_size, state, trace_label)
91+
inclusion_probs, sampled_alternatives = _poisson_sample_alternatives(
92+
alternative_count, chunk_sizer, probs, sample_size, state, trace_label
93+
)
9294

9395
# Stack removes the NaNs (the ones that weren't sampled)
9496
# and gives us a multi-index of (person_id, alt_id)
@@ -109,8 +111,14 @@ def make_sample_choices_utility_based(
109111
return choices_df, inclusion_probs
110112

111113

112-
def _poisson_sample_alternatives(alternative_count, chunk_sizer: ChunkSizer, probs: pd.DataFrame, sample_size,
113-
state: workflow.State, trace_label: str) -> tuple[pd.DataFrame, pd.DataFrame]:
114+
def _poisson_sample_alternatives(
115+
alternative_count,
116+
chunk_sizer: ChunkSizer,
117+
probs: pd.DataFrame,
118+
sample_size,
119+
state: workflow.State,
120+
trace_label: str,
121+
) -> tuple[pd.DataFrame, pd.DataFrame]:
114122
# compute the inclusion probability as the reciprocal of alt never being drawn
115123
# -- these are common, so compute once upfront
116124
exclusion_probs = (1 - probs) ** sample_size
@@ -119,21 +127,31 @@ def _poisson_sample_alternatives(alternative_count, chunk_sizer: ChunkSizer, pro
119127
n = 0
120128
probs_subset = probs
121129
inclusion_probs_subset = inclusion_probs
122-
sampled_alternatives = pd.DataFrame(0.0, index=inclusion_probs.index, columns=inclusion_probs.columns)
130+
sampled_alternatives = pd.DataFrame(
131+
0.0, index=inclusion_probs.index, columns=inclusion_probs.columns
132+
)
123133
while True:
124134
sampled_results_subset = _poisson_sample_alternatives_inner(
125-
alternative_count, probs_subset, inclusion_probs_subset, state.get_rn_generator(), trace_label, chunk_sizer
135+
alternative_count,
136+
probs_subset,
137+
inclusion_probs_subset,
138+
state.get_rn_generator(),
139+
trace_label,
140+
chunk_sizer,
126141
)
127142
no_alts_sampled_mask = sampled_results_subset.isna().all(axis=1)
128143
alts_with_sampled_alternatives = sampled_results_subset[~no_alts_sampled_mask]
129-
sampled_alternatives.loc[alts_with_sampled_alternatives.index, :] = alts_with_sampled_alternatives
144+
sampled_alternatives.loc[
145+
alts_with_sampled_alternatives.index, :
146+
] = alts_with_sampled_alternatives
130147
if no_alts_sampled_mask.any():
131148
# TODO if this happens in base but the project case is such that something is picked, random numbers won't
132149
# be consistent - we're asserting that this is very rare models where the sample size is not too small
133150
logger.info(f"Poisson sampling of alternatives failed with {n=}, retrying")
134151
# TODO put this behind a debug guard, because it will be slow
135152
logger.info(
136-
f"Sampled size was {sample_size}, poisson method mean expected sample size was {inclusion_probs.sum(axis=1).mean():.1f}, actual sampled mean was {(sampled_alternatives > 0).sum(axis=1).mean():.1f} and highest zero selection prob was {(exclusion_probs).product(axis=1).max():.2g}")
153+
f"Sampled size was {sample_size}, poisson method mean expected sample size was {inclusion_probs.sum(axis=1).mean():.1f}, actual sampled mean was {(sampled_alternatives > 0).sum(axis=1).mean():.1f} and highest zero selection prob was {(exclusion_probs).product(axis=1).max():.2g}"
154+
)
137155
probs_subset = probs[no_alts_sampled_mask]
138156
inclusion_probs_subset = inclusion_probs[no_alts_sampled_mask]
139157

@@ -143,8 +161,10 @@ def _poisson_sample_alternatives(alternative_count, chunk_sizer: ChunkSizer, pro
143161
n += 1
144162
if n == 10:
145163
choosers_no_alts_sampled = sampled_results_subset[no_alts_sampled_mask]
146-
msg = (f"Poisson choice set sampling failed after 10 attempts for these cases:\n"
147-
f"{choosers_no_alts_sampled}\n{probs_subset}")
164+
msg = (
165+
f"Poisson choice set sampling failed after 10 attempts for these cases:\n"
166+
f"{choosers_no_alts_sampled}\n{probs_subset}"
167+
)
148168
raise ValueError(msg)
149169

150170
chunk_sizer.log_df(trace_label, "sampled_alternatives", sampled_alternatives)
@@ -260,7 +280,7 @@ def _interaction_sample(
260280
locals_d=None,
261281
trace_label=None,
262282
zone_layer=None,
263-
chunk_sizer: ChunkSizer|None=None,
283+
chunk_sizer: ChunkSizer | None = None,
264284
compute_settings: ComputeSettings | None = None,
265285
):
266286
"""
@@ -325,7 +345,9 @@ def _interaction_sample(
325345
pick_count : int
326346
number of duplicate picks for chooser, alt
327347
"""
328-
assert chunk_sizer is not None, "chunk_sizer cannot be None but old nullable signature is preserved"
348+
assert (
349+
chunk_sizer is not None
350+
), "chunk_sizer cannot be None but old nullable signature is preserved"
329351
# TODO it's probably safe to reorder these arguments to make chunk_sizer mandatory since
330352
# _interaction_sample is private?
331353

activitysim/core/interaction_sample_simulate.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def _interaction_sample_simulate(
283283
if alts_context is not None:
284284
alt_nrs_df = pd.DataFrame(padded_alt_nrs, index=choosers.index)
285285
else:
286-
alt_nrs_df = None # if we don't provide the number of dense alternatives, assume that we'll use the old approach
286+
alt_nrs_df = None # if we don't provide the number of dense alternatives, assume that we'll use the old approach
287287
chunk_sizer.log_df(trace_label, "utilities_df", utilities_df)
288288

289289
del padded_utilities
@@ -330,8 +330,12 @@ def _interaction_sample_simulate(
330330
# positions is series with the chosen alternative represented as a column index in utilities_df
331331
# which is an integer between zero and num alternatives in the alternative sample
332332
positions, rands = logit.make_choices_utility_based(
333-
state, utilities_df, trace_label=trace_label, trace_choosers=choosers, alts_context=alts_context,
334-
alt_nrs_df=alt_nrs_df
333+
state,
334+
utilities_df,
335+
trace_label=trace_label,
336+
trace_choosers=choosers,
337+
alts_context=alts_context,
338+
alt_nrs_df=alt_nrs_df,
335339
)
336340

337341
del utilities_df

activitysim/core/logit.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -345,11 +345,14 @@ def utils_to_probs(
345345
return probs, logsums
346346
return probs
347347

348+
348349
FREEZE_RANDOM_NUMBERS_FOR_DENSE_ALTERNATIVE_SET = True
349350

351+
350352
@dataclass
351353
class AltsContext:
352354
"""Representation of the alternatives without carrying around that full array."""
355+
353356
min_alt_id: int
354357
max_alt_id: int
355358

@@ -359,55 +362,67 @@ def __post_init__(self):
359362
self.n_rands_to_sample = max(self.max_alt_id, self.n_alts_to_cover_max_id)
360363

361364
@classmethod
362-
def from_series(cls, ser:Union[pd.Series,pd.Index])->"AltsContext":
365+
def from_series(cls, ser: Union[pd.Series, pd.Index]) -> "AltsContext":
363366
min_alt_id = ser.min()
364367
max_alt_id = ser.max()
365368
return cls(min_alt_id, max_alt_id)
366369

367370
@classmethod
368-
def from_num_alts(cls, num_alts:int, zero_based:bool=True)->"AltsContext":
371+
def from_num_alts(cls, num_alts: int, zero_based: bool = True) -> "AltsContext":
369372
if zero_based:
370373
offset = -1
371374
else:
372-
offset =0
373-
return cls(min_alt_id=1+offset, max_alt_id=num_alts+offset )
374-
375+
offset = 0
376+
return cls(min_alt_id=1 + offset, max_alt_id=num_alts + offset)
375377

376378
@property
377379
def n_alts_to_cover_max_id(self) -> int:
378380
"""If zones were non-consecutive, this could be a big over-estimate."""
379-
return self.max_alt_id+1
381+
return self.max_alt_id + 1
380382

381383

382384
# TODO-EET: add doc string, tracing
383-
def add_ev1_random(state: workflow.State, df: pd.DataFrame, alt_info: AltsContext | None = None,
384-
alt_nrs_df: pd.DataFrame | None = None, ):
385+
def add_ev1_random(
386+
state: workflow.State,
387+
df: pd.DataFrame,
388+
alt_info: AltsContext | None = None,
389+
alt_nrs_df: pd.DataFrame | None = None,
390+
):
385391

386392
nest_utils_for_choice = df.copy()
387393
assert (alt_info is None) == (
388-
alt_nrs_df is None), "n_zones and alt_nrs_df must both be provided or omitted together"
394+
alt_nrs_df is None
395+
), "n_zones and alt_nrs_df must both be provided or omitted together"
389396

390397
if alt_nrs_df is not None and FREEZE_RANDOM_NUMBERS_FOR_DENSE_ALTERNATIVE_SET:
391-
assert alt_info is not None # narrowing for mypy
398+
assert alt_info is not None # narrowing for mypy
392399

393400
idx_array = alt_nrs_df.values
394401
mask = idx_array == -999
395-
safe_idx = np.where(mask, 1, idx_array) # replace -999 with a temp value inbounds
402+
safe_idx = np.where(
403+
mask, 1, idx_array
404+
) # replace -999 with a temp value inbounds
396405
# generate random number for all alts - this is wasteful, but ensures that the same zone
397406
# gets the same random number if the sampled choice set changes between base and project
398407
# (alternatively, one could seed a channel for (persons x zones) and use the zone seed to ensure consistency.
399408
# Trade off is needing to seed (persons x zones) rows and multiindex channels to
400409
# avoid extra random numbers generated here. Quick benchmark suggests seeding per row is likely slower
401-
rands_dense = state.get_rn_generator().gumbel_for_df(nest_utils_for_choice, n=alt_info.n_alts_to_cover_max_id)
410+
rands_dense = state.get_rn_generator().gumbel_for_df(
411+
nest_utils_for_choice, n=alt_info.n_alts_to_cover_max_id
412+
)
402413
# generate n=alt_info.max_alt_id+1 rather than n_alts so that indexing works
403414
# (this is drawing a random number for a redundant zeroth zone in 1 based zoning systems)
404415
# TODO deal with non 0->n-1 indexed land use more efficiently? ideally do where alt_nrs_df is constructed,
405416
# not on the fly here. Potentially via state.get_injectable('network_los').get_skim_dict('taz').zone_ids
406417
rands = np.take_along_axis(rands_dense, safe_idx, axis=1)
407-
rands[mask] = 0 # zero out the masked zones so they don't have the util adjustment of alt 0
418+
rands[
419+
mask
420+
] = 0 # zero out the masked zones so they don't have the util adjustment of alt 0
408421
else:
409422
# old behaviour, to remove
410-
rands = state.get_rn_generator().gumbel_for_df(nest_utils_for_choice, n=nest_utils_for_choice.shape[1])
423+
rands = state.get_rn_generator().gumbel_for_df(
424+
nest_utils_for_choice, n=nest_utils_for_choice.shape[1]
425+
)
411426

412427
nest_utils_for_choice += rands
413428
return nest_utils_for_choice
@@ -461,10 +476,13 @@ def make_choices_explicit_error_term_nl(
461476

462477

463478
# TODO-EET: add doc string, tracing
464-
def make_choices_explicit_error_term_mnl(state, utilities, trace_label,
465-
alts_context: AltsContext | None = None,
466-
alt_nrs_df: pd.DataFrame | None = None,
467-
):
479+
def make_choices_explicit_error_term_mnl(
480+
state,
481+
utilities,
482+
trace_label,
483+
alts_context: AltsContext | None = None,
484+
alt_nrs_df: pd.DataFrame | None = None,
485+
):
468486
utilities_incl_unobs = add_ev1_random(state, utilities, alts_context, alt_nrs_df)
469487
choices = np.argmax(utilities_incl_unobs.to_numpy(), axis=1)
470488
# TODO-EET: reporting like for zero probs
@@ -474,13 +492,19 @@ def make_choices_explicit_error_term_mnl(state, utilities, trace_label,
474492

475493

476494
def make_choices_explicit_error_term(
477-
state, utilities, alt_order_array, nest_spec=None, trace_label=None,
478-
alts_context: AltsContext | None = None,
479-
alt_nrs_df: pd.DataFrame | None = None,
495+
state,
496+
utilities,
497+
alt_order_array,
498+
nest_spec=None,
499+
trace_label=None,
500+
alts_context: AltsContext | None = None,
501+
alt_nrs_df: pd.DataFrame | None = None,
480502
):
481503
trace_label = tracing.extend_trace_label(trace_label, "make_choices_eet")
482504
if nest_spec is None:
483-
choices = make_choices_explicit_error_term_mnl(state, utilities, trace_label, alts_context, alt_nrs_df)
505+
choices = make_choices_explicit_error_term_mnl(
506+
state, utilities, trace_label, alts_context, alt_nrs_df
507+
)
484508
else:
485509
choices = make_choices_explicit_error_term_nl(
486510
state, utilities, alt_order_array, nest_spec, trace_label

activitysim/core/simulate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
SPEC_EXPRESSION_NAME,
4343
SPEC_LABEL_NAME,
4444
)
45+
4546
if typing.TYPE_CHECKING:
4647
from activitysim.core.estimation import Estimator
4748

0 commit comments

Comments
 (0)