2222from activitysim .core .exceptions import SegmentedSpecificationError
2323from activitysim .core .skim_dataset import DatasetWrapper
2424from activitysim .core .skim_dictionary import SkimWrapper
25+
2526if typing .TYPE_CHECKING :
2627 from activitysim .core .random import Random
2728
@@ -36,7 +37,7 @@ def _poisson_sample_alternatives_inner(
3637 poisson_inclusion_probs : pd .DataFrame ,
3738 rng : Random ,
3839 trace_label : str | None ,
39- chunk_sizer :ChunkSizer ,
40+ chunk_sizer : ChunkSizer ,
4041) -> pd .DataFrame :
4142 rands = rng .random_for_df (probs , n = alternative_count )
4243 chunk_sizer .log_df (trace_label , "rands" , rands )
@@ -54,8 +55,8 @@ def make_sample_choices_utility_based(
5455 alternative_count ,
5556 alt_col_name ,
5657 allow_zero_probs ,
57- trace_label :str ,
58- chunk_sizer :ChunkSizer ,
58+ trace_label : str ,
59+ chunk_sizer : ChunkSizer ,
5960):
6061 assert isinstance (utilities , pd .DataFrame )
6162 assert utilities .shape == (len (choosers ), alternative_count )
@@ -87,8 +88,9 @@ def make_sample_choices_utility_based(
8788 overflow_protection = not allow_zero_probs ,
8889 trace_choosers = choosers ,
8990 )
90- inclusion_probs , sampled_alternatives = _poisson_sample_alternatives (alternative_count , chunk_sizer , probs ,
91- sample_size , state , trace_label )
91+ inclusion_probs , sampled_alternatives = _poisson_sample_alternatives (
92+ alternative_count , chunk_sizer , probs , sample_size , state , trace_label
93+ )
9294
9395 # Stack removes the NaNs (the ones that weren't sampled)
9496 # and gives us a multi-index of (person_id, alt_id)
@@ -109,8 +111,14 @@ def make_sample_choices_utility_based(
109111 return choices_df , inclusion_probs
110112
111113
112- def _poisson_sample_alternatives (alternative_count , chunk_sizer : ChunkSizer , probs : pd .DataFrame , sample_size ,
113- state : workflow .State , trace_label : str ) -> tuple [pd .DataFrame , pd .DataFrame ]:
114+ def _poisson_sample_alternatives (
115+ alternative_count ,
116+ chunk_sizer : ChunkSizer ,
117+ probs : pd .DataFrame ,
118+ sample_size ,
119+ state : workflow .State ,
120+ trace_label : str ,
121+ ) -> tuple [pd .DataFrame , pd .DataFrame ]:
114122 # compute the inclusion probability as the reciprocal of alt never being drawn
115123 # -- these are common, so compute once upfront
116124 exclusion_probs = (1 - probs ) ** sample_size
@@ -119,21 +127,31 @@ def _poisson_sample_alternatives(alternative_count, chunk_sizer: ChunkSizer, pro
119127 n = 0
120128 probs_subset = probs
121129 inclusion_probs_subset = inclusion_probs
122- sampled_alternatives = pd .DataFrame (0.0 , index = inclusion_probs .index , columns = inclusion_probs .columns )
130+ sampled_alternatives = pd .DataFrame (
131+ 0.0 , index = inclusion_probs .index , columns = inclusion_probs .columns
132+ )
123133 while True :
124134 sampled_results_subset = _poisson_sample_alternatives_inner (
125- alternative_count , probs_subset , inclusion_probs_subset , state .get_rn_generator (), trace_label , chunk_sizer
135+ alternative_count ,
136+ probs_subset ,
137+ inclusion_probs_subset ,
138+ state .get_rn_generator (),
139+ trace_label ,
140+ chunk_sizer ,
126141 )
127142 no_alts_sampled_mask = sampled_results_subset .isna ().all (axis = 1 )
128143 alts_with_sampled_alternatives = sampled_results_subset [~ no_alts_sampled_mask ]
129- sampled_alternatives .loc [alts_with_sampled_alternatives .index , :] = alts_with_sampled_alternatives
144+ sampled_alternatives .loc [
145+ alts_with_sampled_alternatives .index , :
146+ ] = alts_with_sampled_alternatives
130147 if no_alts_sampled_mask .any ():
131148 # TODO if this happens in base but the project case is such that something is picked, random numbers won't
132149 # be consistent - we're asserting that this is very rare models where the sample size is not too small
133150 logger .info (f"Poisson sampling of alternatives failed with { n = } , retrying" )
134151 # TODO put this behind a debug guard, because it will be slow
135152 logger .info (
136- f"Sampled size was { sample_size } , poisson method mean expected sample size was { inclusion_probs .sum (axis = 1 ).mean ():.1f} , actual sampled mean was { (sampled_alternatives > 0 ).sum (axis = 1 ).mean ():.1f} and highest zero selection prob was { (exclusion_probs ).product (axis = 1 ).max ():.2g} " )
153+ f"Sampled size was { sample_size } , poisson method mean expected sample size was { inclusion_probs .sum (axis = 1 ).mean ():.1f} , actual sampled mean was { (sampled_alternatives > 0 ).sum (axis = 1 ).mean ():.1f} and highest zero selection prob was { (exclusion_probs ).product (axis = 1 ).max ():.2g} "
154+ )
137155 probs_subset = probs [no_alts_sampled_mask ]
138156 inclusion_probs_subset = inclusion_probs [no_alts_sampled_mask ]
139157
@@ -143,8 +161,10 @@ def _poisson_sample_alternatives(alternative_count, chunk_sizer: ChunkSizer, pro
143161 n += 1
144162 if n == 10 :
145163 choosers_no_alts_sampled = sampled_results_subset [no_alts_sampled_mask ]
146- msg = (f"Poisson choice set sampling failed after 10 attempts for these cases:\n "
147- f"{ choosers_no_alts_sampled } \n { probs_subset } " )
164+ msg = (
165+ f"Poisson choice set sampling failed after 10 attempts for these cases:\n "
166+ f"{ choosers_no_alts_sampled } \n { probs_subset } "
167+ )
148168 raise ValueError (msg )
149169
150170 chunk_sizer .log_df (trace_label , "sampled_alternatives" , sampled_alternatives )
@@ -260,7 +280,7 @@ def _interaction_sample(
260280 locals_d = None ,
261281 trace_label = None ,
262282 zone_layer = None ,
263- chunk_sizer : ChunkSizer | None = None ,
283+ chunk_sizer : ChunkSizer | None = None ,
264284 compute_settings : ComputeSettings | None = None ,
265285):
266286 """
@@ -325,7 +345,9 @@ def _interaction_sample(
325345 pick_count : int
326346 number of duplicate picks for chooser, alt
327347 """
328- assert chunk_sizer is not None , "chunk_sizer cannot be None but old nullable signature is preserved"
348+ assert (
349+ chunk_sizer is not None
350+ ), "chunk_sizer cannot be None but old nullable signature is preserved"
329351 # TODO it's probably safe to reorder these arguments to make chunk_sizer mandatory since
330352 # _interaction_sample is private?
331353
0 commit comments