-
Notifications
You must be signed in to change notification settings - Fork 128
Add global option to skip households on simulation failure #1023
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
a0ec6a8
8623516
397f250
fa08f11
f26cc80
76fd833
96e73ec
cf98cd2
baa47fa
6cffd9b
0e34b7d
5da9715
50034fd
a6ed5cb
91227b6
c986e5b
e6a8c1b
ee52916
5316890
f0a2581
f04bf14
9d3d018
06c6d75
c257e3c
a734c16
d28faa0
71cc5fd
dcb2864
0615cfc
93c3ba9
96e01f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -307,6 +307,12 @@ def write_trip_matrices( | |
| .TAZ.tolist() | ||
| ) | ||
|
|
||
| # print out number of households skipped due to failed choices | ||
| if state.settings.skip_failed_choices: | ||
| logger.info( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logging level for this should be higher than "info"; at least warning if not "error" level.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've moved the summary of number of skipped households, total and total by model component, to |
||
| f"\n!!!\nATTENTION: Skipped households with failed choices during simulation. Number of households skipped: {state.get('num_skipped_households', 0)}.\n!!!" | ||
| ) | ||
|
|
||
|
|
||
| def annotate_trips( | ||
| state: workflow.State, | ||
|
|
@@ -393,6 +399,21 @@ def write_matrices( | |
| if not matrix_settings: | ||
| logger.error("Missing MATRICES setting in write_trip_matrices.yaml") | ||
|
|
||
| hh_weight_col = model_settings.HH_EXPANSION_WEIGHT_COL | ||
| if hh_weight_col: | ||
| if state.get("num_skipped_households", 0) > 0: | ||
| logger.info( | ||
| f"Adjusting household expansion weights in {hh_weight_col} to account for {state.get('num_skipped_households', 0)} skipped households." | ||
| ) | ||
| # adjust the hh expansion weights to account for skipped households | ||
| adjustment_factor = state.get_dataframe("households").shape[0] / ( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the adjustment to expansion weights be based not on the number of skipped households, but instead on the original expansion weight total and the total expansion weight of the households we dropped? E.g. if we drop a household with a giant expansion weight that's a bigger thing than dropping a household with a smaller one.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've updated the adjustment to expansion weights to be based on the weight of households instead of counts. I agree with you in theory that it should be based on weights. However, I don't think it will make a difference in application. Here's why. In simulation, 99% of the time ActivitySim has the same sample rate for all households and the same expansion weight (calculated as 1/sample rate), this is especially true when ActivitySim uses an input synthetic population with weight of 1 for each record implicitly built in. In contrast to simulation, in estimation, the input households are very likely to have different weights, however, we don't use weights when estimating a model, and |
||
| state.get_dataframe("households").shape[0] | ||
| + state.get("num_skipped_households", 0) | ||
| ) | ||
| aggregate_trips[hh_weight_col] = ( | ||
| aggregate_trips[hh_weight_col] * adjustment_factor | ||
| ) | ||
|
i-am-sijia marked this conversation as resolved.
Outdated
|
||
|
|
||
| for matrix in matrix_settings: | ||
| matrix_is_tap = matrix.is_tap | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1043,6 +1043,10 @@ def force_escortee_trip_modes_to_match_chauffeur(state: workflow.State, trips): | |
| f"Changed {diff.sum()} trip modes of school escortees to match their chauffeur" | ||
| ) | ||
|
|
||
| # trip_mode can be na if the run allows skipping failed choices and the trip mode choice has failed | ||
| if state.settings.skip_failed_choices: | ||
|
i-am-sijia marked this conversation as resolved.
|
||
| return trips | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than just being OK here, code should:
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added the warning for item 1. For item 2, I implemented the check in core.workflow.state.update_table(), immediately after the households are being dropped to make sure it raises an error when N (or the weighted N) exceeds the threshold. Implementing the check in the state instead of in model components makes the code base clean, i.e., less code duplication. The global parameter of the threshold is added in core.configuration.top.fraction_of_failed_choices_allowed. |
||
|
|
||
| assert ( | ||
| ~trips.trip_mode.isna() | ||
| ).all(), f"Missing trip mode for {trips[trips.trip_mode.isna()]}" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -781,6 +781,13 @@ def _check_store_skims_in_shm(self): | |
| should catch many common errors early, including missing required configurations or specified coefficient labels without defined values. | ||
|
i-am-sijia marked this conversation as resolved.
|
||
| """ | ||
|
|
||
| skip_failed_choices: bool = True | ||
| """ | ||
| Skip households that cause errors during processing instead of failing the model run. | ||
|
|
||
| .. versionadded:: 1.6 | ||
| """ | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need additional setting[s] to set thresholds for how many skips are OK and when it's too many and should be an error.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
| other_settings: dict[str, Any] = None | ||
|
|
||
| def _get_attr(self, attr): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ def report_bad_choices( | |
| state: workflow.State, | ||
| bad_row_map, | ||
| df, | ||
| skip_failed_choices, | ||
| trace_label, | ||
| msg, | ||
| trace_choosers=None, | ||
|
|
@@ -87,6 +88,27 @@ def report_bad_choices( | |
|
|
||
| logger.warning(row_msg) | ||
|
|
||
| if skip_failed_choices: | ||
| # update counter in state | ||
| num_skipped_households = state.get("num_skipped_households", 0) | ||
| skipped_household_ids = state.get("skipped_household_ids", set()) | ||
| for hh_id in df[trace_col].unique(): | ||
| if hh_id is None: | ||
| continue | ||
| if hh_id not in skipped_household_ids: | ||
| skipped_household_ids.add(hh_id) | ||
| num_skipped_households += 1 | ||
| else: | ||
| continue | ||
|
i-am-sijia marked this conversation as resolved.
Outdated
|
||
| state.set("num_skipped_households", num_skipped_households) | ||
| state.set("skipped_household_ids", skipped_household_ids) | ||
|
|
||
| logger.debug( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logger level here should be
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated to |
||
| f"Skipping {bad_row_map.sum()} bad choices. Total skipped households so far: {num_skipped_households}. Skipped household IDs: {skipped_household_ids}" | ||
| ) | ||
|
|
||
| return | ||
|
|
||
| if raise_error: | ||
| raise InvalidTravelError(msg_with_count) | ||
|
|
||
|
|
@@ -136,6 +158,7 @@ def utils_to_probs( | |
| allow_zero_probs=False, | ||
| trace_choosers=None, | ||
| overflow_protection: bool = True, | ||
| skip_failed_choices: bool = True, | ||
| return_logsums: bool = False, | ||
| ): | ||
| """ | ||
|
|
@@ -176,6 +199,16 @@ def utils_to_probs( | |
| overflow_protection will have no benefit but impose a modest computational | ||
| overhead cost. | ||
|
|
||
| skip_failed_choices : bool, default True | ||
| If True, when bad choices are detected (all zero probabilities or infinite | ||
| probabilities), the entire household that's causing bad choices will be skipped instead of | ||
| being masked by overflow protection or causing an error. | ||
| A counter will be incremented for each skipped household. This is useful when running large | ||
| simulations where occasional bad choices are encountered and should not halt the process. | ||
| The counter can be accessed via `state.get("num_skipped_households", 0)`. | ||
| The number of skipped households and their IDs will be logged at the end of the simulation. | ||
| When `skip_failed_choices` is True, `overflow_protection` will be reverted to False to avoid conflicts. | ||
|
i-am-sijia marked this conversation as resolved.
Outdated
|
||
|
|
||
| Returns | ||
| ------- | ||
| probs : pandas.DataFrame | ||
|
|
@@ -203,6 +236,13 @@ def utils_to_probs( | |
| utils_arr.dtype == np.float32 and utils_arr.max() > 85 | ||
| ) | ||
|
|
||
| if state.settings.skip_failed_choices is not None: | ||
| skip_failed_choices = state.settings.skip_failed_choices | ||
| # when skipping failed choices, we cannot use overflow protection | ||
| # because it would mask the underlying issue causing bad choices | ||
| if skip_failed_choices: | ||
| overflow_protection = False | ||
|
|
||
| if overflow_protection: | ||
| # exponentiated utils will overflow, downshift them | ||
| shifts = utils_arr.max(1, keepdims=True) | ||
|
|
@@ -240,6 +280,7 @@ def utils_to_probs( | |
| state, | ||
| zero_probs, | ||
| utils, | ||
| skip_failed_choices, | ||
| trace_label=tracing.extend_trace_label(trace_label, "zero_prob_utils"), | ||
| msg="all probabilities are zero", | ||
| trace_choosers=trace_choosers, | ||
|
|
@@ -251,6 +292,7 @@ def utils_to_probs( | |
| state, | ||
| inf_utils, | ||
| utils, | ||
| skip_failed_choices, | ||
| trace_label=tracing.extend_trace_label(trace_label, "inf_exp_utils"), | ||
| msg="infinite exponentiated utilities", | ||
| trace_choosers=trace_choosers, | ||
|
|
@@ -281,6 +323,7 @@ def make_choices( | |
| trace_label: str = None, | ||
| trace_choosers=None, | ||
| allow_bad_probs=False, | ||
| skip_failed_choices=True, | ||
| ) -> tuple[pd.Series, pd.Series]: | ||
| """ | ||
| Make choices for each chooser from among a set of alternatives. | ||
|
|
@@ -316,11 +359,15 @@ def make_choices( | |
| np.ones(len(probs.index)) | ||
| ).abs() > BAD_PROB_THRESHOLD * np.ones(len(probs.index)) | ||
|
|
||
| if state.settings.skip_failed_choices is not None: | ||
| skip_failed_choices = state.settings.skip_failed_choices | ||
|
|
||
| if bad_probs.any() and not allow_bad_probs: | ||
| report_bad_choices( | ||
| state, | ||
| bad_probs, | ||
| probs, | ||
| skip_failed_choices, | ||
|
i-am-sijia marked this conversation as resolved.
Outdated
|
||
| trace_label=tracing.extend_trace_label(trace_label, "bad_probs"), | ||
| msg="probabilities do not add up to 1", | ||
| trace_choosers=trace_choosers, | ||
|
|
@@ -329,6 +376,8 @@ def make_choices( | |
| rands = state.get_rn_generator().random_for_df(probs) | ||
|
|
||
| choices = pd.Series(choice_maker(probs.values, rands), index=probs.index) | ||
| # mark bad choices with -99 | ||
| choices[bad_probs] = -99 | ||
|
|
||
| rands = pd.Series(np.asanyarray(rands).flatten(), index=probs.index) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should check here how many choices are getting dropped, and
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both the warning and the check have been implemented in
core.workflow.state.update_table(). I implemented them in the state.py instead of in model components to make the code base clean, i.e., less code duplication.