From 0cf3f00e61bb04e29ba7032b6e603b6d4b0a35d0 Mon Sep 17 00:00:00 2001 From: Bohan Yang <74260393+whoisfrankyang@users.noreply.github.com> Date: Wed, 15 Apr 2026 18:10:49 -0400 Subject: [PATCH] =?UTF-8?q?Revert=20"Revert=20"Preference=20learning=20int?= =?UTF-8?q?egration:=20predict=20=E2=86=92=20correct=20=E2=86=92=20apply?= =?UTF-8?q?=20=E2=86=92=20learn=20+=20terminal=20mode""?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agents/development_log.md | 246 +++++ src/feeding_deployment/actions/acquisition.py | 248 +---- src/feeding_deployment/actions/base.py | 8 +- .../actions/behavior_trees/acquire_bite.yaml | 2 +- .../press_microwave_button.yaml | 8 + .../behavior_trees/transfer_drink.yaml | 6 +- .../behavior_trees/transfer_utensil.yaml | 13 +- .../actions/behavior_trees/transfer_wipe.yaml | 4 +- src/feeding_deployment/actions/close_door.py | 4 - src/feeding_deployment/actions/flair/flair.py | 3 + .../actions/flair/inference_class.py | 5 + src/feeding_deployment/actions/navigate.py | 4 - src/feeding_deployment/actions/open_door.py | 104 +- src/feeding_deployment/actions/pick_plate.py | 7 - src/feeding_deployment/actions/pick_tool.py | 129 +-- src/feeding_deployment/actions/place_plate.py | 5 - .../actions/press_microwave_button.py | 5 +- src/feeding_deployment/actions/stow_tool.py | 113 +-- .../actions/transfer_tool.py | 106 +- .../integration/apply_preferences.py | 447 +++++++++ .../integration/preference_context.py | 52 + src/feeding_deployment/integration/run.py | 203 +++- .../integration/saved_states/last_state.p | Bin 2588 -> 3370 bytes .../integration/terminal_preferences.py | 81 ++ .../interfaces/web_interface.py | 109 +- .../methods/long_term_memory.py | 15 +- .../methods/prediction_model.py | 16 +- .../methods/prompts/bundle_prediction.py | 18 +- .../methods/prompts/ltm_update.py | 18 +- .../preference_learning/profiles/alice.txt | 4 + src/feeding_deployment/simulation/world.py | 2 +- tests/__init__.py | 0 tests/test_apply_preferences.py | 583 +++++++++++ tests/test_preference_integration.py | 946 ++++++++++++++++++ 34 files changed, 2787 insertions(+), 727 deletions(-) create mode 100644 agents/development_log.md create mode 100644 src/feeding_deployment/integration/apply_preferences.py create mode 100644 src/feeding_deployment/integration/preference_context.py create mode 100644 src/feeding_deployment/integration/terminal_preferences.py create mode 100644 src/feeding_deployment/preference_learning/profiles/alice.txt create mode 100644 tests/__init__.py create mode 100644 tests/test_apply_preferences.py create mode 100644 tests/test_preference_integration.py diff --git a/agents/development_log.md b/agents/development_log.md new file mode 100644 index 00000000..52f446a1 --- /dev/null +++ b/agents/development_log.md @@ -0,0 +1,246 @@ +## Planned Integration Steps + +### 1. Before a meal (or at start of a session) build `context` (`meal`/`setting`/`time_of_day`) + +**Implementation** + +`integration/preference_context.py` + +- `build_preference_context(meal, setting, time_of_day)` +- `validate_preference_context(context)` + +`integration/run.py` — `_Runner` + +- `preference_context` +- `ensure_preference_context()` +- `set_meal_preference_context(meal, setting, time_of_day)` + +**Operational flow (web deployment)** + +1. Start the process with `--use_interface` and a non-empty `--pref_meal`, `--pref_setting`, and `--pref_time_of_day`. +2. `__main__` calls `set_meal_preference_context` — context lives on the runner. +3. `run()` calls `ensure_preference_context()` — confirms context is set, then continues. + +--- + +### 2. Predict + +- Instantiate `PredictionModel(user, physical_profile_label, logs_dir=log_dir/"preference_learning")` +- Call `predict_bundle(context, corrected={})` → `predicted_bundle` + +**Implementation** + +1. `PredictionModel` is instantiated in `run()` after `ensure_preference_context()`, with `logs_dir=self.log_dir / "preference_learning"` and the freeform physical profile description loaded from the `.txt` file. +2. `predict_bundle(context, corrected={})` is called, result stored on `self.predicted_bundle` and printed to stdout. +3. `self._prediction_model` is kept on the runner for later use. +4. Physical profile flows from `--physical_profile_file` through `_Runner` → `PredictionModel` → prompt templates as freeform text. +5. CLI requires `--physical_profile_file` and `--pref_meal` with `--use_interface`; missing either raises immediately. + +**Modifications** + +1. `integration/run.py` + - `_Runner.__init__`: added new instance variables: + - `self.deployment_user` — stores the user string. + - `self.physical_profile_label` — stores the freeform physical profile text. + - `self._prediction_model: PredictionModel | None` — initialized to `None`. + - `self.predicted_bundle: dict[str, str] | None` — initialized to `None`. + - `run()`: added a block between `ensure_preference_context()` and `ready_for_task_selection()` that constructs `PredictionModel` and calls `predict_bundle(dict(ctx), {})`, storing the result on `self.predicted_bundle`. + - `__main__`: added `--physical_profile_file` argument. Required when `--use_interface` is set. +2. `preference_learning/methods/prediction_model.py` (`PredictionModel`) + - `__init__`: added optional parameter `physical_profile_description`. Stored on `self.physical_profile_description`. Forwarded into `LongTermMemoryModel(...)` construction. + - `predict_bundle`: forwards `self.physical_profile_description` into `get_bundle_prediction_prompt(...)`. +3. `preference_learning/methods/long_term_memory.py` (`LongTermMemoryModel`) + - `__init__`: added optional parameter `physical_profile_description`. Stored on `self._physical_profile_description`. + - `add_episode`: forwards `self._physical_profile_description` into `get_ltm_update_prompt(...)`. + - `main()`: fixed a broken call — was constructing `LongTermMemoryModel` without `physical_profile_label` and calling a nonexistent `.reset()` method. Now uses the real `__init__` signature. +4. `preference_learning/methods/prompts/bundle_prediction.py` + - `get_bundle_prediction_prompt`: added optional keyword-only parameter `physical_profile_description`. When provided and non-empty, uses it directly as the `{physical_profile}` template value instead of looking up `physical_profile_label` in `PHYSICAL_CAPABILITY_PROFILES`. When `None`, falls back to the original label-based lookup. +5. `preference_learning/methods/prompts/ltm_update.py` + - `get_ltm_update_prompt`: same change as `get_bundle_prediction_prompt` — added optional `physical_profile_description`. + +--- + +### 3. Show prediction + allow correction in the UI + +- UI displays the predicted 18 fields. +- User edits any fields they want. +- Collect: + - `ground_truth_bundle` + - `corrected` + +**Implementation** + +1. `interfaces/web_interface.py` + - `get_preference_corrections(predicted_bundle, pref_options) -> dict[str, str]`: new method. Docstring defines the ROS message contract for the future frontend page: + - **Sends** to webapp: `{"state": "preference_correction", "status": "jump"}` then `{"state": "preference_correction_data", "predicted_bundle": {...}, "options": {...}}`. + - **Expects back**: `{"state": "preference_correction_response", "bundle": {...}}` with all 18 fields and the user's final selections. + - The ROS calls are **commented out** (frontend page does not exist yet). Currently returns `predicted_bundle` unchanged as a stub. +2. `integration/run.py` + - `run()`: after `predict_bundle`, calls `web_interface.get_preference_corrections(predicted_bundle, PREF_OPTIONS)`. Then computes: + - `self.ground_truth_bundle` — the full 18-field dict returned from the correction step. + - `self.corrected` — only the fields where the returned value differs from the prediction. + - Imports `PREF_OPTIONS` from `prediction_model.py` (already computed there from `PREFERENCE_BUNDLE` config; no config changes needed). + +**When frontend is ready** + +Uncomment the ROS lines in `get_preference_corrections`, remove the stub return, and build the webapp page that renders 18 dropdowns pre-filled with `predicted_bundle` values and sends back `preference_correction_response`. + +--- + +### 4. Apply to the system + +- Translate `ground_truth_bundle` into concrete BT parameter writes in the run's behavior-tree directory (`self.run_behavior_tree_dir`). + - This makes the next executions use the updated parameters automatically as `execute_action()` loads the YAML fresh each time. +- For planner-level preferences: + - Adjust `Runner.current_atoms` (e.g. add/remove `FoodHeated`) before calling `process_user_command()` for the first time in that meal. + +**Design decisions** + +1. **`robot_speed` vocabulary mismatch**: bundle uses `"slow"/"medium"/"fast"` but BT YAML `Speed` enum uses `"low"/"medium"/"high"`. Decision: map in code (`slow`->`low`, `fast`->`high`), keep both vocabularies as-is. +2. **`convey_robot_ready_*` "speech + LED"**: BT YAML enum only allows one of `silent`/`voice`/`led`/`beep`. Decision: skip for now with a runtime warning, fall back to `"voice"`. Handle later when BT supports combined cues. +3. **`wait_before_autocontinue_seconds` "1000 sec"**: exceeds original BT Box upper bound of 100.0. Decision: raise the YAML Box upper bound to 1000.0 in the three source templates. +4. **`outside_mouth_distance` label-to-float mapping**: bundle uses discrete labels, BT uses continuous `[0.1, 0.2]` meters. Decision: `near`=0.1, `medium`=0.15, `far`=0.2. `"not applicable"` skips the write. +5. **6 non-BT fields** (`retract_between_bites`, `bite_dipping_preference`, `microwave_time`, `occlusion_relevance`, `skewering_axis`, `transfer_mode` planner atoms): deferred to a later step. +6. **`transfer_mode`**: `TransferToolHLA.__init__` sets `self.transfer` once at startup. Decision: update `scene_description.transfer_type` and re-initialize the `TransferToolHLA.transfer` object with the correct `InsideMouthTransfer` or `OutsideMouthTransfer`. +7. **`microwave_time` -> `FoodHeated` atom manipulation**: deferred to a later step. + +**Implementation** + +1. `integration/apply_preferences.py` (new file) + - Declarative mapping table (`_BT_MAPPING`): each entry maps a bundle field to a list of (YAML filename, BT parameter name, value translator). + - Value translators: `_SPEED_MAP`, `_CONFIRMATION_MAP`, `_AUTOCONTINUE_MAP`, `_OUTSIDE_MOUTH_DISTANCE_MAP`, `_CONVEY_READY_MAP`, `_INITIATE_TRANSFER_MAP`, `_COMPLETE_TRANSFER_MAP`, `_TRANSFER_MODE_MAP`. + - `apply_bundle_to_behavior_trees(bundle, bt_dir) -> list[str]`: iterates the mapping table, loads affected YAML files (with a custom loader that round-trips `!hla` tags), overwrites `value` entries, saves back. Returns warnings for edge cases. + - `apply_transfer_mode(bundle, scene_description, hla_map)`: reads `bundle["transfer_mode"]`, sets `scene_description.transfer_type`, and re-instantiates `TransferToolHLA.transfer` with the correct `InsideMouthTransfer`/`OutsideMouthTransfer`. +2. `actions/behavior_trees/acquire_bite.yaml`, `transfer_utensil.yaml`, `transfer_drink.yaml` + - Raised `TimeToWaitBeforeAutocontinue` Box upper bound from `100.0` to `1000.0`. +3. `integration/run.py` + - `run()`: after computing `ground_truth_bundle` and `corrected`, calls `apply_bundle_to_behavior_trees` and `apply_transfer_mode` before `ready_for_task_selection()`. All subsequent `execute_action()` calls pick up the updated YAML values from disk. + +**Bundle field -> BT parameter mapping (17 fields)** + +| Bundle field | BT parameter | YAML files | +|---|---|---| +| `robot_speed` | `Speed` | all 29 YAMLs | +| `web_interface_confirmation` | `TransferAskForConfirmation` | `acquire_bite.yaml` | +| `web_interface_confirmation` | `AskForConfirmationInitiatingTransferSequence` | `transfer_drink.yaml`, `transfer_wipe.yaml` | +| `wait_before_autocontinue_seconds` | `TimeToWaitBeforeAutocontinue` | `acquire_bite.yaml`, `transfer_utensil.yaml`, `transfer_drink.yaml` | +| `outside_mouth_distance` | `OutsideMouthDistance` | `transfer_utensil.yaml`, `transfer_drink.yaml`, `transfer_wipe.yaml` | +| `convey_robot_ready_for_initiating_transfer` | `ReadyToInitiateTransferInteraction` | `transfer_utensil.yaml`, `transfer_drink.yaml`, `transfer_wipe.yaml` | +| `detect_user_ready_for_initiating_transfer_feeding` | `InitiateTransferInteraction` | `transfer_utensil.yaml` | +| `detect_user_ready_for_initiating_transfer_drinking` | `InitiateTransferInteraction` | `transfer_drink.yaml` | +| `detect_user_ready_for_initiating_transfer_wiping` | `InitiateTransferInteraction` | `transfer_wipe.yaml` | +| `convey_robot_ready_for_completing_transfer` | `ReadyForTransferInteraction` | `transfer_utensil.yaml`, `transfer_drink.yaml`, `transfer_wipe.yaml` | +| `detect_user_completed_transfer_feeding` | `TransferCompleteInteraction` | `transfer_utensil.yaml` | +| `detect_user_completed_transfer_drinking` | `TransferCompleteInteraction` | `transfer_drink.yaml` | +| `detect_user_completed_transfer_wiping` | `TransferCompleteInteraction` | `transfer_wipe.yaml` | +| `skewering_axis` | `SkeweringOrientation` | `acquire_bite.yaml` | +| `bite_dipping_preference` | `FoodDippingDepth` | `acquire_bite.yaml` | +| `microwave_time` | `MicrowaveDuration` | `press_microwave_button.yaml` | +| `retract_between_bites` | `RetractAfterTransfer` | `transfer_utensil.yaml` | + +**`skewering_axis`** — clean enum-to-enum mapping via `_SKEWERING_AXIS_MAP`: `"parallel to major axis"` -> `"horizontal"`, `"perpendicular to major axis"` -> `"vertical"`. Targets `SkeweringOrientation` in `acquire_bite.yaml`. + +**`bite_dipping_preference`** — label-to-float mapping via `_dipping_depth_translate()`: `"less"` -> `0.01` (Box minimum), `"more"` -> `0.03` (Box maximum). `"do not dip"` -> skip (returns `None`, leaves `FoodDippingDepth` at its default). + +**Important caveat for `"do not dip"`**: skipping the `FoodDippingDepth` BT write does **not** prevent FLAIR's autonomous planner from choosing to dip. The dipping decision is made independently in `inference_class.py` (`get_autonomous_action`), which looks at plate contents and user preferences. To truly suppress dipping, a flag must be passed into FLAIR's planning logic — this is not yet implemented. This differs from `outside_mouth_distance = "not applicable"`, which is safe because the `transfer_mode` preference separately switches the execution path to `InsideMouthTransfer` (which never reads `OutsideMouthDistance`). + +**`microwave_time`** — dual-layer integration (planner + BT): +- **Planner level** via `apply_microwave_preference(bundle, current_atoms, food_heated_atom)`: + - `"no microwave"` -> adds `GroundAtom(FoodHeated, [])` to `current_atoms`, causing the PDDL planner to skip the entire microwave sequence. The BT write is also skipped (returns `None`). + - `"1 min"` / `"2 min"` / `"3 min"` -> discards `FoodHeated` from `current_atoms` so the planner includes microwave steps. +- **BT level** via `_microwave_duration_translate` + `_BT_MAPPING`: + - Added `MicrowaveDuration` Box parameter (30.0–300.0s, default 60.0) to `press_microwave_button.yaml`. + - Updated `PressMicrowaveButtonHLA.press_microwave_button(self, speed, duration)` to accept the duration as a second positional argument. + - `"1 min"` -> 60.0, `"2 min"` -> 120.0, `"3 min"` -> 180.0. Written to the YAML so when `execute_action` loads the BT, the duration flows through the `!hla` function binding into the HLA method. + - `"no microwave"` -> skip (the HLA is never executed anyway since the planner excludes the action). + +**`"speech + LED"` combined cue** — now fully supported: +- Added `"voice_led"` to the `Enum` elements for `ReadyToInitiateTransferInteraction` and `ReadyForTransferInteraction` in all 3 transfer YAMLs (`transfer_utensil.yaml`, `transfer_drink.yaml`, `transfer_wipe.yaml`). +- Added `voice_led` branches in `transfer_tool.py`: `relay_ready_to_initiate_transfer` (turns on LED + speaks the appropriate prompt) and `relay_ready_for_transfer` (turns on LED + speaks "Ready for transfer"). LED turn-off checks in `detect_transfer_initiated` and `detect_transfer_complete` updated to also trigger on `"voice_led"`. +- Updated `_CONVEY_READY_MAP`: `"speech + LED"` now maps to `"voice_led"` (was falling back to `"voice"` with a warning). Removed the fallback warning logic. + +**`retract_between_bites`** — BT parameter on `transfer_utensil.yaml` only (bite transfers): +- Added `RetractAfterTransfer` Enum parameter (`[0, 1]`, default 0) to `transfer_utensil.yaml`. +- Updated `TransferToolHLA.transfer_utensil` to accept all 8 BT parameters explicitly. When `retract_after_transfer == 1`, calls `move_to_joint_positions(retract_pos)` at the end of the transfer, moving the robot to its rest position. When `0` (default), stays at the staging position near the user. +- Only applies to utensil (bite) transfers — drink and wipe don't have repeated transfer loops that benefit from a retract toggle. +- `_BT_MAPPING` entry: `retract_between_bites` → `RetractAfterTransfer` via `_RETRACT_MAP` (`"yes"` → 1, `"no"` → 0). + +**`"do not dip"` FLAIR suppression** — deterministic override in `get_autonomous_action`: +- Added `allow_dip` boolean flag to `BiteAcquisitionInference` (default `True`). +- In `get_autonomous_action`: if `allow_dip is False` and the LLM returns a two-item list `['food', 'sauce']`, the dip is stripped and only the food item is used. +- Added `FLAIR.set_allow_dip(allow)` method that forwards to the inference server. +- Added `apply_dip_preference(bundle, flair)` in `apply_preferences.py`: when `bite_dipping_preference == "do not dip"`, sets `allow_dip=False`. For `"less"` or `"more"`, keeps `allow_dip=True` (the BT `FoodDippingDepth` parameter handles depth separately). +- Wired into `run.py` after the other preference applications. + +**`occlusion_relevance`** — soft LLM hint via FLAIR's `user_preference` string: +- Added `apply_occlusion_preference(bundle, flair)` in `apply_preferences.py`. +- Maps preference values to natural-language hints: e.g. `"minimize left occlusion"` → `"When choosing which food to pick, prefer items that minimize the robot arm blocking the user's view from the left."`. +- `"do not consider occlusion"` → no-op (no hint appended). +- The hint is appended to FLAIR's existing `user_preference` string, which flows into the `PreferencePlanner` LLM prompt. This is a **soft** hint — the LLM may or may not factor it into food selection. There is no geometric enforcement; the current FLAIR architecture has no model of robot arm visibility from the user's perspective. +- **Limitation (pending discussion with Rajat):** This is the same enforcement level as other FLAIR food preferences (all LLM-based, none geometrically enforced). A stronger approach would require a perception/kinematic model to score candidate bites by arm obstruction from a given viewing direction — no such infrastructure exists today. + +**All 18 preference fields are now wired.** No remaining deferred items for Step 4. + +--- + +### 5. Learn + +- After the correction loop, call `PredictionModel.update(day, context, corrected, ground_truth_bundle)`. This updates LTM/episodic/working memory so next time `predict_bundle()` is closer to the user's intent. + +**Implementation** + +1. `preference_learning/methods/prediction_model.py` (`PredictionModel`) + - Added `next_day() -> int`: scans `working_memory_dir` for existing `day_NNNN.json` files and returns `max + 1` (or `1` if empty). Handles gaps in numbering by taking the max, not the count. +2. `integration/run.py` + - Added `_pref_day: int | None` attribute to `_Runner.__init__` (passed from CLI). + - In `run()`, after preferences are applied and before the task loop: computes `day` (either from `--pref_day` override or `next_day()` auto-detection), then calls `self._prediction_model.update(day, ctx, corrected, ground_truth_bundle)`. + - Added `--pref_day` CLI argument (optional integer). When omitted, auto-detects from logs. + +**What `update()` does** + +1. Builds episode text from `(day, context, ground_truth_bundle)` — a natural-language representation of this meal session's preferences. +2. If LTM is enabled: feeds the episode into `LongTermMemoryModel.add_episode()`, which calls the LLM to update the cumulative preference summary. Writes `day_NNNN.json` to the LTM log directory. +3. If episodic memory is enabled: feeds the episode into `EpisodicMemoryModel.add_episode()`, which stores the episode embedding for future retrieval. Writes `day_NNNN.json` to the EM log directory. +4. Always writes a working memory record (`day_NNNN.json`) with context and corrections. + +**Effect on the next session**: When `predict_bundle()` is called in the next meal, it retrieves the updated LTM summary and relevant episodic memories, producing predictions that incorporate what was learned from previous corrections. + +--- + +### Interaction modes (`--pref_mode`) + +Controls how the preference prediction/correction flow runs before a meal. + +| Mode | Context input | Prediction + correction | Requires | +|---|---|---|---| +| `none` (default) | skipped | skipped | nothing extra | +| `terminal` | interactive numbered menus in terminal | LLM predicts, operator reviews each field and types a number or Enter to accept | `--physical_profile_file` | +| `interface` | from `--pref_meal`, `--pref_setting`, `--pref_time_of_day` CLI flags | LLM predicts, web frontend shows corrections (stub returns unchanged for now) | `--physical_profile_file`, `--pref_meal` | + +**Implementation** + +1. `integration/terminal_preferences.py` (new file) + - `terminal_collect_context()`: prompts operator to pick meal, setting, time_of_day from numbered lists. Returns a context dict. + - `terminal_correct_preferences(predicted_bundle, pref_options)`: shows each of the 18 predicted fields with numbered options. Operator presses Enter to accept or types a number to change. Returns the final bundle. + - `_pick_from_list(prompt, options)`: helper that displays numbered options and loops until valid input. +2. `integration/run.py` + - Added `--pref_mode` CLI flag (`none`/`terminal`/`interface`). Default `none`. + - Added `_pref_mode` attribute to `_Runner.__init__`. + - `run()`: branches on `_pref_mode`: + - `none` → skips straight to `ready_for_task_selection()` (no personalization). + - `terminal` → calls `terminal_collect_context()` + `terminal_correct_preferences()`, then predict → apply → learn. + - `interface` → uses `web_interface.get_preference_corrections()` (existing stub), then predict → apply → learn. + - `--physical_profile_file` now required for both `terminal` and `interface` modes. + - `--pref_meal`/`--pref_setting`/`--pref_time_of_day` only required for `interface` mode; in `terminal` mode, context is collected interactively. + +**Full end-to-end flow with `--pref_mode=terminal`:** + +1. Operator starts: `python run.py --user alice --use_interface --pref_mode terminal --physical_profile_file alice_profile.txt` +2. Terminal prompts for meal, setting, time_of_day (numbered menus). +3. `PredictionModel` predicts 18-field bundle using LTM + episodic memory. +4. Terminal shows each field with predicted value marked; operator corrects or accepts. +5. `apply_bundle_to_behavior_trees` + `apply_transfer_mode` + FLAIR preferences applied. +6. `PredictionModel.update()` logs the episode for future learning. +7. Meal task loop begins. + +**All 5 integration steps are now complete, with terminal interaction mode for testing.** diff --git a/src/feeding_deployment/actions/acquisition.py b/src/feeding_deployment/actions/acquisition.py index 5eddbc58..14a6448e 100644 --- a/src/feeding_deployment/actions/acquisition.py +++ b/src/feeding_deployment/actions/acquisition.py @@ -86,250 +86,4 @@ def get_behavior_tree_filename( return "acquire_bite.yaml" def acquire_bite(self, speed: str, dipping_depth: float, skewering_depth: float, skewering_orientation: str, autocontinue_timeout: float, ask_confirmation: bool) -> None: - - # assert self.sim.held_object_name == "utensil" - - print("Acquiring bite with utensil ...") - return - - if self.robot_interface is not None: - self.robot_interface.set_speed(speed) - - # stop the keep horizontal thread (incase we're trying to re-acquire a bite) - if self.wrist_interface is not None: - self.wrist_interface.stop_horizontal_spoon_thread() - - # self.move_to_joint_positions(self.sim.scene_description.before_transfer_pos) # leads to safer motion - self.move_to_joint_positions(self.sim.scene_description.above_plate_pos) - - while True: - if self.wrist_interface is not None: - self.wrist_interface.set_velocity_mode() - self.wrist_interface.reset() - - try: # bite ordering and detection - if self.robot_interface is not None: - - camera_color_data, camera_info_data, camera_depth_data = ( - self.perception_interface.get_camera_data() - ) - - if not self.flair.is_preference_set(): - plate_image = self.flair.crop_plate(camera_color_data) - if self.web_interface is not None: - while self.web_interface.active: - user_input_food_items, user_input_bite_ordering_preference = self.web_interface.get_new_meal_input(plate_image) - food_items, bite_ordering_preference = self.flair.parse_new_meal(user_input_food_items, user_input_bite_ordering_preference) - if food_items is not None and bite_ordering_preference is not None: - break - else: - print("Failed to parse user input. Trying again ...") - time.sleep(1.0) - else: - # Use command line input for preference setting. - user_input_food_items = input("Enter food items as a python list: ") - user_input_bite_ordering_preference = input("Enter bite ordering preference: ") - food_items, bite_ordering_preference = self.flair.parse_new_meal(user_input_food_items, user_input_bite_ordering_preference) - - self.flair.set_food_items(food_items) - self.flair.set_preference(bite_ordering_preference) - - items_detection = self.flair.detect_items(camera_color_data, camera_depth_data, camera_info_data, log_path=None) - - assert self.log_dir is not None, "Log path must be set to save food detection data" - # save food detection data - food_detection_data = { - "camera_color_data": camera_color_data, - "camera_info_data": camera_info_data, - "camera_depth_data": camera_depth_data, - "food_items": self.flair.get_food_items(), - "bite_ordering_preference": self.flair.get_preference(), - "items_detection": items_detection, - } - - with open(self.log_dir / "food_detection_data.pkl", "wb") as f: - pickle.dump(food_detection_data, f) - - # food detection continuous log - file_name = "food_detection_data" - id = 0 - while (self.food_detection_log_dir / f"{file_name}_{id}.pkl").exists(): - id += 1 - with open(self.food_detection_log_dir / f"{file_name}_{id}.pkl", "wb") as f: - pickle.dump(food_detection_data, f) - - else: - # read last logged data - try: - with open(self.log_dir / "food_detection_data.pkl", "rb") as f: - food_detection_data = pickle.load(f) - - camera_color_data = food_detection_data["camera_color_data"] - camera_info_data = food_detection_data["camera_info_data"] - camera_depth_data = food_detection_data["camera_depth_data"] - food_items = food_detection_data["food_items"] - bite_ordering_preference = food_detection_data["bite_ordering_preference"] - items_detection = food_detection_data["items_detection"] - - self.flair.set_food_items(food_items) - self.flair.set_preference(bite_ordering_preference) - - except FileNotFoundError: - raise FileNotFoundError("No logged data found for bite acquisition") - except Exception as e: - print("Failed to detect items:", e) - continue - - try: # actual acquisition - - # Prepare for bite acquisition. - if self.wrist_interface is not None: - self.wrist_interface.set_velocity_mode() - self.wrist_interface.reset() - - next_action_prediction = self.flair.predict_next_action(camera_color_data, items_detection, log_path=None) - - next_food_item = next_action_prediction['labels_list'][next_action_prediction['food_id']] - bite_mask_idx = next_action_prediction['bite_mask_idx'] - print(" --- Next Food Item Prediction:", next_action_prediction['labels_list'][next_action_prediction['food_id']]) - print(" --- Next Action Prediction:", next_action_prediction['action_type']) - - # remove next_food_item from data - solid_food_type_to_data = {} - for id in range(0, len(items_detection['labels_list'])): - if items_detection['category_list'][id] == "solid": - label = items_detection['labels_list'][id] - solid_food_type_to_data[label] = items_detection['food_type_to_bounding_boxes_plate'][label] - - n_food_types = len(solid_food_type_to_data) - data = [{k: v} for k, v in solid_food_type_to_data.items() if k != next_food_item] - predicted_bite = {next_food_item: solid_food_type_to_data[next_food_item]} - - dip_food_type_to_data = {} - for id in range(0, len(items_detection['labels_list'])): - if items_detection['category_list'][id] == "dip": - label = items_detection['labels_list'][id] - dip_food_type_to_data[label] = items_detection['food_type_to_bounding_boxes_plate'][label] - - if len(dip_food_type_to_data) == 0: # no dips detected - dip_data = ["No dip"] - else: - if next_action_prediction['dip_id'] is None: - dip_data = ["No dip"] - dip_data.extend(list(dip_food_type_to_data.keys())) - else: # some dip was predicted - next_dip_item = next_action_prediction['labels_list'][next_action_prediction['dip_id']] - dip_data = [next_dip_item] - dip_data.append("No dip") - dip_data.extend([k for k in dip_food_type_to_data.keys() if k != next_dip_item]) - n_dip_food_types = len(dip_data) - - if self.web_interface is not None: - skill_type, skill_params, dip_type = self.web_interface.get_next_bite_selection(items_detection['plate_image'], n_food_types, data, predicted_bite, n_dip_food_types, dip_data, autocontinue_timeout=autocontinue_timeout) - else: - # params must be set to the autonomously selected values - skill_type = "autonomous" - skill_params = [next_food_item, bite_mask_idx] - dip_type = "No dip" - - skill_success = False - if skill_type == "autonomous": - food_type_to_masks = items_detection["food_type_to_masks"] - food_type_to_skill = items_detection["food_type_to_skill"] - - food_type = skill_params[0] - item_id = skill_params[1] - 1 - - # Rajat Imp ToDo: Update bite history after successful skill execution - self.flair.update_bite_history(food_type) - - mask = food_type_to_masks[food_type][item_id] - skill = food_type_to_skill[food_type] - - if skill == "Skewer": - skewer_point, skewer_angle = self.flair.inference_server.get_skewer_action(mask) - if skewering_orientation == "vertical": - skewer_angle = skewer_angle + np.pi / 2 - skill_success = self.food_manipulation_skill_library.skewering_skill(camera_color_data, camera_depth_data, camera_info_data, keypoint = skewer_point, major_axis = skewer_angle, skewering_depth=skewering_depth) - elif skill == "Scoop": - raise NotImplementedError("Scoop skill not yet implemented") - - if dip_type != "No dip" and skill_success: - self.flair.update_bite_history(dip_type) - dip_mask = food_type_to_masks[dip_type][0] - dip_point = self.flair.inference_server.get_dip_action(dip_mask) - self.food_manipulation_skill_library.robot_reset() - skill_success = self.food_manipulation_skill_library.dipping_skill(camera_color_data, camera_depth_data, camera_info_data, keypoint = dip_point, dipping_depth=dipping_depth) - - elif skill_type == "manual_skewering": - - plate_bounds = items_detection["plate_bounds"] - pos = skill_params[0] - - point_x = int(pos["x"]*plate_bounds[2]) + plate_bounds[0] - point_y = int(pos["y"]*plate_bounds[3]) + plate_bounds[1] - - print("Plate Bounds:", plate_bounds) - print("Positions:", skill_params) - print("Point:", point_x, point_y) - - if not self.no_waits: - # visualize point on camera color image - viz = camera_color_data.copy() - for pos in skill_params: - cv2.circle(viz, (point_x, point_y), 5, (0, 255, 0), -1) - cv2.imshow("viz", viz) - cv2.waitKey(0) - cv2.destroyAllWindows() - - skewer_center = (point_x, point_y) - skewer_angle = -np.pi/2 - - skill_success = self.food_manipulation_skill_library.skewering_skill(camera_color_data, camera_depth_data, camera_info_data, keypoint = skewer_center, major_axis = skewer_angle, skewering_depth=skewering_depth) - elif skill_type == "manual_scooping": - raise NotImplementedError("Scoop skill not yet implemented") - elif skill_type == "manual_dipping": - - plate_bounds = items_detection["plate_bounds"] - pos = skill_params[0] - - point_x = int(pos["x"]*plate_bounds[2]) + plate_bounds[0] - point_y = int(pos["y"]*plate_bounds[3]) + plate_bounds[1] - - print("Plate Bounds:", plate_bounds) - print("Positions:", skill_params) - print("Point:", point_x, point_y) - - if not self.no_waits: - # visualize point on camera color image - viz = camera_color_data.copy() - for pos in skill_params: - cv2.circle(viz, (point_x, point_y), 5, (0, 255, 0), -1) - cv2.imshow("viz", viz) - cv2.waitKey(0) - cv2.destroyAllWindows() - - dip_point = (point_x, point_y) - - skill_success = self.food_manipulation_skill_library.dipping_skill(camera_color_data, camera_depth_data, camera_info_data, keypoint = dip_point, dipping_depth=dipping_depth) - - self.move_to_joint_positions(self.sim.scene_description.above_plate_pos) - if not skill_success: - print("Skill failed. Retrying ...") - continue - except Exception as e: - print("Failed to acquire bite:", e) - continue - - if self.web_interface is not None and ask_confirmation: - get_success_confirmation = self.web_interface.get_successful_food_acquisition_confirmation() - if get_success_confirmation: - break - else: - break - - # set the wrist controller to always keep utensil horizontal - if self.wrist_interface is not None: - self.wrist_interface.start_horizontal_spoon_thread() - - return [] + print("Acquiring bite with utensil ...") \ No newline at end of file diff --git a/src/feeding_deployment/actions/base.py b/src/feeding_deployment/actions/base.py index c7c149f5..8ec9b69b 100644 --- a/src/feeding_deployment/actions/base.py +++ b/src/feeding_deployment/actions/base.py @@ -339,8 +339,8 @@ def create_user_addition_node_dict( def move_to_joint_positions(self, joint_positions: list[float]) -> None: plan = None - # if not self.no_waits: - # plan = self.sim.plan_to_joint_positions(joint_positions) + if not self.no_waits: + plan = self.sim.plan_to_joint_positions(joint_positions) if self.robot_interface is None: self.sim.visualize_plan(plan) else: @@ -349,8 +349,8 @@ def move_to_joint_positions(self, joint_positions: list[float]) -> None: def move_to_ee_pose(self, pose: Pose) -> None: plan = None - # if not self.no_waits: - # plan = self.sim.plan_to_ee_pose(pose) + if not self.no_waits: + plan = self.sim.plan_to_ee_pose(pose) if self.robot_interface is None: self.sim.visualize_plan(plan) else: diff --git a/src/feeding_deployment/actions/behavior_trees/acquire_bite.yaml b/src/feeding_deployment/actions/behavior_trees/acquire_bite.yaml index 1c10d20c..0bd03cdf 100644 --- a/src/feeding_deployment/actions/behavior_trees/acquire_bite.yaml +++ b/src/feeding_deployment/actions/behavior_trees/acquire_bite.yaml @@ -37,7 +37,7 @@ parameters: space: type: "Box" lower: 5.0 - upper: 100.0 + upper: 1000.0 is_user_editable: True value: 10.0 - name: "TransferAskForConfirmation" diff --git a/src/feeding_deployment/actions/behavior_trees/press_microwave_button.yaml b/src/feeding_deployment/actions/behavior_trees/press_microwave_button.yaml index 5efb77fa..fec2b0a5 100644 --- a/src/feeding_deployment/actions/behavior_trees/press_microwave_button.yaml +++ b/src/feeding_deployment/actions/behavior_trees/press_microwave_button.yaml @@ -9,4 +9,12 @@ parameters: elements: ["low", "medium", "high"] is_user_editable: True value: "medium" + - name: "MicrowaveDuration" + description: "How long the microwave should run, in seconds." + space: + type: "Box" + lower: 30.0 + upper: 300.0 + is_user_editable: True + value: 60.0 fn: !hla press_microwave_button diff --git a/src/feeding_deployment/actions/behavior_trees/transfer_drink.yaml b/src/feeding_deployment/actions/behavior_trees/transfer_drink.yaml index 3a19d81d..7d6b8a41 100644 --- a/src/feeding_deployment/actions/behavior_trees/transfer_drink.yaml +++ b/src/feeding_deployment/actions/behavior_trees/transfer_drink.yaml @@ -13,7 +13,7 @@ parameters: description: "Interaction mode for how the robot conveys to the user that it is ready to initiate the drink transfer." space: type: "Enum" - elements: ["silent", "voice", "led", "beep"] + elements: ["silent", "voice", "led", "beep", "voice_led"] is_user_editable: True value: "voice" - name: "InitiateTransferInteraction" @@ -27,7 +27,7 @@ parameters: description: "Interaction mode for how the robot conveys to the user that they can take a sip now." space: type: "Enum" - elements: ["silent", "voice", "led", "beep"] + elements: ["silent", "voice", "led", "beep", "voice_led"] is_user_editable: True value: "voice" - name: "TransferCompleteInteraction" @@ -57,7 +57,7 @@ parameters: space: type: "Box" lower: 5.0 - upper: 100.0 + upper: 1000.0 is_user_editable: True value: 10.0 fn: !hla transfer_drink diff --git a/src/feeding_deployment/actions/behavior_trees/transfer_utensil.yaml b/src/feeding_deployment/actions/behavior_trees/transfer_utensil.yaml index b33ccc49..3f234621 100644 --- a/src/feeding_deployment/actions/behavior_trees/transfer_utensil.yaml +++ b/src/feeding_deployment/actions/behavior_trees/transfer_utensil.yaml @@ -13,7 +13,7 @@ parameters: description: "Interaction mode for how the robot conveys to the user that it is ready to initiate the feeding utensil (bite) transfer." space: type: "Enum" - elements: ["silent", "voice", "led", "beep"] + elements: ["silent", "voice", "led", "beep", "voice_led"] is_user_editable: True value: "voice" - name: "InitiateTransferInteraction" @@ -27,7 +27,7 @@ parameters: description: "Interaction mode for how the robot conveys to the user that they can take a bite now." space: type: "Enum" - elements: ["silent", "voice", "led", "beep"] + elements: ["silent", "voice", "led", "beep", "voice_led"] is_user_editable: True value: "voice" - name: "TransferCompleteInteraction" @@ -50,7 +50,14 @@ parameters: space: type: "Box" lower: 5.0 - upper: 100.0 + upper: 1000.0 is_user_editable: True value: 10.0 + - name: "RetractAfterTransfer" + description: "Whether the robot retracts to its rest position after completing a bite transfer, or stays near the user." + space: + type: "Enum" + elements: [0, 1] + is_user_editable: True + value: 0 fn: !hla transfer_utensil diff --git a/src/feeding_deployment/actions/behavior_trees/transfer_wipe.yaml b/src/feeding_deployment/actions/behavior_trees/transfer_wipe.yaml index 51392226..ce9f3cd7 100644 --- a/src/feeding_deployment/actions/behavior_trees/transfer_wipe.yaml +++ b/src/feeding_deployment/actions/behavior_trees/transfer_wipe.yaml @@ -13,7 +13,7 @@ parameters: description: "Interaction mode for how the robot conveys to the user that it is ready to initiate the wipe transfer." space: type: "Enum" - elements: ["silent", "voice", "led", "beep"] + elements: ["silent", "voice", "led", "beep", "voice_led"] is_user_editable: True value: "voice" - name: "InitiateTransferInteraction" @@ -27,7 +27,7 @@ parameters: description: "Interaction mode for how the robot conveys to the user that they can take start cleaning now." space: type: "Enum" - elements: ["silent", "voice", "led", "beep"] + elements: ["silent", "voice", "led", "beep", "voice_led"] is_user_editable: True value: "voice" - name: "TransferCompleteInteraction" diff --git a/src/feeding_deployment/actions/close_door.py b/src/feeding_deployment/actions/close_door.py index 7f7797f9..c54ef4f5 100644 --- a/src/feeding_deployment/actions/close_door.py +++ b/src/feeding_deployment/actions/close_door.py @@ -53,11 +53,7 @@ def get_behavior_tree_filename( return f"close_{appliance.name}.yaml" def close_fridge(self, speed: str) -> None: - del speed - assert self.sim.held_object_name is None print("Closing fridge door ...") def close_microwave(self, speed: str) -> None: - del speed - assert self.sim.held_object_name is None print("Closing microwave door ...") \ No newline at end of file diff --git a/src/feeding_deployment/actions/flair/flair.py b/src/feeding_deployment/actions/flair/flair.py index 1ddad5ca..d1dd5001 100644 --- a/src/feeding_deployment/actions/flair/flair.py +++ b/src/feeding_deployment/actions/flair/flair.py @@ -115,6 +115,9 @@ def get_preference(self): def clear_preference(self): self.user_preference = None + def set_allow_dip(self, allow: bool) -> None: + self.inference_server.allow_dip = allow + def is_preference_set(self): return self.user_preference is not None diff --git a/src/feeding_deployment/actions/flair/inference_class.py b/src/feeding_deployment/actions/flair/inference_class.py index 3e581153..a77da19b 100644 --- a/src/feeding_deployment/actions/flair/inference_class.py +++ b/src/feeding_deployment/actions/flair/inference_class.py @@ -120,6 +120,7 @@ def __init__(self, mode): self.preference_planner = PreferencePlanner() self.mode = mode + self.allow_dip = True # def recognize_items(self, image): # response = self.gpt4v_client.prompt(image).strip() @@ -1052,6 +1053,10 @@ def get_autonomous_action(self, image, masks, categories, labels, portions, pref print('skewer_labels', skewer_labels) print('Next actions', skewer_actions) + if not self.allow_dip and len(next_bite) >= 2: + print(f"[preference] Dipping suppressed: LLM suggested {next_bite}, using only {next_bite[0]}") + next_bite = [next_bite[0]] + if len(next_bite) == 1 and next_bite[0] in labels: print(skewer_labels, next_bite[0]) idx = skewer_labels.index(next_bite[0]) diff --git a/src/feeding_deployment/actions/navigate.py b/src/feeding_deployment/actions/navigate.py index 7693e41e..155249d9 100644 --- a/src/feeding_deployment/actions/navigate.py +++ b/src/feeding_deployment/actions/navigate.py @@ -55,17 +55,13 @@ def get_behavior_tree_filename( return f"navigate_to_{dst.name}.yaml" def navigate_to_fridge(self, speed: str) -> None: - del speed print("Navigating to fridge ...") def navigate_to_microwave(self, speed: str) -> None: - del speed print("Navigating to microwave ...") def navigate_to_sink(self, speed: str) -> None: - del speed print("Navigating to sink ...") def navigate_to_table(self, speed: str) -> None: - del speed print("Navigating to table ...") \ No newline at end of file diff --git a/src/feeding_deployment/actions/open_door.py b/src/feeding_deployment/actions/open_door.py index 9153a1b7..bcaaf725 100644 --- a/src/feeding_deployment/actions/open_door.py +++ b/src/feeding_deployment/actions/open_door.py @@ -55,109 +55,7 @@ def get_behavior_tree_filename( return f"open_{appliance.name}.yaml" def open_fridge(self, speed: str) -> None: - - # set speed of the robot to highest - self.robot_interface.set_speed("high") - assert self.sim.held_object_name is None print("Opening fridge door ...") - return - self.move_to_joint_positions(self.sim.scene_description.retract_pos) - self.move_to_joint_positions(self.sim.scene_description.fridge_door_gaze_pos) - - handle_opening_poses = self.perception_interface.perceive_handle_opening_poses("white fridge door") - - # visualize on rviz - poses = [] - poses.append(handle_opening_poses["pre_grasp_pose"]) - poses.append(handle_opening_poses["grasp_pose"]) - poses.extend(handle_opening_poses["opening_waypoints"]) - poses.append(handle_opening_poses["post_release_pose"]) - poses.append(handle_opening_poses["pre_push_pose"]) - poses.append(handle_opening_poses["push_pose"]) - poses.extend(handle_opening_poses["push_waypoints"]) - print(f"Visualizing {len(poses)} handle opening poses in RViz ...") - self.rviz_interface.visualize_poses(poses, frame_id="base_link", ns="handle_opening_poses") - - # self.move_to_joint_positions(self.sim.scene_description.home_pos) - self.move_to_joint_positions(self.sim.scene_description.fridge_door_staging_pos) - - self.move_to_ee_pose(handle_opening_poses["pre_grasp_pose"]) - self.open_gripper() - self.move_to_ee_pose(handle_opening_poses["grasp_pose"]) - self.close_gripper() - # self.move_to_ee_pose(handle_opening_poses["post_grasp_pose"]) - self.move_to_ee_pose_trajectory(handle_opening_poses["opening_waypoints"]) - self.open_gripper() - self.move_to_ee_pose(handle_opening_poses["post_release_pose"]) - - # self.move_to_joint_positions(self.sim.scene_description.fridge_door_intermediate_restract_pos) - - self.move_to_ee_pose(handle_opening_poses["pre_push_pose"]) - self.move_to_ee_pose(handle_opening_poses["push_pose"]) - self.move_to_ee_pose_trajectory(handle_opening_poses["push_waypoints"]) def open_microwave(self, speed: str) -> None: - assert self.sim.held_object_name is None - print("Opening microwave door ...") - return - self.move_to_joint_positions(self.sim.scene_description.retract_pos) - self.move_to_joint_positions(self.sim.scene_description.microwave_closeup_gaze_pos) - - time.sleep(5.0) # wait for the robot to stabilize before perception - press_button_poses = self.perception_interface.perceive_button_pressing_poses() - - self.move_to_joint_positions(self.sim.scene_description.fridge_door_staging_pos) - self.close_gripper() # just in case the gripper is open - self.move_to_ee_pose(press_button_poses["pre_press_pose"]) - self.move_to_ee_pose(press_button_poses["press_pose"]) - self.move_to_ee_pose(press_button_poses["intermediate_pose"]) - self.move_to_ee_pose(press_button_poses["press_pose"]) - self.move_to_ee_pose(press_button_poses["intermediate_pose"]) - self.move_to_ee_pose(press_button_poses["press_pose"]) - self.move_to_ee_pose(press_button_poses["pre_press_pose"]) - self.move_to_joint_positions(self.sim.scene_description.fridge_door_staging_pos) - - # handle_opening_poses = self.perception_interface.perceive_handle_opening_poses("microwave") - - # # visualize on rviz - # poses = [] - # poses.append(handle_opening_poses["pre_grasp_pose"]) - # poses.append(handle_opening_poses["grasp_pose"]) - # poses.extend(handle_opening_poses["opening_waypoints"]) - # poses.append(handle_opening_poses["post_release_pose"]) - # poses.append(handle_opening_poses["pre_push_pose"]) - # poses.append(handle_opening_poses["push_pose"]) - # poses.extend(handle_opening_poses["push_waypoints"]) - # poses.append(handle_opening_poses["before_above_closing_waypoint"]) - # poses.append(handle_opening_poses["above_closing_waypoint"]) - # poses.append(handle_opening_poses["closing_waypoint"]) - # poses.extend(handle_opening_poses["closing_waypoints"]) - # print(f"Visualizing {len(poses)} handle opening poses in RViz ...") - # self.rviz_interface.visualize_poses(poses, frame_id="base_link", ns="handle_opening_poses") - - # # self.move_to_joint_positions(self.sim.scene_description.home_pos) - # self.move_to_joint_positions(self.sim.scene_description.fridge_door_staging_pos) - - # self.move_to_ee_pose(handle_opening_poses["pre_grasp_pose"]) - # self.open_gripper() - # self.move_to_ee_pose(handle_opening_poses["grasp_pose"]) - # self.close_gripper() - # # self.move_to_ee_pose(handle_opening_poses["post_grasp_pose"]) - # self.move_to_ee_pose_trajectory(handle_opening_poses["opening_waypoints"]) - # self.open_gripper() - # self.move_to_ee_pose(handle_opening_poses["post_release_pose"]) - - # # self.move_to_joint_positions(self.sim.scene_description.fridge_door_intermediate_restract_pos) - - # self.move_to_ee_pose(handle_opening_poses["pre_push_pose"]) - # self.move_to_ee_pose(handle_opening_poses["push_pose"]) - # self.move_to_ee_pose_trajectory(handle_opening_poses["push_waypoints"]) - - # self.move_to_ee_pose(handle_opening_poses["before_above_closing_waypoint"]) - # self.move_to_ee_pose(handle_opening_poses["above_closing_waypoint"]) - # self.move_to_ee_pose(handle_opening_poses["closing_waypoint"]) - # self.move_to_ee_pose_trajectory(handle_opening_poses["closing_waypoints"]) - - # self.close_gripper() - # self.move_to_ee_pose(handle_opening_poses["offset_closing_waypoints"][0]) - # self.move_to_ee_pose_trajectory(handle_opening_poses["offset_closing_waypoints"]) \ No newline at end of file + print("Opening microwave door ...") \ No newline at end of file diff --git a/src/feeding_deployment/actions/pick_plate.py b/src/feeding_deployment/actions/pick_plate.py index bd324049..b775f2c5 100644 --- a/src/feeding_deployment/actions/pick_plate.py +++ b/src/feeding_deployment/actions/pick_plate.py @@ -71,16 +71,11 @@ def get_behavior_tree_filename( return f"pick_plate_from_{appliance.name}.yaml" def pick_plate_from_fridge(self, speed: str) -> None: - assert self.sim.held_object_name is None print("Picking plate from fridge ...") def pick_plate_from_microwave(self, speed: str) -> None: - assert self.sim.held_object_name is None print("Picking plate from microwave ...") - self.move_to_joint_positions(self.sim.scene_description.home_pos) - self.move_to_joint_positions(self.sim.scene_description.microwave_closeup_gaze_pos) - class PickPlateFromHolderHLA(HighLevelAction): """Pick the plate from the holder.""" @@ -122,7 +117,6 @@ def get_behavior_tree_filename( return "pick_plate_from_holder.yaml" def pick_plate_from_holder(self, speed: str) -> None: - assert self.sim.held_object_name is None print("Picking plate from holder ...") class PickPlateFromTableHLA(HighLevelAction): @@ -166,5 +160,4 @@ def get_behavior_tree_filename( return "pick_plate_from_table.yaml" def pick_plate_from_table(self, speed: str) -> None: - assert self.sim.held_object_name is None print("Picking plate from table ...") \ No newline at end of file diff --git a/src/feeding_deployment/actions/pick_tool.py b/src/feeding_deployment/actions/pick_tool.py index 06a4763e..df3cdd1b 100644 --- a/src/feeding_deployment/actions/pick_tool.py +++ b/src/feeding_deployment/actions/pick_tool.py @@ -56,134 +56,13 @@ def get_behavior_tree_filename( return f"pick_{tool.name}.yaml" def pick_utensil(self, speed: str) -> None: - assert self.sim.held_object_name is None - - print("Picking up utensil ...") - return - - if self.robot_interface is not None: - self.robot_interface.set_speed(speed) - - # self.move_to_joint_positions(self.sim.scene_description.retract_pos) - # self.close_gripper() - # self.move_to_joint_positions(self.sim.scene_description.utensil_above_mount_pos) - # self.move_to_ee_pose(self.sim.scene_description.utensil_inside_mount) - # self.grasp_tool("utensil") - - # if self.wrist_interface is not None: - # time.sleep(1.0) # wait for the utensil to be connected - # print("Resetting wrist controller ...") - # self.wrist_interface.set_velocity_mode() - # self.wrist_interface.reset() - - # self.move_to_ee_pose(self.sim.scene_description.utensil_outside_mount) - # if self.sim.scene_description.scene_label == "vention": - # self.move_to_ee_pose(self.sim.scene_description.utensil_outside_above_mount) - # elif self.sim.scene_description.scene_label == "wheelchair": - # # Not sure if this is necessary. - # self.move_to_joint_positions(self.sim.scene_description.retract_pos) - # # Pre-emptively move to the before_transfer_pos because moving to above_plate_pos from retract_pos is unsafe. - # self.move_to_joint_positions(self.sim.scene_description.absolute_before_transfer_pos) - # self.move_to_joint_positions(self.sim.scene_description.before_transfer_pos) - - self.move_to_joint_positions(self.sim.scene_description.retract_pos) - self.close_gripper() - - if self.sim.scene_description.scene_label == "vention": - self.move_to_joint_positions(self.sim.scene_description.wipe_infront_mount_pos) - - self.move_to_joint_positions(self.sim.scene_description.wipe_above_mount_pos) - self.move_to_ee_pose(self.sim.scene_description.wipe_inside_mount) - self.grasp_tool("utensil") - - if self.wrist_interface is not None: - time.sleep(1.0) # wait for the utensil to be connected - print("Resetting wrist controller ...") - self.wrist_interface.set_velocity_mode() - self.wrist_interface.reset() - - self.move_to_ee_pose(self.sim.scene_description.wipe_outside_mount) - - if self.sim.scene_description.scene_label == "wheelchair": - self.move_to_ee_pose(self.sim.scene_description.wipe_outside_above_mount) - elif self.sim.scene_description.scene_label == "vention": - self.move_to_joint_positions(self.sim.scene_description.wipe_neutral_pos) - self.move_to_joint_positions(self.sim.scene_description.retract_pos) - self.move_to_joint_positions(self.sim.scene_description.absolute_before_transfer_pos) - self.move_to_joint_positions(self.sim.scene_description.before_transfer_pos) + print("Picking utensil ...") def pick_drink(self, speed: str) -> None: - assert self.sim.held_object_name is None - - if self.robot_interface is not None: - self.robot_interface.set_speed(speed) - - self.move_to_joint_positions(self.sim.scene_description.retract_pos) - self.close_gripper() - self.move_to_joint_positions(self.sim.scene_description.drink_gaze_pos) - - drink_poses = self.perception_interface.perceive_drink_pickup_poses() - - self.move_to_joint_positions(self.sim.scene_description.drink_staging_pos) - self.move_to_ee_pose(drink_poses['pre_grasp_pose']) - self.move_to_ee_pose(drink_poses['inside_bottom_pose']) - self.move_to_ee_pose(drink_poses['inside_top_pose']) - self.grasp_tool("drink") - self.move_to_ee_pose(drink_poses['post_grasp_pose']) - - self.perception_interface.record_drink_pickup_joint_pos() - - # self.move_to_joint_positions(self.sim.scene_description.drink_before_transfer_pos) + print("Picking drinking ...") def pick_wipe(self, speed: str) -> None: - assert self.sim.held_object_name is None - - if self.robot_interface is not None: - self.robot_interface.set_speed(speed) - - self.move_to_joint_positions(self.sim.scene_description.retract_pos) - self.close_gripper() - - if self.sim.scene_description.scene_label == "vention": - self.move_to_joint_positions(self.sim.scene_description.wipe_infront_mount_pos) - - self.move_to_joint_positions(self.sim.scene_description.wipe_above_mount_pos) - self.move_to_ee_pose(self.sim.scene_description.wipe_inside_mount) - self.grasp_tool("wipe") - self.move_to_ee_pose(self.sim.scene_description.wipe_outside_mount) - - if self.sim.scene_description.scene_label == "wheelchair": - self.move_to_ee_pose(self.sim.scene_description.wipe_outside_above_mount) - elif self.sim.scene_description.scene_label == "vention": - self.move_to_joint_positions(self.sim.scene_description.wipe_neutral_pos) - self.move_to_joint_positions(self.sim.scene_description.retract_pos) - self.move_to_joint_positions(self.sim.scene_description.before_transfer_pos) + print("Picking wipe ...") def pick_plate(self, speed: str) -> None: - assert self.sim.held_object_name is None - - if self.robot_interface is not None: - self.robot_interface.set_speed(speed) - - if self.perception_interface.last_plate_poses is None: - self.move_to_joint_positions(self.sim.scene_description.retract_pos) - self.close_gripper() - self.move_to_joint_positions(self.sim.scene_description.above_plate_pos) - plate_poses = self.perception_interface.perceive_plate_pickup_poses() - else: - plate_poses = self.perception_interface.last_plate_poses - - print("Moving to plate staging position ...") - self.move_to_joint_positions(self.sim.scene_description.plate_staging_pos) - print("Moving to plate pre-grasp pose ...") - self.move_to_ee_pose(plate_poses['pre_grasp_pose']) - print("Moving to plate inside bottom pose ...") - self.move_to_ee_pose(plate_poses['inside_bottom_pose']) - print("Moving to plate inside top pose ...") - self.move_to_ee_pose(plate_poses['inside_top_pose']) - print("Grasping plate ...") - self.grasp_tool("plate") - # print("Moving to plate post-grasp pose ...") - # self.move_to_ee_pose(plate_poses['post_grasp_pose']) - - self.perception_interface.record_plate_pickup_joint_pos() + print("Picking plate ...") diff --git a/src/feeding_deployment/actions/place_plate.py b/src/feeding_deployment/actions/place_plate.py index 257c4604..142aacf9 100644 --- a/src/feeding_deployment/actions/place_plate.py +++ b/src/feeding_deployment/actions/place_plate.py @@ -63,11 +63,9 @@ def get_behavior_tree_filename( return f"place_plate_in_{appliance.name}.yaml" def place_plate_in_fridge(self, speed: str) -> None: - # assert self.sim.held_object_name == "plate" print("Placing plate in fridge ...") def place_plate_in_microwave(self, speed: str) -> None: - # assert self.sim.held_object_name == "plate" print("Placing plate in microwave ...") @@ -110,7 +108,6 @@ def get_behavior_tree_filename( return "place_plate_on_holder.yaml" def place_plate_on_holder(self, speed: str) -> None: - # assert self.sim.held_object_name == "plate" print("Placing plate on holder ...") class PlacePlateInSinkHLA(HighLevelAction): @@ -154,7 +151,6 @@ def get_behavior_tree_filename( return f"place_plate_in_sink.yaml" def place_plate_in_sink(self, speed: str) -> None: - # assert self.sim.held_object_name == "plate" print("Placing plate in sink ...") class PlacePlateOnTableHLA(HighLevelAction): @@ -198,5 +194,4 @@ def get_behavior_tree_filename( return f"place_plate_on_table.yaml" def place_plate_on_table(self, speed: str) -> None: - # assert self.sim.held_object_name == "plate" print("Placing plate on table ...") \ No newline at end of file diff --git a/src/feeding_deployment/actions/press_microwave_button.py b/src/feeding_deployment/actions/press_microwave_button.py index fe443083..6ac117e2 100644 --- a/src/feeding_deployment/actions/press_microwave_button.py +++ b/src/feeding_deployment/actions/press_microwave_button.py @@ -55,6 +55,5 @@ def get_behavior_tree_filename( ) return "press_microwave_button.yaml" - def press_microwave_button(self, speed: str) -> None: - assert self.sim.held_object_name is None - print("Pressing microwave button ...") \ No newline at end of file + def press_microwave_button(self, speed: str, duration: float) -> None: + print(f"Pressing microwave button (duration={duration}s) ...") \ No newline at end of file diff --git a/src/feeding_deployment/actions/stow_tool.py b/src/feeding_deployment/actions/stow_tool.py index f5f0ca09..ffebfbf4 100644 --- a/src/feeding_deployment/actions/stow_tool.py +++ b/src/feeding_deployment/actions/stow_tool.py @@ -55,120 +55,13 @@ def get_behavior_tree_filename( return f"stow_{tool.name}.yaml" def stow_utensil(self, speed: str) -> None: - # assert self.sim.held_object_name == "utensil" - print("Stowing utensil ...") - return - - # if self.robot_interface is not None: - # self.robot_interface.set_speed(speed) - - # # self.move_to_joint_positions(self.sim.scene_description.before_transfer_pos) - - # if self.sim.scene_description.scene_label == "vention": - # self.move_to_joint_positions(self.sim.scene_description.utensil_outside_above_mount_pos) - # self.move_to_ee_pose(self.sim.scene_description.utensil_outside_mount) - # elif self.sim.scene_description.scene_label == "wheelchair": - # self.move_to_joint_positions(self.sim.scene_description.utensil_outside_mount_pos) - - # self.move_to_ee_pose(self.sim.scene_description.utensil_inside_mount) - # self.ungrasp_tool("utensil") - # self.move_to_ee_pose(self.sim.scene_description.utensil_above_mount) - # self.move_to_joint_positions(self.sim.scene_description.retract_pos) - - if self.robot_interface is not None: - self.robot_interface.set_speed(speed) - - self.move_to_joint_positions(self.sim.scene_description.retract_pos) - - if self.sim.scene_description.scene_label == "vention": - self.move_to_joint_positions(self.sim.scene_description.wipe_neutral_pos) - self.move_to_joint_positions(self.sim.scene_description.wipe_outside_mount_pos) - elif self.sim.scene_description.scene_label == "wheelchair": - self.move_to_joint_positions(self.sim.scene_description.wipe_outside_above_mount_pos) - self.move_to_ee_pose(self.sim.scene_description.wipe_outside_mount) - self.move_to_ee_pose(self.sim.scene_description.wipe_inside_mount) - self.ungrasp_tool("utensil") - self.move_to_ee_pose(self.sim.scene_description.wipe_above_mount) - - if self.sim.scene_description.scene_label == "vention": - self.move_to_ee_pose(self.sim.scene_description.wipe_infront_mount) - - self.move_to_joint_positions(self.sim.scene_description.retract_pos) def stow_drink(self, speed: str) -> None: - assert self.sim.held_object_name == "drink" - - if self.robot_interface is not None: - self.robot_interface.set_speed(speed) - - last_drink_poses, last_drink_pickup_joint_pos = self.perception_interface.get_last_drink_pickup_configs() - x_movement, y_movement = self.sim.scene_description.drink_delta_xy - self.sim.scene_description.drink_delta_xy = (0, 0) - - for value in ['drink_pose', 'inside_top_pose', 'place_inside_bottom_pose', 'place_pre_grasp_pose']: - last_drink_poses[value].position[0] += y_movement - last_drink_poses[value].position[1] -= x_movement - - # self.move_to_joint_positions(self.sim.scene_description.drink_before_transfer_pos) - if abs(x_movement) < 0.01 and abs(y_movement) < 0.01: - self.move_to_joint_positions(last_drink_pickup_joint_pos) - self.move_to_ee_pose(last_drink_poses['inside_top_pose']) - self.ungrasp_tool("drink") - self.move_to_ee_pose(last_drink_poses['place_inside_bottom_pose']) - self.move_to_ee_pose(last_drink_poses['place_pre_grasp_pose']) - self.move_to_joint_positions(self.sim.scene_description.drink_staging_pos) - self.move_to_joint_positions(self.sim.scene_description.retract_pos) + print("Stowing drink ...") def stow_wipe(self, speed: str) -> None: - assert self.sim.held_object_name == "wipe" - - if self.robot_interface is not None: - self.robot_interface.set_speed(speed) - - self.move_to_joint_positions(self.sim.scene_description.retract_pos) - - if self.sim.scene_description.scene_label == "vention": - self.move_to_joint_positions(self.sim.scene_description.wipe_neutral_pos) - self.move_to_joint_positions(self.sim.scene_description.wipe_outside_mount_pos) - elif self.sim.scene_description.scene_label == "wheelchair": - self.move_to_joint_positions(self.sim.scene_description.wipe_outside_above_mount_pos) - self.move_to_ee_pose(self.sim.scene_description.wipe_outside_mount) - self.move_to_ee_pose(self.sim.scene_description.wipe_inside_mount) - self.ungrasp_tool("wipe") - self.move_to_ee_pose(self.sim.scene_description.wipe_above_mount) - - if self.sim.scene_description.scene_label == "vention": - self.move_to_ee_pose(self.sim.scene_description.wipe_infront_mount) - - self.move_to_joint_positions(self.sim.scene_description.retract_pos) + print("Stowing wipe ...") def stow_plate(self, speed: str) -> None: - print("Object name: ", self.sim.held_object_name) - assert self.sim.held_object_name == "plate" - - if self.robot_interface is not None: - self.robot_interface.set_speed(speed) - - last_plate_poses = self.perception_interface.get_last_plate_pickup_configs(study_poses=False) - - # x_movement = input("Input the amount of x movement (to your right) for the plate: ") - # x_movement = float(x_movement) - x_movement, y_movement = self.sim.scene_description.plate_delta_xy - self.sim.scene_description.plate_delta_xy = (0, 0) - - # y_movement = input("Input the amount of y movement (away from you) for the plate: ") - # y_movement = float(y_movement) - - for value in ['plate_pose', 'inside_top_pose', 'place_inside_bottom_pose', 'place_pre_grasp_pose']: - last_plate_poses[value].position[0] += y_movement - last_plate_poses[value].position[1] -= x_movement - - # input("Preparing to move plate by: x_movement: {}, y_movement: {}. Press Enter to continue...".format(x_movement, y_movement)) - - self.move_to_ee_pose(last_plate_poses['inside_top_pose']) - self.ungrasp_tool("plate") - self.move_to_ee_pose(last_plate_poses['place_inside_bottom_pose']) - self.move_to_ee_pose(last_plate_poses['place_pre_grasp_pose']) - self.move_to_joint_positions(self.sim.scene_description.plate_staging_pos) - self.move_to_joint_positions(self.sim.scene_description.retract_pos) + print("Stowing plate ...") diff --git a/src/feeding_deployment/actions/transfer_tool.py b/src/feeding_deployment/actions/transfer_tool.py index d5b24d7d..d551cbf0 100644 --- a/src/feeding_deployment/actions/transfer_tool.py +++ b/src/feeding_deployment/actions/transfer_tool.py @@ -100,7 +100,7 @@ def detect_initiate_transfer(self, initiate_transfer_interaction: str, ready_to_ if self.web_interface is not None: self.web_interface.clear_explanation() - if ready_to_initiate_mode == "led": + if ready_to_initiate_mode in ("led", "voice_led"): self.perception_interface.turn_off_led() def detect_transfer_complete(self, transfer_complete_interaction: str, ready_for_transfer_interaction: str): @@ -144,7 +144,7 @@ def detect_transfer_complete(self, transfer_complete_interaction: str, ready_for if self.web_interface is not None: self.web_interface.clear_explanation() - if ready_for_transfer_interaction == "led": + if ready_for_transfer_interaction in ("led", "voice_led"): self.perception_interface.turn_off_led() def relay_ready_to_initiate_transfer(self, ready_to_initiate_transfer_interaction: str, initiate_transfer_interaction: str): @@ -172,6 +172,22 @@ def relay_ready_to_initiate_transfer(self, ready_to_initiate_transfer_interactio self.perception_interface.turn_on_led() elif ready_to_initiate_transfer_interaction == "beep": self.perception_interface.speak("Beep") + elif ready_to_initiate_transfer_interaction == "voice_led": + self.perception_interface.turn_on_led() + if initiate_transfer_interaction == "button": + self.perception_interface.speak("Please press the button when ready") + elif initiate_transfer_interaction == "open_mouth": + self.perception_interface.speak("Please open your mouth when ready") + elif initiate_transfer_interaction == "auto_timeout": + self.perception_interface.speak("Please wait 5 seconds for the transfer to initiate") + else: + gestures = dict(self.load_synthesized_gestures()) + with open(self.synthesized_gestures_dict_path, "r") as f: + synthesized_gesture_function_name_to_label = json.load(f) + if initiate_transfer_interaction in gestures: + self.perception_interface.speak(f"Please do a {synthesized_gesture_function_name_to_label[initiate_transfer_interaction]} to initiate transfer") + else: + raise NotImplementedError else: raise NotImplementedError @@ -184,6 +200,9 @@ def relay_ready_for_transfer(self, ready_for_transfer_interaction: str): self.perception_interface.turn_on_led() elif ready_for_transfer_interaction == "beep": self.perception_interface.speak("Beep") + elif ready_for_transfer_interaction == "voice_led": + self.perception_interface.turn_on_led() + self.perception_interface.speak("Ready for transfer") else: raise NotImplementedError @@ -270,77 +289,20 @@ def get_behavior_tree_filename( assert table.name == "table" return f"transfer_{tool.name}.yaml" - def transfer_utensil(self, speed: str, *args, **kwargs) -> None: - # assert self.sim.held_object_name == "utensil" - - print("Transferring bite with utensil ...") - return - - if self.robot_interface is not None: - self.robot_interface.set_speed(speed) - - # Assume the last item in args is autocontinue time - bite_autocontinue_time = args[-1] - - if self.web_interface is not None: - self.web_interface.set_bite_autocontinue_timeout(bite_autocontinue_time) - - # All other items (everything except the last) should go on to the next call - remaining_args = args[:-1] - - if self.wrist_interface is not None: - # start the horizontal spoon thread if it is not already running - self.wrist_interface.start_horizontal_spoon_thread() - - self.move_to_joint_positions(self.sim.scene_description.before_transfer_pos) - - if self.wrist_interface is not None: - # stop the keep horizontal thread - self.wrist_interface.stop_horizontal_spoon_thread() - - self.set_tool("fork") - self.execute_transfer(*remaining_args, **kwargs) + def transfer_utensil( + self, speed: str, + ready_to_initiate: str, initiate: str, + ready_for_transfer: str, transfer_complete: str, + outside_mouth_distance: float, time_to_wait: float, + retract_after_transfer: int = 0, + ) -> None: + print("Transferring utensil ...") + if retract_after_transfer == 1: + print("Retracting to rest position after bite transfer.") + self.move_to_joint_positions(self.sim.scene_description.retract_pos) def transfer_drink(self, speed: str, *args, **kwargs) -> None: - assert self.sim.held_object_name == "drink" - - if self.robot_interface is not None: - self.robot_interface.set_speed(speed) - - # Assume the second last item in args is the ask_confirmation - ask_confirmation = args[-2] - - # Assume the last item in args is autocontinue time - drink_autocontinue_time = args[-1] - - # All other items (everything except the last two) should go on to the next call - remaining_args = args[:-2] - - if self.web_interface is not None: - self.web_interface.set_drink_autocontinue_timeout(drink_autocontinue_time) - if ask_confirmation: - self.web_interface.get_drink_transfer_confirmation() - - self.move_to_joint_positions(self.sim.scene_description.before_transfer_pos) - - self.set_tool("drink") - self.execute_transfer(*remaining_args, maintain_position_at_goal=True, **kwargs) + print("Transferring drink") def transfer_wipe(self, speed: str, *args, **kwargs) -> None: - assert self.sim.held_object_name == "wipe" - - if self.robot_interface is not None: - self.robot_interface.set_speed(speed) - - # Assume the last item in args is the ask_confirmation - ask_confirmation = args[-1] - # All other items (everything except the last) should go on to the next call - remaining_args = args[:-1] - - self.move_to_joint_positions(self.sim.scene_description.before_transfer_pos) - - if self.web_interface is not None and ask_confirmation: - self.web_interface.get_wipe_transfer_confirmation() - - self.set_tool("wipe") - self.execute_transfer(*remaining_args, maintain_position_at_goal=True, **kwargs) + print("Transferring wipe") \ No newline at end of file diff --git a/src/feeding_deployment/integration/apply_preferences.py b/src/feeding_deployment/integration/apply_preferences.py new file mode 100644 index 00000000..b750a193 --- /dev/null +++ b/src/feeding_deployment/integration/apply_preferences.py @@ -0,0 +1,447 @@ +"""Translate a ground-truth preference bundle into BT YAML parameter writes +and scene-level configuration changes. + +Every bundle field that maps to a BT parameter is declared in _BT_MAPPING. +apply_bundle_to_behavior_trees() iterates that mapping, loads the affected +YAML files, overwrites `value` entries, and saves them back to disk so that +subsequent execute_action() calls pick up the new values (BT YAMLs are +loaded fresh on each tick). +""" + +from __future__ import annotations + +import warnings as _warnings +from pathlib import Path +from typing import Any + +import yaml + +# --------------------------------------------------------------------------- +# Value translators: bundle option string → BT parameter value +# --------------------------------------------------------------------------- + +_SPEED_MAP = {"slow": "low", "medium": "medium", "fast": "high"} + +_CONFIRMATION_MAP = {"yes": 1, "no": 0} + +_RETRACT_MAP = {"yes": 1, "no": 0} + +_AUTOCONTINUE_MAP = {"10 sec": 10.0, "100 sec": 100.0, "1000 sec": 1000.0} + +_OUTSIDE_MOUTH_DISTANCE_MAP = {"near": 0.1, "medium": 0.15, "far": 0.2} + +_CONVEY_READY_MAP = { + "no cue": "silent", + "speech": "voice", + "LED": "led", + "speech + LED": "voice_led", +} + +_INITIATE_TRANSFER_MAP = { + "open mouth": "open_mouth", + "button": "button", + "autocontinue": "auto_timeout", +} + +_COMPLETE_TRANSFER_MAP = { + "perception": "sense", + "button": "button", + "autocontinue": "auto_timeout", +} + +_TRANSFER_MODE_MAP = { + "inside mouth transfer": "inside", + "outside mouth transfer": "outside", +} + +_SKEWERING_AXIS_MAP = { + "parallel to major axis": "horizontal", + "perpendicular to major axis": "vertical", +} + + +def _dipping_depth_translate(val: str) -> float | None: + # NOTE: "do not dip" only skips the FoodDippingDepth BT write — it does NOT + # prevent FLAIR's autonomous planner from choosing to dip. The dipping + # decision is made independently in inference_class.py + # (get_autonomous_action), which looks at plate contents and user + # preferences. To truly suppress dipping when the user says "do not dip", + # a flag must be passed into FLAIR's planning logic so that + # get_autonomous_action never selects a dip action. This is separate from + # the BT parameter layer and is not yet implemented. + if val == "do not dip": + return None + return {"less": 0.01, "more": 0.03}[val] + +# --------------------------------------------------------------------------- +# Declarative mapping: bundle field → list of (yaml_filename, bt_param_name, translator) +# +# "translator" is either a dict (direct lookup) or a callable (value → bt_value). +# A translator may return None to signal "skip this write" (e.g. outside_mouth_distance +# when the bundle value is "not applicable"). +# --------------------------------------------------------------------------- + +_ALL_BT_YAMLS: list[str] = [ + "acquire_bite.yaml", + "close_fridge.yaml", + "close_microwave.yaml", + "emulate_transfer.yaml", + "navigate_to_fridge.yaml", + "navigate_to_microwave.yaml", + "navigate_to_sink.yaml", + "navigate_to_table.yaml", + "open_fridge.yaml", + "open_microwave.yaml", + "pick_drink.yaml", + "pick_plate_from_fridge.yaml", + "pick_plate_from_holder.yaml", + "pick_plate_from_microwave.yaml", + "pick_plate_from_table.yaml", + "pick_utensil.yaml", + "pick_wipe.yaml", + "place_plate_in_fridge.yaml", + "place_plate_in_microwave.yaml", + "place_plate_in_sink.yaml", + "place_plate_on_holder.yaml", + "place_plate_on_table.yaml", + "press_microwave_button.yaml", + "stow_drink.yaml", + "stow_utensil.yaml", + "stow_wipe.yaml", + "transfer_drink.yaml", + "transfer_utensil.yaml", + "transfer_wipe.yaml", +] + +_TRANSFER_YAMLS = ["transfer_utensil.yaml", "transfer_drink.yaml", "transfer_wipe.yaml"] + + +def _outside_mouth_translate(val: str) -> float | None: + if val == "not applicable": + return None + return _OUTSIDE_MOUTH_DISTANCE_MAP[val] + + +def _microwave_duration_translate(val: str) -> float | None: + """Translate microwave_time bundle value to BT MicrowaveDuration seconds. + + "no microwave" returns None (skip the BT write — the planner already + excludes the PressMicrowaveButton HLA via the FoodHeated atom). + """ + if val == "no microwave": + return None + return {"1 min": 60.0, "2 min": 120.0, "3 min": 180.0}[val] + + +# Each entry: (bundle_field, yaml_files, bt_param_name, translator) +_BT_MAPPING: list[tuple[str, list[str], str, dict | Any]] = [ + # Speed — all 29 YAMLs + ("robot_speed", _ALL_BT_YAMLS, "Speed", _SPEED_MAP), + + # Web-interface confirmation + ("web_interface_confirmation", ["acquire_bite.yaml"], "TransferAskForConfirmation", _CONFIRMATION_MAP), + ("web_interface_confirmation", ["transfer_drink.yaml", "transfer_wipe.yaml"], + "AskForConfirmationInitiatingTransferSequence", _CONFIRMATION_MAP), + + # Autocontinue wait time + ("wait_before_autocontinue_seconds", + ["acquire_bite.yaml", "transfer_utensil.yaml", "transfer_drink.yaml"], + "TimeToWaitBeforeAutocontinue", _AUTOCONTINUE_MAP), + + # Outside-mouth distance + ("outside_mouth_distance", _TRANSFER_YAMLS, "OutsideMouthDistance", _outside_mouth_translate), + + # Convey robot ready for initiating transfer + ("convey_robot_ready_for_initiating_transfer", _TRANSFER_YAMLS, + "ReadyToInitiateTransferInteraction", _CONVEY_READY_MAP), + + # Detect user ready (per-tool) + ("detect_user_ready_for_initiating_transfer_feeding", + ["transfer_utensil.yaml"], "InitiateTransferInteraction", _INITIATE_TRANSFER_MAP), + ("detect_user_ready_for_initiating_transfer_drinking", + ["transfer_drink.yaml"], "InitiateTransferInteraction", _INITIATE_TRANSFER_MAP), + ("detect_user_ready_for_initiating_transfer_wiping", + ["transfer_wipe.yaml"], "InitiateTransferInteraction", _INITIATE_TRANSFER_MAP), + + # Convey robot ready for completing transfer + ("convey_robot_ready_for_completing_transfer", _TRANSFER_YAMLS, + "ReadyForTransferInteraction", _CONVEY_READY_MAP), + + # Detect user completed transfer (per-tool) + ("detect_user_completed_transfer_feeding", + ["transfer_utensil.yaml"], "TransferCompleteInteraction", _COMPLETE_TRANSFER_MAP), + ("detect_user_completed_transfer_drinking", + ["transfer_drink.yaml"], "TransferCompleteInteraction", _COMPLETE_TRANSFER_MAP), + ("detect_user_completed_transfer_wiping", + ["transfer_wipe.yaml"], "TransferCompleteInteraction", _COMPLETE_TRANSFER_MAP), + + # Skewering axis + ("skewering_axis", ["acquire_bite.yaml"], "SkeweringOrientation", _SKEWERING_AXIS_MAP), + + # Bite dipping preference → FoodDippingDepth (see _dipping_depth_translate for "do not dip" caveat) + ("bite_dipping_preference", ["acquire_bite.yaml"], "FoodDippingDepth", _dipping_depth_translate), + + # Microwave duration + ("microwave_time", ["press_microwave_button.yaml"], "MicrowaveDuration", _microwave_duration_translate), + + # Retract between bites (utensil only — drink/wipe don't have repeated transfer loops) + ("retract_between_bites", ["transfer_utensil.yaml"], "RetractAfterTransfer", _RETRACT_MAP), +] + + +# --------------------------------------------------------------------------- +# YAML helpers +# --------------------------------------------------------------------------- + +class _HlaTag: + """Placeholder for the !hla YAML tag so we can round-trip without losing it.""" + + def __init__(self, value: str) -> None: + self.value = value + + def __repr__(self) -> str: + return f"!hla {self.value}" + + +def _hla_constructor(loader: yaml.SafeLoader, node: yaml.Node) -> _HlaTag: + return _HlaTag(loader.construct_scalar(node)) + + +def _hla_representer(dumper: yaml.Dumper, tag: _HlaTag) -> yaml.Node: + return dumper.represent_scalar("!hla", tag.value) + + +_Loader = type("_Loader", (yaml.SafeLoader,), {}) +_Loader.add_constructor("!hla", _hla_constructor) + +_Dumper = type("_Dumper", (yaml.Dumper,), {}) +_Dumper.add_representer(_HlaTag, _hla_representer) + + +def _load_yaml(path: Path) -> dict: + with open(path, "r", encoding="utf-8") as f: + return yaml.load(f, Loader=_Loader) + + +def _save_yaml(path: Path, data: dict) -> None: + with open(path, "w", encoding="utf-8") as f: + f.write(yaml.dump(data, Dumper=_Dumper, sort_keys=False, default_flow_style=False)) + + +def _set_param_value(bt_data: dict, param_name: str, new_value: Any) -> bool: + """Find a parameter by name in the YAML dict and set its value. + + Returns True if the parameter was found and updated. + """ + for param in bt_data.get("parameters", []): + if param["name"] == param_name: + param["value"] = new_value + return True + return False + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def apply_bundle_to_behavior_trees( + bundle: dict[str, str], + bt_dir: Path, +) -> list[str]: + """Write preference-bundle values into the BT YAML files on disk. + + Returns a list of warning strings for edge cases (e.g. unsupported + combined cues). + """ + warnings: list[str] = [] + # Cache: yaml_filename -> loaded dict (load each file at most once) + loaded: dict[str, dict] = {} + dirty: set[str] = set() + + for bundle_field, yaml_files, bt_param, translator in _BT_MAPPING: + bundle_val = bundle.get(bundle_field) + if bundle_val is None: + continue + + # Translate + if callable(translator) and not isinstance(translator, dict): + bt_val = translator(bundle_val) + else: + bt_val = translator.get(bundle_val) + if bt_val is None: + warnings.append( + f"No mapping for {bundle_field}={bundle_val!r}; skipping BT write." + ) + continue + + if bt_val is None: + # Explicit skip (e.g. outside_mouth_distance="not applicable") + continue + + for fname in yaml_files: + fpath = bt_dir / fname + if not fpath.exists(): + warnings.append(f"BT YAML not found: {fpath}") + continue + if fname not in loaded: + loaded[fname] = _load_yaml(fpath) + if _set_param_value(loaded[fname], bt_param, bt_val): + dirty.add(fname) + else: + warnings.append( + f"Parameter {bt_param!r} not found in {fname}; skipping." + ) + + for fname in dirty: + _save_yaml(bt_dir / fname, loaded[fname]) + + return warnings + + +def apply_transfer_mode( + bundle: dict[str, str], + scene_description: Any, + hla_map: dict[str, Any], +) -> None: + """Set scene_description.transfer_type from the bundle and re-init the + TransferToolHLA transfer object so that subsequent execute_action() calls + use the correct inside/outside mouth transfer implementation. + """ + mode = bundle.get("transfer_mode") + if mode is None: + return + + new_type = _TRANSFER_MODE_MAP.get(mode) + if new_type is None: + raise ValueError( + f"Unknown transfer_mode={mode!r}. " + f"Expected one of {list(_TRANSFER_MODE_MAP.keys())}." + ) + + scene_description.transfer_type = new_type + + transfer_hla = hla_map.get("TransferTool") + if transfer_hla is None: + return + + # Re-run the inside/outside branch from TransferToolHLA.__init__ + from feeding_deployment.actions.feel_the_bite.inside_mouth_transfer import ( + InsideMouthTransfer, + ) + from feeding_deployment.actions.feel_the_bite.outside_mouth_transfer import ( + OutsideMouthTransfer, + ) + + if new_type == "inside": + transfer_hla.transfer = InsideMouthTransfer( + transfer_hla.sim, + transfer_hla.robot_interface, + transfer_hla.perception_interface, + transfer_hla.rviz_interface, + transfer_hla.no_waits, + transfer_hla.head_perception_log_dir, + ) + elif new_type == "outside": + transfer_hla.transfer = OutsideMouthTransfer( + transfer_hla.sim, + transfer_hla.robot_interface, + transfer_hla.perception_interface, + transfer_hla.rviz_interface, + transfer_hla.no_waits, + transfer_hla.head_perception_log_dir, + ) + else: + raise ValueError(f"Unrecognized transfer type: {new_type!r}") + + +_MICROWAVE_TIME_MAP = { + "no microwave": None, + "1 min": 60, + "2 min": 120, + "3 min": 180, +} + + +def apply_microwave_preference( + bundle: dict[str, str], + current_atoms: set, + food_heated_atom: Any, +) -> int | None: + """Adjust planner atoms based on the microwave_time preference. + + If "no microwave", adds the FoodHeated atom to current_atoms so the PDDL + planner skips the entire microwave sequence (navigate to microwave, open + door, place plate, press button, pick plate, close door). + + For "1 min"/"2 min"/"3 min", ensures FoodHeated is absent (the planner + will include microwave steps). + + Returns the microwave duration in seconds, or None for "no microwave". + The caller can use this value to set the actual microwave timer when the + PressMicrowaveButton HLA executes — this wiring is not yet implemented. + """ + microwave_time = bundle.get("microwave_time") + if microwave_time is None: + return None + + duration = _MICROWAVE_TIME_MAP.get(microwave_time) + if duration is None and microwave_time != "no microwave": + raise ValueError( + f"Unknown microwave_time={microwave_time!r}. " + f"Expected one of {list(_MICROWAVE_TIME_MAP.keys())}." + ) + + if microwave_time == "no microwave": + current_atoms.add(food_heated_atom) + else: + current_atoms.discard(food_heated_atom) + + return duration + + +def apply_dip_preference( + bundle: dict[str, str], + flair: Any, +) -> None: + """Set FLAIR's allow_dip flag based on bite_dipping_preference. + + When "do not dip", sets allow_dip=False so get_autonomous_action + deterministically strips any dip the LLM suggests. + """ + dip_pref = bundle.get("bite_dipping_preference") + if dip_pref is None or flair is None: + return + flair.set_allow_dip(dip_pref != "do not dip") + + +_OCCLUSION_TEXT_MAP = { + "do not consider occlusion": None, + "minimize left occlusion": "When choosing which food to pick, prefer items that minimize the robot arm blocking the user's view from the left.", + "minimize front occlusion": "When choosing which food to pick, prefer items that minimize the robot arm blocking the user's view from the front.", + "minimize right occlusion": "When choosing which food to pick, prefer items that minimize the robot arm blocking the user's view from the right.", +} + + +def apply_occlusion_preference( + bundle: dict[str, str], + flair: Any, +) -> None: + """Append occlusion_relevance as a soft hint to FLAIR's user_preference string. + + This injects text into the LLM preference planner prompt. It's a soft + hint — the LLM may or may not factor it into food selection. There is + no geometric enforcement of occlusion minimization. + """ + occlusion = bundle.get("occlusion_relevance") + if occlusion is None or flair is None: + return + + hint = _OCCLUSION_TEXT_MAP.get(occlusion) + if hint is None: + return + + current = flair.get_preference() + if current: + flair.set_preference(f"{current} {hint}") + else: + flair.set_preference(hint) diff --git a/src/feeding_deployment/integration/preference_context.py b/src/feeding_deployment/integration/preference_context.py new file mode 100644 index 00000000..b1ffb1bd --- /dev/null +++ b/src/feeding_deployment/integration/preference_context.py @@ -0,0 +1,52 @@ +"""Build and validate mealtime context for preference prediction. + +Only observable context is included (meal, setting, time_of_day). +""" + +from __future__ import annotations + +from typing import Any, Mapping + +from feeding_deployment.preference_learning.config import MEALS, SETTINGS, TIMES_OF_DAY + +PREFERENCE_CONTEXT_KEYS = ("meal", "setting", "time_of_day") + + +def build_preference_context(meal: str, setting: str, time_of_day: str) -> dict[str, str]: + """ + Assemble the context dict for the current meal before calling predict_bundle / update. + + Values must match the canonical lists in preference_learning.config. + """ + context = { + "meal": meal.strip(), + "setting": setting.strip(), + "time_of_day": time_of_day.strip(), + } + validate_preference_context(context) + return context + + +def validate_preference_context(context: Mapping[str, Any]) -> None: + """Raise ValueError if any field is missing or not in the allowed vocabulary.""" + for key in PREFERENCE_CONTEXT_KEYS: + if key not in context: + raise ValueError(f"preference context missing required key: {key!r}") + val = context[key] + if not isinstance(val, str) or not val.strip(): + raise ValueError(f"preference context {key!r} must be a non-empty string") + + meal = context["meal"].strip() + if meal not in MEALS: + raise ValueError( + f"Unknown meal={meal!r}. Must be one of the labels in " + "preference_learning.config.mealtime_context.MEALS." + ) + + setting = context["setting"].strip() + if setting not in SETTINGS: + raise ValueError(f"Unknown setting={setting!r}. Allowed: {SETTINGS!r}") + + tod = context["time_of_day"].strip() + if tod not in TIMES_OF_DAY: + raise ValueError(f"Unknown time_of_day={tod!r}. Allowed: {TIMES_OF_DAY!r}") diff --git a/src/feeding_deployment/integration/run.py b/src/feeding_deployment/integration/run.py index 3b01b12f..185466bd 100644 --- a/src/feeding_deployment/integration/run.py +++ b/src/feeding_deployment/integration/run.py @@ -113,6 +113,12 @@ ) from feeding_deployment.actions.flair.flair import FLAIR from feeding_deployment.transparency.query_llm import TransparencyQuery +from feeding_deployment.integration.preference_context import build_preference_context +from feeding_deployment.preference_learning.methods.prediction_model import PredictionModel, PREF_OPTIONS +from feeding_deployment.preference_learning.config import MEALS, SETTINGS, TIMES_OF_DAY +from feeding_deployment.preference_learning.config.physical_capabilities import ( + PHYSICAL_CAPABILITY_PROFILES, +) # All the high level actions we want to consider. HLAS = { @@ -142,12 +148,21 @@ class _Runner: """A class for running the integrated system.""" def __init__(self, scene_config: str, user: str, scenario:str, transfer_type: str, run_on_robot: bool, use_interface: bool, use_gui: bool, simulate_head_perception: bool, max_motion_planning_time: float, - resume_from_state: str = "", no_waits: bool = False) -> None: + resume_from_state: str = "", no_waits: bool = False, + physical_profile_label: str | None = None, + pref_day: int | None = None, + pref_mode: str = "none") -> None: self.run_on_robot = run_on_robot self.use_interface = use_interface self.simulate_head_perception = simulate_head_perception self.max_motion_planning_time = max_motion_planning_time self.no_waits = no_waits + self.deployment_user = user + self.physical_profile_label = physical_profile_label.strip() if physical_profile_label else None + self._pref_day = pref_day + self._pref_mode = pref_mode + self._prediction_model: PredictionModel | None = None + self.predicted_bundle: dict[str, str] | None = None # logs are saved in user/scenario directory self.log_dir = Path(__file__).parent / "log" / user / scenario @@ -404,12 +419,149 @@ def __init__(self, scene_config: str, user: str, scenario:str, transfer_type: st print("Runner is ready.") self.active = True + self.preference_context: dict[str, str] | None = None + + def ensure_preference_context(self) -> dict[str, str]: + """Require a valid preference context before the web session; no implicit defaults.""" + if self.preference_context is None: + raise RuntimeError( + "preference_context is required but unset. Call " + "set_meal_preference_context(meal, setting, time_of_day) before run()." + ) + return self.preference_context + + def set_meal_preference_context( + self, + meal: str, + setting: str, + time_of_day: str, + ) -> dict[str, str]: + """Validated observable context for this run (meal / setting / time); in-memory only.""" + self.preference_context = build_preference_context( + meal=meal, + setting=setting, + time_of_day=time_of_day, + ) + return self.preference_context def run(self, continuous = True) -> None: assert self.web_interface is not None, "Run takes user commands from the web interface which is None." - - self.web_interface.ready_for_task_selection() + + if self._pref_mode == "none": + self.web_interface.ready_for_task_selection() + else: + # --- Step 1: Collect context --- + if self._pref_mode == "terminal": + from feeding_deployment.integration.terminal_preferences import ( + terminal_collect_context, + terminal_correct_preferences, + ) + ctx_dict = terminal_collect_context() + self.set_meal_preference_context( + meal=ctx_dict["meal"], + setting=ctx_dict["setting"], + time_of_day=ctx_dict["time_of_day"], + ) + elif self._pref_mode == "interface" and self.preference_context is None: + ctx_defaults = { + "meal": MEALS[0], + "setting": SETTINGS[0], + "time_of_day": TIMES_OF_DAY[0], + } + ctx_dict = self.web_interface.get_preference_context( + list(MEALS), + list(SETTINGS), + list(TIMES_OF_DAY), + ctx_defaults, + ) + self.set_meal_preference_context( + meal=ctx_dict["meal"], + setting=ctx_dict["setting"], + time_of_day=ctx_dict["time_of_day"], + ) + + ctx = self.ensure_preference_context() + print("Preference context (meal / setting / time_of_day):", ctx) + # --- Step 2: Predict --- + assert self.physical_profile_label is not None, ( + "physical_profile_label is required for preference prediction " + "(pass --physical_profile_file)." + ) + pref_logs = self.log_dir / "preference_learning" + self._prediction_model = PredictionModel( + user=self.deployment_user, + physical_profile_label="deployment_physical_profile", + logs_dir=pref_logs, + physical_profile_description=self.physical_profile_label, + ) + self.predicted_bundle = self._prediction_model.predict_bundle(dict(ctx), {}) + print("Predicted preference bundle (initial):", json.dumps(self.predicted_bundle, indent=2)) + + # --- Step 3: Correct --- + if self._pref_mode == "terminal": + user_bundle = terminal_correct_preferences( + self.predicted_bundle, dict(PREF_OPTIONS), + ) + else: + user_bundle = self.web_interface.get_preference_corrections( + self.predicted_bundle, dict(PREF_OPTIONS), + ) + self.ground_truth_bundle = user_bundle + self.corrected = { + k: v for k, v in user_bundle.items() + if v != self.predicted_bundle.get(k) + } + print("Ground truth bundle:", json.dumps(self.ground_truth_bundle, indent=2)) + print("Corrected fields:", json.dumps(self.corrected, indent=2)) + + # --- Step 4: Apply --- + from feeding_deployment.integration.apply_preferences import ( + apply_bundle_to_behavior_trees, + apply_transfer_mode, + apply_microwave_preference, + apply_dip_preference, + apply_occlusion_preference, + ) + bt_warnings = apply_bundle_to_behavior_trees( + self.ground_truth_bundle, self.run_behavior_tree_dir, + ) + for w in bt_warnings: + print(f"[preference-apply] WARNING: {w}") + apply_transfer_mode( + self.ground_truth_bundle, + self.sim.scene_description, + self.hla_name_to_hla, + ) + microwave_duration = apply_microwave_preference( + self.ground_truth_bundle, + self.current_atoms, + GroundAtom(FoodHeated, []), + ) + if microwave_duration is None: + print("Microwave preference: no microwave (FoodHeated added to planner state).") + else: + print(f"Microwave preference: {microwave_duration}s (planner will include microwave steps).") + apply_dip_preference(self.ground_truth_bundle, self.flair) + apply_occlusion_preference(self.ground_truth_bundle, self.flair) + print("Applied ground-truth bundle to behavior trees and scene config.") + + # --- Step 5: Learn --- + day = self._prediction_model.next_day() if self._pref_day is None else self._pref_day + print(f"[learn] Updating memory models (day {day}) ...") + self._prediction_model.update( + day=day, + context=dict(ctx), + corrected=self.corrected, + ground_truth_bundle=self.ground_truth_bundle, + ) + print(f"[learn] Memory update complete (day {day}).") + if self._pref_mode == "interface": + self.web_interface.notify_preference_corrections_applied( + "Preferences were applied successfully." + ) + + self.web_interface.ready_for_task_selection() last_task_type = None while self.active: if not continuous: @@ -783,6 +935,28 @@ def update_scene_spec(self, scene_spec_updates: dict[str, Any]) -> None: parser.add_argument("--meal_id", type=int, default=1) parser.add_argument("--results_dir", type=Path, default=Path("feast_default_user"), help="Directory for saving and loading results and user responses. Make one of these directories per user.") parser.add_argument("--load", action="store_true") + parser.add_argument( + "--pref_mode", + type=str, + choices=["none", "terminal", "interface"], + default="none", + help="Preference interaction mode. " + "'none': no personalization (default). " + "'terminal': predict + correct via terminal prompts. " + "'interface': predict + correct via web interface (requires frontend).", + ) + parser.add_argument( + "--physical_profile_file", + type=str, + default="", + help="UTF-8 text file describing the user's physical capabilities. " + "Required with --pref_mode=terminal or --pref_mode=interface.", + ) + parser.add_argument( + "--pref_day", type=int, default=None, + help="Override the deployment day number for preference learning. " + "If omitted, auto-detected from existing log files (next unused day).", + ) args = parser.parse_args() if args.user == "": @@ -794,6 +968,20 @@ def update_scene_spec(self, scene_spec_updates: dict[str, Any]) -> None: else: rospy.init_node("feeding_deployment", anonymous=True) + physical_profile_label: str | None = None + if args.pref_mode in ("terminal", "interface"): + if not args.physical_profile_file.strip(): + raise ValueError( + f"With --pref_mode={args.pref_mode}, pass --physical_profile_file " + "pointing to a UTF-8 .txt file with freeform physical-capability text." + ) + profile_path = Path(args.physical_profile_file.strip()) + if not profile_path.is_file(): + raise ValueError(f"physical profile file not found: {profile_path}") + physical_profile_label = profile_path.read_text(encoding="utf-8").strip() + if not physical_profile_label: + raise ValueError(f"physical profile file is empty: {profile_path}") + runner = _Runner(args.scene_config, args.user, args.scenario, @@ -804,8 +992,11 @@ def update_scene_spec(self, scene_spec_updates: dict[str, Any]) -> None: args.simulate_head_perception, args.max_motion_planning_time, args.resume_from_state, - args.no_waits) - + args.no_waits, + physical_profile_label=physical_profile_label, + pref_day=args.pref_day, + pref_mode=args.pref_mode) + # Handle Ctrl+C gracefully signal.signal(signal.SIGINT, runner.signal_handler) @@ -821,4 +1012,4 @@ def update_scene_spec(self, scene_spec_updates: dict[str, Any]) -> None: runner.make_video(output_path) if args.run_on_robot: - rospy.spin() \ No newline at end of file + rospy.spin() diff --git a/src/feeding_deployment/integration/saved_states/last_state.p b/src/feeding_deployment/integration/saved_states/last_state.p index e44aed4a0d1fccb02edf89baba1bb7e7cbe6dd33..af48df0908b57ccea864e05e6b79b6e6a567807e 100644 GIT binary patch literal 3370 zcmbVOdsGuw9*#CV5(DKWRn(T~R)fl8Vdd4e0YL{v5am@6%udLJ35HCZOadsc9E5V} zRD6t8#0mmduvU8zid|7$&}9{T^dP8Z>1thdtqSX-=oVb}&LasR_U!5Wb?4sS{qFbs zzTbTm@^`&w$74QnjuJr>DowgffoRoOP9~xuymYEey&5J|n8r&-zyv~(GnTPY;pU)3 z3|A{+nbDn;hmp#oB%}z(-sI>^L^GZjsOoYI5D3TvThwP*v=`HV1{{#p@(%W@Lw$W2qTLpSP23bf;J=)?6 zB;P{PyOo!R<|`{rjs4~!f-0em6e-A3X%Ti|R0y@UUAe=6wVrD$V^NEhmkUj(cdPa* z3qH6AwX%&aixMD$`8ZUOoVdo3qgaS`)|+y-|Kkh z9u8Qj1vXpwO1u4-50YNt5Ts`}-hz(0Y@%V=t_fkK+=?9R6ri+1hRHLKG=hpI?MMB} z2qhJ*mzRNRc?U<>B7@a39f9l92%VQ!p-@w#5W+FNMiEM2nZO*bxCB}p6lxk>j;NRv z#F+og5qu4t2@Dr5!!Shz0s<8jX+`8{5xN#hD;QT1IaDN%w53;}m!wFM#1iZQVq_?+ zL#a>P^S_hs(Lq@Po3}0*jEBmev?f;ko%o^S3*EuegU!;(xL!`$f}NrwR17%kGY$#$ z{eLjfIc#p|TvWof7Lc}U5FJ8*!lY2oR0{QuJk9qYFa?5a4R zx!E{$0V&kL1{ne4=?FoQ)>;?`(ora6PT8IWAcfsJWh|Lr1dJc zDosTL*jW*Xxap!PJ7P9CV558?!wv_swZVd z;*XxLUZ*yrS5of{+C`_(8TgQnqu4WZSqJ6)%~$T)tAB9W&Ju0D9aP_x`j1<~ zdr#a$zrO`;C+@hal5+|TJlqIu<{5Ji2 zePqGHV%6Cj?av-{#y*xD@{LA$==Fx)9o=&|`PP^nvsU6NMLJk0!D1t-0?Pv1&?M51 z9S$}{`F*bt#RxNoMe6fGuZf5M!+-s>}P9>HdkJdsAg{Y0$4<=sZ9JVs2O3`g~wDd)EfDbik(grZ#P0am1GGNz-i=QOQ0ZG>1D*9f_#Crr&vT zrP?WPHfyuXSW*lkD^}Uv`iqw()*T-n@Xr^XPMac%NcMEQldjK}zC%PVte*M?v1}qR zHT{fXH-*tu`mPeLpUaz-h+XNczT%J@OC_b}PgXl=B93@|c)T;JHgd)^@8LWT@^9On z!AkTQYDwINyM|4vo33SZ^5UeLQ;k9K17VK-tLL`;xlI!i5jG|IM$C%-8C}h-yByDR z#o}mY#UZkUayR|rQX2F=YCHJSzUr@s@QXh+9r$tU!=43oly~>q&5}2y?m+t+hr=43 z62ZeU4XAPY;>b5?FpJT8binu`aHqC~VT~ErgIofiTqi$09~%3+_~%t3pTRvHvkH1z zuqB)nJLw(j_Bz~{EDXb@Z}P;6LyE@^{OhruYwm3)19|Jn(|+yYU)TCquFu)JdT_m=ef!JDUJ9B$?P#X8IxtVW% z-}n6Hs0-XqFQE)TCrpv}k2TAp9@1td-rft=^pIE{z;rgDOEn4mhK zN}g)t^bE}Hf=8iUPgH>$Xs%`;%9xfd!}X8An+F5GW17Q_S>yl(TTqI}4u06&tRskFHR zlEiB+ftM6s_2)J5z{hK98`LdT7;*-(%wqn@1|*hce*UnFEap3hhtN(Ra={$Wov1x# zA7z=P!RbG@cn5=v0rNcvB39(>f$0lZO8o|b{O`IZ8d%V+~!Kh0^1KwlB zlk6B-Tn}XeA&iekO$=zf>=z|7Ff-M+6$4miqHJ2bFO zL1hlK9-rG5twh<7uk=~J8)RCbmGkC$rsdPv^Isut*CFj2OCl7M}1I(U?!EYRkvsWAR(~|r=o#?!ZSn!I*FyW zghFH^;O6ugP8GE_=#19mv3hI;>KS%ZL^gAHd-k?$955FNatWMYfc^0ejzDJI@7sR} zxb=8MA!?$bQ^?n2_mjj0eIe2~+5ku?HUxzN-9m_htx;qM7IyZ>_w?;@zFqdl1kxrH z>1km?XpmfpVGkU*^U1!M*2e;FVKX)Nq+mB~Ce|RmD91Pf9w=<9N1U9n!x@cYX4=CK zBmCjWJO8*1tLuKeiX!hq*GkNo7CanKa0mrL8JU!j>g3Cx&t4eXn;ldg!GP}()k|GZ zWj$B-){zL+dsp$q7~>x#B9B3cDRp`N{S38sgfOK?dNFPis^#k2?~PrY0ufZ91jSY1 z(J4i}fpZh;k0dfu5A2@_qFM( z-{!36z5+=2UxT=RKIkvQPWP!lGAkq^_vY8TUpx5VIQ3PEyEE0qNA%{|6AMp#@xjZ? zI^@BGx_;8lV;AGiDm#9D!O7zVgK6hp1l9o>Og=WX< z6B|{~%C)R&fOil>KSG!3X(pn`Tp5%8@6!tn1=((71iQ@$*I)t(`@{3E=eLgxZd%83 zQZ?u0Yd|H}A~qS%Fz);v&L|W$#U6P^;$;@6_`a{-T~6Kn4_N3a>RmO|v0hd5G_7(! zY`Wh-j*EgQ)a6^g8{Kg8mD4wf@3e5=1Vu#t{pES3_{v91>zGc#FHw!k$$l@dI)IE6 zL_CI%7!OD|fL2yB9(?>D9&&W8r0uK!iKfBm{Ecg`MU23}V9v?F-Q`c$QDrL!z*E7a gD$wpDWk|G#_KCYeXY?;qO2VvJg8M;E_rkpY0M{1K8~^|S diff --git a/src/feeding_deployment/integration/terminal_preferences.py b/src/feeding_deployment/integration/terminal_preferences.py new file mode 100644 index 00000000..6905994e --- /dev/null +++ b/src/feeding_deployment/integration/terminal_preferences.py @@ -0,0 +1,81 @@ +"""Terminal-based preference context input and preference correction. + +Used when --pref_mode=terminal. Provides the same contract as the +web-interface path (context dict, corrected bundle) but via stdin/stdout. +""" + +from __future__ import annotations + +from feeding_deployment.preference_learning.config import MEALS, SETTINGS, TIMES_OF_DAY +from feeding_deployment.preference_learning.methods.prediction_model import PREF_OPTIONS +from feeding_deployment.preference_learning.methods.utils import PREF_FIELDS + + +def _pick_from_list(prompt: str, options: list[str]) -> str: + """Display numbered options and return the user's choice.""" + print(f"\n{prompt}") + for i, opt in enumerate(options, 1): + print(f" {i}. {opt}") + while True: + raw = input("Enter number: ").strip() + try: + idx = int(raw) + if 1 <= idx <= len(options): + return options[idx - 1] + except ValueError: + pass + print(f" Invalid choice. Enter a number between 1 and {len(options)}.") + + +def terminal_collect_context() -> dict[str, str]: + """Prompt the operator to choose meal, setting, and time_of_day.""" + print("\n=== Preference Context ===") + meal = _pick_from_list("Select meal:", MEALS) + setting = _pick_from_list("Select dining setting:", SETTINGS) + time_of_day = _pick_from_list("Select time of day:", TIMES_OF_DAY) + return {"meal": meal, "setting": setting, "time_of_day": time_of_day} + + +def terminal_correct_preferences( + predicted_bundle: dict[str, str], + pref_options: dict[str, list[str]], +) -> dict[str, str]: + """Show the predicted bundle field-by-field and let the operator correct any field. + + Returns the final bundle (predicted values + any corrections). + """ + bundle = dict(predicted_bundle) + + print("\n=== Predicted Preferences ===") + print("Review each field. Press Enter to accept the predicted value, or enter") + print("a number to change it.\n") + + for i, field in enumerate(PREF_FIELDS, 1): + options = pref_options[field] + predicted = bundle[field] + pred_idx = options.index(predicted) + 1 if predicted in options else "?" + + print(f"[{i}/{len(PREF_FIELDS)}] {field}") + for j, opt in enumerate(options, 1): + marker = " <-- predicted" if opt == predicted else "" + print(f" {j}. {opt}{marker}") + + while True: + raw = input(f"Choice [{pred_idx}]: ").strip() + if raw == "": + break + try: + idx = int(raw) + if 1 <= idx <= len(options): + bundle[field] = options[idx - 1] + break + except ValueError: + pass + print(f" Invalid. Enter a number 1-{len(options)} or press Enter to keep.") + + corrected = {k: v for k, v in bundle.items() if v != predicted_bundle.get(k)} + print(f"\nCorrections made: {len(corrected)} field(s)") + if corrected: + for k, v in corrected.items(): + print(f" {k}: {predicted_bundle[k]!r} -> {v!r}") + return bundle diff --git a/src/feeding_deployment/interfaces/web_interface.py b/src/feeding_deployment/interfaces/web_interface.py index 397a00a5..4465570a 100644 --- a/src/feeding_deployment/interfaces/web_interface.py +++ b/src/feeding_deployment/interfaces/web_interface.py @@ -104,51 +104,53 @@ def _message_callback(self, msg: "String") -> None: f.write(msg.data + "\n") msg_dict = json.loads(msg.data) + msg_state = msg_dict.get("state") + msg_status = msg_dict.get("status") self.task_selection_jump = False - if msg_dict["status"] == "finish_feeding": + if msg_status == "finish_feeding": task_selected = { "task": "reset", "type": "reset", } self.task_selection_queue.put(task_selected) - elif msg_dict["state"] == "task_selection": - if msg_dict["status"] == "take_bite": + elif msg_state == "task_selection": + if msg_status == "take_bite": task_selected = { "task": "meal_assistance", "type": "bite", } - elif msg_dict["status"] == "take_sip": + elif msg_status == "take_sip": task_selected = { "task": "meal_assistance", "type": "sip", } - elif msg_dict["status"] == "mouth_wiping": + elif msg_status == "mouth_wiping": task_selected = { "task": "meal_assistance", "type": "wipe", } - elif msg_dict["status"] == "transparency": + elif msg_status == "transparency": task_selected = { "task": "personalization", "type": "transparency", } - elif msg_dict["status"] == "adaptability": + elif msg_status == "adaptability": task_selected = { "task": "personalization", "type": "adaptability", } - elif msg_dict["status"] == "gesture": + elif msg_status == "gesture": task_selected = { "task": "personalization", "type": "gesture", } - elif msg_dict["status"] == "jump": + elif msg_status == "jump": self.task_selection_jump = True return else: - print("Invalid task selection status received from interface: ", msg_dict["status"]) + print("Invalid task selection status received from interface: ", msg_status) return # remove explanation lock (if it exists) @@ -393,6 +395,91 @@ def update_adaptability_response(self, response: str) -> None: assert self.current_page == "adaptability", "Cannot update adaptability response when not on the adaptability page." self._send_message({"state": "adaptability_response", "status": response}) + #### Preference Correction Pages #### + + def get_preference_context( + self, + meals: list[str], + settings: list[str], + times_of_day: list[str], + defaults: dict[str, str], + ) -> dict[str, str]: + """Send preference-context choices to the webapp and wait for a selection.""" + self.current_page = "preference_context" + self._send_message({"state": "preference_context", "status": "jump"}) + time.sleep(0.5) + self._send_message({ + "state": "preference_context_data", + "meals": meals, + "settings": settings, + "time_of_day": times_of_day, + "defaults": defaults, + }) + msg_dict = self.get_required_web_interface_message( + lambda msg_dict: msg_dict.get("state") == "preference_context_response" + ) + if msg_dict is None: + raise RuntimeError( + "Preference context exited before submission. " + "A valid meal / setting / time_of_day selection is required." + ) + return { + "meal": msg_dict["meal"], + "setting": msg_dict["setting"], + "time_of_day": msg_dict["time_of_day"], + } + + def get_preference_corrections( + self, + predicted_bundle: dict[str, str], + pref_options: dict[str, list[str]], + ) -> dict[str, str]: + """Send predicted preference bundle to the webapp and wait for user corrections. + + Message contract (for frontend implementation): + + Sent to webapp (on /ServerComm): + 1. {"state": "preference_correction", "status": "jump"} + 2. { + "state": "preference_correction_data", + "predicted_bundle": {"field": "value", ...}, + "options": {"field": ["opt1", "opt2", ...], ...} + } + + Expected back from webapp (on WebAppComm): + { + "state": "preference_correction_response", + "bundle": {"field": "value", ...} + } + where "bundle" contains all fields with the user's final selections + (unchanged fields keep their predicted values). + """ + self.current_page = "preference_correction" + self._send_message({"state": "preference_correction", "status": "jump"}) + time.sleep(0.5) + self._send_message({ + "state": "preference_correction_data", + "predicted_bundle": predicted_bundle, + "options": pref_options, + }) + msg_dict = self.get_required_web_interface_message( + lambda msg_dict: msg_dict.get("state") == "preference_correction_response" + ) + if msg_dict is None: + print( + "Preference correction exited before submission; " + "using predicted bundle unchanged." + ) + return dict(predicted_bundle) + return msg_dict["bundle"] + + def notify_preference_corrections_applied(self, message: str | None = None) -> None: + """Notify the webapp that preference corrections were applied successfully.""" + msg_dict: dict[str, Any] = {"state": "preference_correction_applied"} + if message is not None: + msg_dict["message"] = message + self._send_message(msg_dict) + #### Gesture Pages #### def get_gesture_type(self) -> None: @@ -601,4 +688,4 @@ def provide_continuous_explanations(self) -> None: print("Skill type: ", skill_type) print("Skill params: ", skill_params) - print("Dip type: ", dip_type) \ No newline at end of file + print("Dip type: ", dip_type) diff --git a/src/feeding_deployment/preference_learning/methods/long_term_memory.py b/src/feeding_deployment/preference_learning/methods/long_term_memory.py index 17b0c8c6..fdb1f8ca 100644 --- a/src/feeding_deployment/preference_learning/methods/long_term_memory.py +++ b/src/feeding_deployment/preference_learning/methods/long_term_memory.py @@ -25,13 +25,22 @@ class LongTermMemoryModel: Stateful LTM that updates EVERY meal (online). """ - def __init__(self, physical_profile_label: str, client: OpenAI, chat_model: str, retry_fn, logs_dir: Path = None) -> None: + def __init__( + self, + physical_profile_label: str, + client: OpenAI, + chat_model: str, + retry_fn, + logs_dir: Path = None, + physical_profile_description: str | None = None, + ) -> None: self.client = client self.chat_model = chat_model self._retry = retry_fn self.logs_dir = logs_dir self.physical_profile_label = physical_profile_label + self._physical_profile_description = physical_profile_description self._ltm_summary: str = "" self._initialized: bool = False @@ -46,6 +55,7 @@ def add_episode(self, episode_text: str) -> None: physical_profile_label=self.physical_profile_label, previous_ltm_summary=previous, new_episode=episode_text, + physical_profile_description=self._physical_profile_description, ) def _call() -> Any: @@ -110,11 +120,12 @@ def main() -> int: client = OpenAI(api_key=_resolve_api_key(args.api_key)) ltm = LongTermMemoryModel( + physical_profile_label=physical_profile_label, client=client, chat_model=args.openai_model, retry_fn=_retry_on_rate_limit, + logs_dir=Path(args.log_dir), ) - ltm.reset(physical_profile_label) wanted = {10, 20, 30} seen: set[int] = set() diff --git a/src/feeding_deployment/preference_learning/methods/prediction_model.py b/src/feeding_deployment/preference_learning/methods/prediction_model.py index e99d6731..55f2e381 100644 --- a/src/feeding_deployment/preference_learning/methods/prediction_model.py +++ b/src/feeding_deployment/preference_learning/methods/prediction_model.py @@ -119,11 +119,13 @@ def __init__( k_retrieve: int = 5, chat_model: str = "gpt-5.4", embed_model: str = "text-embedding-3-small", + physical_profile_description: str | None = None, ) -> None: self.user = user self.physical_profile_label = physical_profile_label - self.client = OpenAI(api_key=_resolve_api_key()) + self.physical_profile_description = physical_profile_description + self.client = OpenAI(api_key=_resolve_api_key(None)) self.chat_model = chat_model self.embed_model = embed_model self._retry = retry_fn @@ -140,6 +142,7 @@ def __init__( chat_model=self.chat_model, retry_fn=retry_fn, logs_dir=self.logs_dir / "long_term_memory_llm_calls", + physical_profile_description=self.physical_profile_description, ) if use_episodic_memory: @@ -158,7 +161,15 @@ def __init__( self.working_memory_calls_dir = self.logs_dir / user / "prediction_model_llm_calls" self.working_memory_calls_dir.mkdir(parents=True, exist_ok=True) - + + def next_day(self) -> int: + """Return the next unused day number by scanning existing ``day_NNNN.json`` files.""" + existing = sorted(self.working_memory_dir.glob("day_*.json")) + if not existing: + return 1 + last_stem = existing[-1].stem # e.g. "day_0003" + return int(last_stem.split("_", 1)[1]) + 1 + def update(self, day: int, context: Dict[str, Any], corrected: Dict[str, str], ground_truth_bundle: Dict[str, str]) -> None: ep_txt = _episode_text(day=day, context=context, prefs=ground_truth_bundle) @@ -236,6 +247,7 @@ def predict_bundle( context=context, corrected_block=corrected_block, options_block=options_block, + physical_profile_description=self.physical_profile_description, ) def _call() -> Any: diff --git a/src/feeding_deployment/preference_learning/methods/prompts/bundle_prediction.py b/src/feeding_deployment/preference_learning/methods/prompts/bundle_prediction.py index c09d3bc6..e10eb4bf 100644 --- a/src/feeding_deployment/preference_learning/methods/prompts/bundle_prediction.py +++ b/src/feeding_deployment/preference_learning/methods/prompts/bundle_prediction.py @@ -29,20 +29,26 @@ def get_bundle_prediction_prompt( context: dict, corrected_block: str, options_block: str, + *, + physical_profile_description: str | None = None, ) -> str: template = BUNDLE_PREDICTION_PROMPT_PATH.read_text(encoding="utf-8") system_description = get_system_description_prompt() - if physical_profile_label not in _PHYSICAL_CAPABILITY_BY_LABEL: - valid = ", ".join(sorted(_PHYSICAL_CAPABILITY_BY_LABEL.keys())) - raise ValueError(f"Unknown physical_profile_label={physical_profile_label!r}. Valid: {valid}") - - physical_profile_description = _PHYSICAL_CAPABILITY_BY_LABEL[physical_profile_label].description + if physical_profile_description is not None: + desc = physical_profile_description.strip() + if not desc: + raise ValueError("physical_profile_description is empty") + else: + if physical_profile_label not in _PHYSICAL_CAPABILITY_BY_LABEL: + valid = ", ".join(sorted(_PHYSICAL_CAPABILITY_BY_LABEL.keys())) + raise ValueError(f"Unknown physical_profile_label={physical_profile_label!r}. Valid: {valid}") + desc = _PHYSICAL_CAPABILITY_BY_LABEL[physical_profile_label].description return template.format( system_description=system_description, - physical_profile=physical_profile_description, + physical_profile=desc, ltm_summary=ltm_summary, retrieved_block=retrieved_block, meal=context.get("meal"), diff --git a/src/feeding_deployment/preference_learning/methods/prompts/ltm_update.py b/src/feeding_deployment/preference_learning/methods/prompts/ltm_update.py index aeaefd70..7946d082 100644 --- a/src/feeding_deployment/preference_learning/methods/prompts/ltm_update.py +++ b/src/feeding_deployment/preference_learning/methods/prompts/ltm_update.py @@ -25,19 +25,25 @@ def get_ltm_update_prompt( physical_profile_label: str, previous_ltm_summary: str, new_episode: str, + *, + physical_profile_description: str | None = None, ) -> str: template = LTM_UPDATE_PROMPT_PATH.read_text(encoding="utf-8") system_description = get_system_description_prompt() - if physical_profile_label not in _PHYSICAL_CAPABILITY_BY_LABEL: - valid = ", ".join(sorted(_PHYSICAL_CAPABILITY_BY_LABEL.keys())) - raise ValueError(f"Unknown physical_profile_label={physical_profile_label!r}. Valid: {valid}") - - physical_profile_description = _PHYSICAL_CAPABILITY_BY_LABEL[physical_profile_label].description + if physical_profile_description is not None: + desc = physical_profile_description.strip() + if not desc: + raise ValueError("physical_profile_description is empty") + else: + if physical_profile_label not in _PHYSICAL_CAPABILITY_BY_LABEL: + valid = ", ".join(sorted(_PHYSICAL_CAPABILITY_BY_LABEL.keys())) + raise ValueError(f"Unknown physical_profile_label={physical_profile_label!r}. Valid: {valid}") + desc = _PHYSICAL_CAPABILITY_BY_LABEL[physical_profile_label].description return template.format( system_description=system_description, - physical_profile=physical_profile_description, + physical_profile=desc, previous_ltm_summary=previous_ltm_summary, new_episode=new_episode, ) diff --git a/src/feeding_deployment/preference_learning/profiles/alice.txt b/src/feeding_deployment/preference_learning/profiles/alice.txt new file mode 100644 index 00000000..1a92724f --- /dev/null +++ b/src/feeding_deployment/preference_learning/profiles/alice.txt @@ -0,0 +1,4 @@ +The user has severe upper-limb paralysis with very limited voluntary control of their arms (cannot press physical buttons). +They cannot lean forward due to lack of trunk control. However, they have good neck and head control and are able to +open their mouth wide and perform head gestures. They interact with the web interface through a reflective dot on their +pose, which is tracked by their personal device \ No newline at end of file diff --git a/src/feeding_deployment/simulation/world.py b/src/feeding_deployment/simulation/world.py index b399410f..e3d2442d 100644 --- a/src/feeding_deployment/simulation/world.py +++ b/src/feeding_deployment/simulation/world.py @@ -31,7 +31,7 @@ def __init__(self, scene_description: SceneDescription, use_gui: bool = True, ig # Create the PyBullet client. if use_gui: - self.physics_client_id = create_gui_connection(camera_yaw=180) + self.physics_client_id = create_gui_connection(camera_yaw=120) else: self.physics_client_id = p.connect(p.DIRECT) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_apply_preferences.py b/tests/test_apply_preferences.py new file mode 100644 index 00000000..5175219d --- /dev/null +++ b/tests/test_apply_preferences.py @@ -0,0 +1,583 @@ +"""Tests for Step 4: apply_preferences module (BT YAML writes + transfer mode). + +Run with: + PYTHONPATH=src python -m pytest tests/test_apply_preferences.py -v +""" + +from __future__ import annotations + +import shutil +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import yaml + +from feeding_deployment.integration.apply_preferences import ( + apply_bundle_to_behavior_trees, + apply_transfer_mode, + apply_microwave_preference, + apply_dip_preference, + apply_occlusion_preference, + _SPEED_MAP, + _CONFIRMATION_MAP, + _AUTOCONTINUE_MAP, + _OUTSIDE_MOUTH_DISTANCE_MAP, + _CONVEY_READY_MAP, + _INITIATE_TRANSFER_MAP, + _COMPLETE_TRANSFER_MAP, + _TRANSFER_MODE_MAP, + _SKEWERING_AXIS_MAP, + _MICROWAVE_TIME_MAP, + _RETRACT_MAP, + _dipping_depth_translate, + _microwave_duration_translate, + _load_yaml, +) +from feeding_deployment.preference_learning.methods.utils import PREF_FIELDS + +BT_SOURCE = ( + Path(__file__).resolve().parents[1] + / "src" + / "feeding_deployment" + / "actions" + / "behavior_trees" +) + + +@pytest.fixture() +def bt_dir(tmp_path: Path) -> Path: + """Copy all source BT YAMLs into a temp directory for isolated testing.""" + dest = tmp_path / "behavior_trees" + shutil.copytree(BT_SOURCE, dest) + return dest + + +def _load(bt_dir: Path, fname: str) -> dict: + return _load_yaml(bt_dir / fname) + + +def _get_param_value(data: dict, param_name: str): + for p in data.get("parameters", []): + if p["name"] == param_name: + return p["value"] + raise KeyError(f"parameter {param_name!r} not found") + + +# ------------------------------------------------------------------- +# Value translator unit tests +# ------------------------------------------------------------------- + + +class TestValueTranslators: + + def test_speed_map_covers_all_bundle_options(self): + assert set(_SPEED_MAP.keys()) == {"slow", "medium", "fast"} + + def test_confirmation_map(self): + assert _CONFIRMATION_MAP == {"yes": 1, "no": 0} + + def test_autocontinue_map_values_are_floats(self): + for v in _AUTOCONTINUE_MAP.values(): + assert isinstance(v, float) + + def test_outside_mouth_distance_values_in_range(self): + for v in _OUTSIDE_MOUTH_DISTANCE_MAP.values(): + assert 0.1 <= v <= 0.2 + + def test_convey_ready_maps_speech_plus_led_to_voice_led(self): + assert _CONVEY_READY_MAP["speech + LED"] == "voice_led" + + def test_initiate_transfer_map(self): + assert _INITIATE_TRANSFER_MAP["open mouth"] == "open_mouth" + assert _INITIATE_TRANSFER_MAP["autocontinue"] == "auto_timeout" + + def test_complete_transfer_map(self): + assert _COMPLETE_TRANSFER_MAP["perception"] == "sense" + assert _COMPLETE_TRANSFER_MAP["autocontinue"] == "auto_timeout" + + def test_skewering_axis_map(self): + assert _SKEWERING_AXIS_MAP["parallel to major axis"] == "horizontal" + assert _SKEWERING_AXIS_MAP["perpendicular to major axis"] == "vertical" + + def test_dipping_depth_translate(self): + assert _dipping_depth_translate("less") == 0.01 + assert _dipping_depth_translate("more") == 0.03 + assert _dipping_depth_translate("do not dip") is None + + def test_microwave_duration_translate(self): + assert _microwave_duration_translate("1 min") == 60.0 + assert _microwave_duration_translate("2 min") == 120.0 + assert _microwave_duration_translate("3 min") == 180.0 + assert _microwave_duration_translate("no microwave") is None + + def test_retract_map(self): + assert _RETRACT_MAP["yes"] == 1 + assert _RETRACT_MAP["no"] == 0 + + +# ------------------------------------------------------------------- +# apply_bundle_to_behavior_trees: YAML round-trip +# ------------------------------------------------------------------- + + +class TestApplyBundleToBehaviorTrees: + + def test_speed_written_to_all_yamls(self, bt_dir: Path): + bundle = {"robot_speed": "fast"} + warnings = apply_bundle_to_behavior_trees(bundle, bt_dir) + + for fname in bt_dir.glob("*.yaml"): + data = _load(bt_dir, fname.name) + assert _get_param_value(data, "Speed") == "high", ( + f"Speed not updated in {fname.name}" + ) + assert not warnings + + def test_speed_slow_maps_to_low(self, bt_dir: Path): + apply_bundle_to_behavior_trees({"robot_speed": "slow"}, bt_dir) + data = _load(bt_dir, "acquire_bite.yaml") + assert _get_param_value(data, "Speed") == "low" + + def test_confirmation_written(self, bt_dir: Path): + bundle = {"web_interface_confirmation": "no"} + apply_bundle_to_behavior_trees(bundle, bt_dir) + + ab = _load(bt_dir, "acquire_bite.yaml") + assert _get_param_value(ab, "TransferAskForConfirmation") == 0 + + td = _load(bt_dir, "transfer_drink.yaml") + assert _get_param_value(td, "AskForConfirmationInitiatingTransferSequence") == 0 + + tw = _load(bt_dir, "transfer_wipe.yaml") + assert _get_param_value(tw, "AskForConfirmationInitiatingTransferSequence") == 0 + + def test_autocontinue_written(self, bt_dir: Path): + bundle = {"wait_before_autocontinue_seconds": "1000 sec"} + apply_bundle_to_behavior_trees(bundle, bt_dir) + + for fname in ["acquire_bite.yaml", "transfer_utensil.yaml", "transfer_drink.yaml"]: + data = _load(bt_dir, fname) + assert _get_param_value(data, "TimeToWaitBeforeAutocontinue") == 1000.0 + + def test_outside_mouth_distance_near(self, bt_dir: Path): + bundle = {"outside_mouth_distance": "near"} + apply_bundle_to_behavior_trees(bundle, bt_dir) + + for fname in ["transfer_utensil.yaml", "transfer_drink.yaml", "transfer_wipe.yaml"]: + data = _load(bt_dir, fname) + assert _get_param_value(data, "OutsideMouthDistance") == 0.1 + + def test_outside_mouth_distance_not_applicable_skips(self, bt_dir: Path): + original = _get_param_value(_load(bt_dir, "transfer_utensil.yaml"), "OutsideMouthDistance") + bundle = {"outside_mouth_distance": "not applicable"} + apply_bundle_to_behavior_trees(bundle, bt_dir) + + after = _get_param_value(_load(bt_dir, "transfer_utensil.yaml"), "OutsideMouthDistance") + assert after == original + + def test_convey_ready_initiating(self, bt_dir: Path): + bundle = {"convey_robot_ready_for_initiating_transfer": "LED"} + apply_bundle_to_behavior_trees(bundle, bt_dir) + + for fname in ["transfer_utensil.yaml", "transfer_drink.yaml", "transfer_wipe.yaml"]: + data = _load(bt_dir, fname) + assert _get_param_value(data, "ReadyToInitiateTransferInteraction") == "led" + + def test_convey_ready_completing(self, bt_dir: Path): + bundle = {"convey_robot_ready_for_completing_transfer": "no cue"} + apply_bundle_to_behavior_trees(bundle, bt_dir) + + for fname in ["transfer_utensil.yaml", "transfer_drink.yaml", "transfer_wipe.yaml"]: + data = _load(bt_dir, fname) + assert _get_param_value(data, "ReadyForTransferInteraction") == "silent" + + def test_detect_initiate_per_tool(self, bt_dir: Path): + bundle = { + "detect_user_ready_for_initiating_transfer_feeding": "button", + "detect_user_ready_for_initiating_transfer_drinking": "autocontinue", + "detect_user_ready_for_initiating_transfer_wiping": "open mouth", + } + apply_bundle_to_behavior_trees(bundle, bt_dir) + + assert _get_param_value( + _load(bt_dir, "transfer_utensil.yaml"), "InitiateTransferInteraction" + ) == "button" + assert _get_param_value( + _load(bt_dir, "transfer_drink.yaml"), "InitiateTransferInteraction" + ) == "auto_timeout" + assert _get_param_value( + _load(bt_dir, "transfer_wipe.yaml"), "InitiateTransferInteraction" + ) == "open_mouth" + + def test_detect_complete_per_tool(self, bt_dir: Path): + bundle = { + "detect_user_completed_transfer_feeding": "perception", + "detect_user_completed_transfer_drinking": "button", + "detect_user_completed_transfer_wiping": "autocontinue", + } + apply_bundle_to_behavior_trees(bundle, bt_dir) + + assert _get_param_value( + _load(bt_dir, "transfer_utensil.yaml"), "TransferCompleteInteraction" + ) == "sense" + assert _get_param_value( + _load(bt_dir, "transfer_drink.yaml"), "TransferCompleteInteraction" + ) == "button" + assert _get_param_value( + _load(bt_dir, "transfer_wipe.yaml"), "TransferCompleteInteraction" + ) == "auto_timeout" + + def test_skewering_axis_parallel(self, bt_dir: Path): + apply_bundle_to_behavior_trees({"skewering_axis": "parallel to major axis"}, bt_dir) + data = _load(bt_dir, "acquire_bite.yaml") + assert _get_param_value(data, "SkeweringOrientation") == "horizontal" + + def test_skewering_axis_perpendicular(self, bt_dir: Path): + apply_bundle_to_behavior_trees({"skewering_axis": "perpendicular to major axis"}, bt_dir) + data = _load(bt_dir, "acquire_bite.yaml") + assert _get_param_value(data, "SkeweringOrientation") == "vertical" + + def test_dipping_preference_more(self, bt_dir: Path): + apply_bundle_to_behavior_trees({"bite_dipping_preference": "more"}, bt_dir) + data = _load(bt_dir, "acquire_bite.yaml") + assert _get_param_value(data, "FoodDippingDepth") == 0.03 + + def test_dipping_preference_less(self, bt_dir: Path): + apply_bundle_to_behavior_trees({"bite_dipping_preference": "less"}, bt_dir) + data = _load(bt_dir, "acquire_bite.yaml") + assert _get_param_value(data, "FoodDippingDepth") == 0.01 + + def test_dipping_preference_do_not_dip_skips(self, bt_dir: Path): + original = _get_param_value(_load(bt_dir, "acquire_bite.yaml"), "FoodDippingDepth") + apply_bundle_to_behavior_trees({"bite_dipping_preference": "do not dip"}, bt_dir) + after = _get_param_value(_load(bt_dir, "acquire_bite.yaml"), "FoodDippingDepth") + assert after == original + + def test_microwave_duration_2_min(self, bt_dir: Path): + apply_bundle_to_behavior_trees({"microwave_time": "2 min"}, bt_dir) + data = _load(bt_dir, "press_microwave_button.yaml") + assert _get_param_value(data, "MicrowaveDuration") == 120.0 + + def test_microwave_duration_1_min(self, bt_dir: Path): + apply_bundle_to_behavior_trees({"microwave_time": "1 min"}, bt_dir) + data = _load(bt_dir, "press_microwave_button.yaml") + assert _get_param_value(data, "MicrowaveDuration") == 60.0 + + def test_microwave_duration_3_min(self, bt_dir: Path): + apply_bundle_to_behavior_trees({"microwave_time": "3 min"}, bt_dir) + data = _load(bt_dir, "press_microwave_button.yaml") + assert _get_param_value(data, "MicrowaveDuration") == 180.0 + + def test_microwave_no_microwave_skips(self, bt_dir: Path): + original = _get_param_value(_load(bt_dir, "press_microwave_button.yaml"), "MicrowaveDuration") + apply_bundle_to_behavior_trees({"microwave_time": "no microwave"}, bt_dir) + after = _get_param_value(_load(bt_dir, "press_microwave_button.yaml"), "MicrowaveDuration") + assert after == original + + def test_retract_between_bites_yes(self, bt_dir: Path): + apply_bundle_to_behavior_trees({"retract_between_bites": "yes"}, bt_dir) + data = _load(bt_dir, "transfer_utensil.yaml") + assert _get_param_value(data, "RetractAfterTransfer") == 1 + + def test_retract_between_bites_no(self, bt_dir: Path): + apply_bundle_to_behavior_trees({"retract_between_bites": "no"}, bt_dir) + data = _load(bt_dir, "transfer_utensil.yaml") + assert _get_param_value(data, "RetractAfterTransfer") == 0 + + def test_retract_only_affects_utensil(self, bt_dir: Path): + apply_bundle_to_behavior_trees({"retract_between_bites": "yes"}, bt_dir) + for fname in ["transfer_drink.yaml", "transfer_wipe.yaml"]: + data = _load(bt_dir, fname) + with pytest.raises(KeyError): + _get_param_value(data, "RetractAfterTransfer") + + +class TestApplyBundleWarnings: + + def test_speech_plus_led_writes_voice_led(self, bt_dir: Path): + bundle = {"convey_robot_ready_for_initiating_transfer": "speech + LED"} + warnings = apply_bundle_to_behavior_trees(bundle, bt_dir) + assert not any("speech + LED" in w for w in warnings) + + for fname in ["transfer_utensil.yaml", "transfer_drink.yaml", "transfer_wipe.yaml"]: + data = _load(bt_dir, fname) + assert _get_param_value(data, "ReadyToInitiateTransferInteraction") == "voice_led" + + def test_missing_yaml_file_produces_warning(self, bt_dir: Path): + (bt_dir / "acquire_bite.yaml").unlink() + bundle = {"web_interface_confirmation": "yes"} + warnings = apply_bundle_to_behavior_trees(bundle, bt_dir) + assert any("not found" in w for w in warnings) + + def test_unknown_bundle_field_ignored(self, bt_dir: Path): + bundle = {"nonexistent_field": "whatever"} + warnings = apply_bundle_to_behavior_trees(bundle, bt_dir) + assert not warnings + + def test_empty_bundle_no_changes(self, bt_dir: Path): + before = _load(bt_dir, "acquire_bite.yaml") + speed_before = _get_param_value(before, "Speed") + + warnings = apply_bundle_to_behavior_trees({}, bt_dir) + assert not warnings + + after = _load(bt_dir, "acquire_bite.yaml") + assert _get_param_value(after, "Speed") == speed_before + + +class TestApplyBundleFullBundle: + """Apply a realistic full bundle and verify no errors.""" + + def test_full_bundle_applies_without_error(self, bt_dir: Path): + bundle = { + "robot_speed": "slow", + "web_interface_confirmation": "yes", + "wait_before_autocontinue_seconds": "100 sec", + "outside_mouth_distance": "far", + "convey_robot_ready_for_initiating_transfer": "speech", + "detect_user_ready_for_initiating_transfer_feeding": "open mouth", + "detect_user_ready_for_initiating_transfer_drinking": "button", + "detect_user_ready_for_initiating_transfer_wiping": "autocontinue", + "convey_robot_ready_for_completing_transfer": "LED", + "detect_user_completed_transfer_feeding": "perception", + "detect_user_completed_transfer_drinking": "autocontinue", + "detect_user_completed_transfer_wiping": "button", + "transfer_mode": "inside mouth transfer", + "microwave_time": "no microwave", + "occlusion_relevance": "do not consider occlusion", + "skewering_axis": "parallel to major axis", + "retract_between_bites": "yes", + "bite_dipping_preference": "do not dip", + } + warnings = apply_bundle_to_behavior_trees(bundle, bt_dir) + assert not any("Error" in w for w in warnings) + + data = _load(bt_dir, "acquire_bite.yaml") + assert _get_param_value(data, "Speed") == "low" + assert _get_param_value(data, "TransferAskForConfirmation") == 1 + assert _get_param_value(data, "TimeToWaitBeforeAutocontinue") == 100.0 + + +# ------------------------------------------------------------------- +# apply_transfer_mode +# ------------------------------------------------------------------- + + +class TestApplyTransferMode: + """Tests for apply_transfer_mode. + + When a TransferTool HLA is present in the map, the function imports + InsideMouthTransfer / OutsideMouthTransfer (heavy robot deps). We mock + those imports to avoid pulling in scipy etc. in the test environment. + When there is no HLA in the map, only the scene attribute is set. + """ + + def _make_scene(self, current_type: str = "outside") -> MagicMock: + scene = MagicMock() + scene.transfer_type = current_type + return scene + + def _make_hla(self) -> MagicMock: + hla = MagicMock() + hla.sim = MagicMock() + hla.robot_interface = None + hla.perception_interface = None + hla.rviz_interface = None + hla.no_waits = True + hla.head_perception_log_dir = Path("/tmp/test_head_log") + return hla + + def test_sets_transfer_type_inside_no_hla(self): + scene = self._make_scene("outside") + apply_transfer_mode( + {"transfer_mode": "inside mouth transfer"}, scene, {}, + ) + assert scene.transfer_type == "inside" + + def test_sets_transfer_type_outside_no_hla(self): + scene = self._make_scene("inside") + apply_transfer_mode( + {"transfer_mode": "outside mouth transfer"}, scene, {}, + ) + assert scene.transfer_type == "outside" + + def test_unknown_transfer_mode_raises(self): + scene = self._make_scene() + with pytest.raises(ValueError, match="Unknown transfer_mode"): + apply_transfer_mode({"transfer_mode": "sideways"}, scene, {}) + + def test_no_transfer_mode_in_bundle_is_noop(self): + scene = self._make_scene("outside") + apply_transfer_mode({}, scene, {}) + assert scene.transfer_type == "outside" + + def test_no_transfer_hla_only_sets_scene(self): + scene = self._make_scene("outside") + apply_transfer_mode( + {"transfer_mode": "inside mouth transfer"}, scene, {}, + ) + assert scene.transfer_type == "inside" + + +# ------------------------------------------------------------------- +# apply_microwave_preference +# ------------------------------------------------------------------- + + +class _FakeAtom: + """Minimal stand-in for GroundAtom(FoodHeated, []) so tests don't need + the heavy relational_structs import.""" + + def __init__(self, name: str) -> None: + self.name = name + + def __eq__(self, other: object) -> bool: + return isinstance(other, _FakeAtom) and self.name == other.name + + def __hash__(self) -> int: + return hash(self.name) + + +class TestApplyMicrowavePreference: + + def _food_heated(self) -> _FakeAtom: + return _FakeAtom("FoodHeated") + + def test_no_microwave_adds_food_heated(self): + atoms: set = set() + fh = self._food_heated() + result = apply_microwave_preference({"microwave_time": "no microwave"}, atoms, fh) + assert fh in atoms + assert result is None + + def test_1_min_removes_food_heated(self): + fh = self._food_heated() + atoms = {fh} + result = apply_microwave_preference({"microwave_time": "1 min"}, atoms, fh) + assert fh not in atoms + assert result == 60 + + def test_2_min_returns_120(self): + fh = self._food_heated() + atoms: set = set() + result = apply_microwave_preference({"microwave_time": "2 min"}, atoms, fh) + assert fh not in atoms + assert result == 120 + + def test_3_min_returns_180(self): + fh = self._food_heated() + atoms: set = set() + result = apply_microwave_preference({"microwave_time": "3 min"}, atoms, fh) + assert result == 180 + + def test_no_microwave_time_in_bundle_is_noop(self): + fh = self._food_heated() + atoms: set = set() + result = apply_microwave_preference({}, atoms, fh) + assert fh not in atoms + assert result is None + + def test_unknown_microwave_time_raises(self): + fh = self._food_heated() + with pytest.raises(ValueError, match="Unknown microwave_time"): + apply_microwave_preference({"microwave_time": "5 min"}, set(), fh) + + def test_microwave_time_map_covers_all_options(self): + assert set(_MICROWAVE_TIME_MAP.keys()) == {"no microwave", "1 min", "2 min", "3 min"} + + +# ------------------------------------------------------------------- +# apply_dip_preference +# ------------------------------------------------------------------- + + +class _FakeFlair: + """Minimal stand-in for FLAIR so tests don't need heavy imports.""" + + def __init__(self) -> None: + self._allow_dip = True + + def set_allow_dip(self, allow: bool) -> None: + self._allow_dip = allow + + +class TestApplyDipPreference: + + def test_do_not_dip_sets_false(self): + flair = _FakeFlair() + apply_dip_preference({"bite_dipping_preference": "do not dip"}, flair) + assert flair._allow_dip is False + + def test_less_keeps_true(self): + flair = _FakeFlair() + apply_dip_preference({"bite_dipping_preference": "less"}, flair) + assert flair._allow_dip is True + + def test_more_keeps_true(self): + flair = _FakeFlair() + apply_dip_preference({"bite_dipping_preference": "more"}, flair) + assert flair._allow_dip is True + + def test_missing_key_is_noop(self): + flair = _FakeFlair() + apply_dip_preference({}, flair) + assert flair._allow_dip is True + + def test_none_flair_is_noop(self): + apply_dip_preference({"bite_dipping_preference": "do not dip"}, None) + + +# ------------------------------------------------------------------- +# apply_occlusion_preference +# ------------------------------------------------------------------- + + +class _FakeFlairWithPreference: + """Minimal stand-in for FLAIR with get/set_preference.""" + + def __init__(self, initial: str | None = None) -> None: + self._preference = initial + + def get_preference(self): + return self._preference + + def set_preference(self, pref: str): + self._preference = pref + + +class TestApplyOcclusionPreference: + + def test_minimize_left_sets_hint(self): + flair = _FakeFlairWithPreference() + apply_occlusion_preference({"occlusion_relevance": "minimize left occlusion"}, flair) + assert "left" in flair._preference + + def test_minimize_front_sets_hint(self): + flair = _FakeFlairWithPreference() + apply_occlusion_preference({"occlusion_relevance": "minimize front occlusion"}, flair) + assert "front" in flair._preference + + def test_minimize_right_sets_hint(self): + flair = _FakeFlairWithPreference() + apply_occlusion_preference({"occlusion_relevance": "minimize right occlusion"}, flair) + assert "right" in flair._preference + + def test_do_not_consider_is_noop(self): + flair = _FakeFlairWithPreference() + apply_occlusion_preference({"occlusion_relevance": "do not consider occlusion"}, flair) + assert flair._preference is None + + def test_appends_to_existing_preference(self): + flair = _FakeFlairWithPreference("I prefer chicken") + apply_occlusion_preference({"occlusion_relevance": "minimize left occlusion"}, flair) + assert flair._preference.startswith("I prefer chicken") + assert "left" in flair._preference + + def test_missing_key_is_noop(self): + flair = _FakeFlairWithPreference() + apply_occlusion_preference({}, flair) + assert flair._preference is None + + def test_none_flair_is_noop(self): + apply_occlusion_preference({"occlusion_relevance": "minimize left occlusion"}, None) diff --git a/tests/test_preference_integration.py b/tests/test_preference_integration.py new file mode 100644 index 00000000..e61a9646 --- /dev/null +++ b/tests/test_preference_integration.py @@ -0,0 +1,946 @@ +"""Tests for preference-learning integration Steps 1-5 + terminal interaction. + +Run with: + PYTHONPATH=src python -m pytest tests/test_preference_integration.py -v + +Note: run.py and web_interface.py have heavy robot dependencies (tomsutils, +pybullet_helpers, cv2, rospy, ...) that are not available outside the robot +environment. The tests below therefore exercise the *logic* of Steps 1-5 +without importing those modules: preference_context.py is tested directly, +PredictionModel is tested with mocked OpenAI, and the runner/web-interface +contracts are tested against the same functions those classes delegate to. +""" + +from __future__ import annotations + +import importlib +import json +import queue +import sys +import threading +import types +from pathlib import Path +from unittest.mock import MagicMock, patch +from io import StringIO + +import pytest + +from feeding_deployment.integration.preference_context import ( + PREFERENCE_CONTEXT_KEYS, + build_preference_context, + validate_preference_context, +) +from feeding_deployment.preference_learning.config import ( + MEALS, + SETTINGS, + TIMES_OF_DAY, +) +from feeding_deployment.preference_learning.methods.prediction_model import ( + PREF_OPTIONS, +) +from feeding_deployment.preference_learning.methods.utils import PREF_FIELDS + +_PM_MODULE = "feeding_deployment.preference_learning.methods.prediction_model" + + +def _import_web_interface_module(): + """Import web_interface.py with lightweight dependency stubs.""" + module_name = "feeding_deployment.interfaces.web_interface" + sys.modules.pop(module_name, None) + + fake_modules = { + "cv2": types.ModuleType("cv2"), + "rospy": types.ModuleType("rospy"), + "sensor_msgs": types.ModuleType("sensor_msgs"), + "sensor_msgs.msg": types.ModuleType("sensor_msgs.msg"), + "std_msgs": types.ModuleType("std_msgs"), + "std_msgs.msg": types.ModuleType("std_msgs.msg"), + "cv_bridge": types.ModuleType("cv_bridge"), + "pybullet_helpers": types.ModuleType("pybullet_helpers"), + "pybullet_helpers.geometry": types.ModuleType("pybullet_helpers.geometry"), + "pybullet_helpers.joint": types.ModuleType("pybullet_helpers.joint"), + "scipy": types.ModuleType("scipy"), + "scipy.spatial": types.ModuleType("scipy.spatial"), + "scipy.spatial.transform": types.ModuleType("scipy.spatial.transform"), + "feeding_deployment.transparency.continuous_llm": types.ModuleType( + "feeding_deployment.transparency.continuous_llm" + ), + } + + fake_modules["sensor_msgs.msg"].CompressedImage = type("CompressedImage", (), {}) + fake_modules["std_msgs.msg"].String = type("String", (), {}) + fake_modules["cv_bridge"].CvBridge = type("CvBridge", (), {}) + fake_modules["pybullet_helpers.geometry"].Pose = type("Pose", (), {}) + fake_modules["pybullet_helpers.joint"].JointPositions = type( + "JointPositions", (), {} + ) + fake_modules["scipy.spatial.transform"].Rotation = type("Rotation", (), {}) + fake_modules["feeding_deployment.transparency.continuous_llm"].TransparencyContinuous = type( + "TransparencyContinuous", (), {} + ) + + with patch.dict(sys.modules, fake_modules): + return importlib.import_module(module_name) + + +# =================================================================== +# Step 1 — preference context: build, validate, runner contract +# =================================================================== + + +class TestBuildPreferenceContext: + + def test_valid_context(self): + ctx = build_preference_context(MEALS[0], SETTINGS[0], TIMES_OF_DAY[0]) + assert set(ctx.keys()) == set(PREFERENCE_CONTEXT_KEYS) + assert ctx["meal"] == MEALS[0] + assert ctx["setting"] == SETTINGS[0] + assert ctx["time_of_day"] == TIMES_OF_DAY[0] + + def test_strips_whitespace(self): + ctx = build_preference_context( + f" {MEALS[0]} ", f" {SETTINGS[0]} ", f" {TIMES_OF_DAY[0]} " + ) + assert ctx["meal"] == MEALS[0] + + def test_invalid_meal_raises(self): + with pytest.raises(ValueError, match="Unknown meal"): + build_preference_context("not_a_real_meal", SETTINGS[0], TIMES_OF_DAY[0]) + + def test_invalid_setting_raises(self): + with pytest.raises(ValueError, match="Unknown setting"): + build_preference_context(MEALS[0], "bad_setting", TIMES_OF_DAY[0]) + + def test_invalid_time_of_day_raises(self): + with pytest.raises(ValueError, match="Unknown time_of_day"): + build_preference_context(MEALS[0], SETTINGS[0], "midnight") + + def test_empty_meal_raises(self): + with pytest.raises(ValueError): + build_preference_context("", SETTINGS[0], TIMES_OF_DAY[0]) + + def test_whitespace_only_meal_raises(self): + with pytest.raises(ValueError): + build_preference_context(" ", SETTINGS[0], TIMES_OF_DAY[0]) + + +class TestValidatePreferenceContext: + + def test_missing_key_raises(self): + with pytest.raises(ValueError, match="missing required key"): + validate_preference_context({"meal": MEALS[0], "setting": SETTINGS[0]}) + + def test_non_string_value_raises(self): + with pytest.raises(ValueError, match="must be a non-empty string"): + validate_preference_context( + {"meal": 123, "setting": SETTINGS[0], "time_of_day": TIMES_OF_DAY[0]} + ) + + def test_valid_context_passes(self): + validate_preference_context( + {"meal": MEALS[0], "setting": SETTINGS[0], "time_of_day": TIMES_OF_DAY[0]} + ) + + +class TestRunnerPreferenceContextContract: + """Verify the contract that _Runner.ensure_preference_context and + _Runner.set_meal_preference_context implement, without importing run.py. + + The runner stores `self.preference_context` and: + - ensure_preference_context() raises RuntimeError when it is None + - set_meal_preference_context(m, s, t) delegates to build_preference_context + """ + + def test_ensure_raises_when_context_is_none(self): + preference_context = None + with pytest.raises(RuntimeError, match="preference_context is required"): + if preference_context is None: + raise RuntimeError( + "preference_context is required but unset. Each run must set it " + "explicitly (e.g. non-empty --pref_meal with --use_interface, or call " + "set_meal_preference_context(meal, setting, time_of_day) before run()). " + "Context is not loaded from or saved to disk; after a crash, supply it again." + ) + + def test_ensure_returns_when_context_is_set(self): + ctx = build_preference_context(MEALS[0], SETTINGS[0], TIMES_OF_DAY[0]) + preference_context = ctx + assert preference_context is not None + assert preference_context == ctx + + def test_set_meal_delegates_to_build(self): + ctx = build_preference_context(MEALS[0], SETTINGS[0], TIMES_OF_DAY[0]) + assert ctx["meal"] == MEALS[0] + assert ctx["setting"] == SETTINGS[0] + assert ctx["time_of_day"] == TIMES_OF_DAY[0] + + def test_set_meal_rejects_invalid_meal(self): + with pytest.raises(ValueError): + build_preference_context("bad_meal", SETTINGS[0], TIMES_OF_DAY[0]) + + def test_each_meal_setting_time_accepted(self): + """Smoke test: every canonical value from config builds successfully.""" + for m in MEALS: + for s in SETTINGS: + for t in TIMES_OF_DAY: + ctx = build_preference_context(m, s, t) + validate_preference_context(ctx) + + def test_interface_mode_context_defaults_are_valid(self): + """Interface mode can safely fall back to canonical default context values.""" + ctx = { + "meal": MEALS[0], + "setting": SETTINGS[0], + "time_of_day": TIMES_OF_DAY[0], + } + defaults = build_preference_context( + ctx["meal"], + ctx["setting"], + ctx["time_of_day"], + ) + assert defaults == ctx + + +# =================================================================== +# Step 2 — PredictionModel: instantiation + predict_bundle +# =================================================================== + + +def _fake_openai_response(bundle: dict[str, str]) -> MagicMock: + """Build a MagicMock that mimics an openai ChatCompletion response.""" + choice = MagicMock() + choice.message.content = json.dumps(bundle) + resp = MagicMock() + resp.choices = [choice] + return resp + + +def _default_bundle() -> dict[str, str]: + return {field: opts[0] for field, opts in PREF_OPTIONS.items()} + + +class TestPredictionModelPredictBundle: + + @patch(f"{_PM_MODULE}._resolve_api_key", return_value="fake-key") + @patch(f"{_PM_MODULE}.OpenAI") + def test_returns_all_fields(self, mock_openai_cls, _key, tmp_path): + bundle = _default_bundle() + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _fake_openai_response(bundle) + mock_openai_cls.return_value = mock_client + + from feeding_deployment.preference_learning.methods.prediction_model import ( + PredictionModel, + ) + + model = PredictionModel( + user="test_user", + physical_profile_label="test_label", + logs_dir=tmp_path / "pref", + physical_profile_description="Good arm control.", + use_long_term_memory=False, + use_episodic_memory=False, + ) + + ctx = {"meal": MEALS[0], "setting": SETTINGS[0], "time_of_day": TIMES_OF_DAY[0]} + result = model.predict_bundle(ctx, {}) + + assert isinstance(result, dict) + assert set(result.keys()) == set(PREF_FIELDS) + for field in PREF_FIELDS: + assert result[field] in PREF_OPTIONS[field], ( + f"{field}={result[field]} not in allowed options" + ) + + @patch(f"{_PM_MODULE}._resolve_api_key", return_value="fake-key") + @patch(f"{_PM_MODULE}.OpenAI") + def test_physical_profile_description_in_prompt(self, mock_openai_cls, _key, tmp_path): + bundle = _default_bundle() + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _fake_openai_response(bundle) + mock_openai_cls.return_value = mock_client + + desc = "This user has limited arm control and cannot press buttons." + + from feeding_deployment.preference_learning.methods.prediction_model import ( + PredictionModel, + ) + + model = PredictionModel( + user="test_user", + physical_profile_label="unused_label", + logs_dir=tmp_path / "pref", + physical_profile_description=desc, + use_long_term_memory=False, + use_episodic_memory=False, + ) + + ctx = {"meal": MEALS[0], "setting": SETTINGS[0], "time_of_day": TIMES_OF_DAY[0]} + model.predict_bundle(ctx, {}) + + call_args = mock_client.chat.completions.create.call_args + prompt = call_args.kwargs["messages"][1]["content"] + assert desc in prompt, ( + "Freeform physical-profile description must appear verbatim in the LLM prompt" + ) + + @patch(f"{_PM_MODULE}._resolve_api_key", return_value="fake-key") + @patch(f"{_PM_MODULE}.OpenAI") + def test_corrected_fields_override_prediction(self, mock_openai_cls, _key, tmp_path): + bundle = _default_bundle() + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _fake_openai_response(bundle) + mock_openai_cls.return_value = mock_client + + from feeding_deployment.preference_learning.methods.prediction_model import ( + PredictionModel, + ) + + model = PredictionModel( + user="test_user", + physical_profile_label="test_label", + logs_dir=tmp_path / "pref", + physical_profile_description="Test.", + use_long_term_memory=False, + use_episodic_memory=False, + ) + + override_field = PREF_FIELDS[0] + override_val = PREF_OPTIONS[override_field][-1] # last option + + ctx = {"meal": MEALS[0], "setting": SETTINGS[0], "time_of_day": TIMES_OF_DAY[0]} + result = model.predict_bundle(ctx, {override_field: override_val}) + assert result[override_field] == override_val + + @patch(f"{_PM_MODULE}._resolve_api_key", return_value="fake-key") + @patch(f"{_PM_MODULE}.OpenAI") + def test_malformed_llm_json_falls_back(self, mock_openai_cls, _key, tmp_path): + choice = MagicMock() + choice.message.content = "not valid json {{{" + resp = MagicMock() + resp.choices = [choice] + + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = resp + mock_openai_cls.return_value = mock_client + + from feeding_deployment.preference_learning.methods.prediction_model import ( + PredictionModel, + ) + + model = PredictionModel( + user="test_user", + physical_profile_label="test_label", + logs_dir=tmp_path / "pref", + physical_profile_description="Test.", + use_long_term_memory=False, + use_episodic_memory=False, + ) + + ctx = {"meal": MEALS[0], "setting": SETTINGS[0], "time_of_day": TIMES_OF_DAY[0]} + result = model.predict_bundle(ctx, {}) + + assert isinstance(result, dict) + assert len(result) == len(PREF_FIELDS) + for field in PREF_FIELDS: + assert result[field] in PREF_OPTIONS[field] + + @patch(f"{_PM_MODULE}._resolve_api_key", return_value="fake-key") + @patch(f"{_PM_MODULE}.OpenAI") + def test_predict_bundle_logs_to_disk(self, mock_openai_cls, _key, tmp_path): + bundle = _default_bundle() + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = _fake_openai_response(bundle) + mock_openai_cls.return_value = mock_client + + from feeding_deployment.preference_learning.methods.prediction_model import ( + PredictionModel, + ) + + model = PredictionModel( + user="test_user", + physical_profile_label="test_label", + logs_dir=tmp_path / "pref", + physical_profile_description="Good control.", + use_long_term_memory=False, + use_episodic_memory=False, + ) + + ctx = {"meal": MEALS[0], "setting": SETTINGS[0], "time_of_day": TIMES_OF_DAY[0]} + model.predict_bundle(ctx, {}) + + log_dir = tmp_path / "pref" / "test_user" / "prediction_model_llm_calls" + log_files = list(log_dir.glob("*.txt")) + assert len(log_files) == 1, "predict_bundle should write exactly one log file" + contents = log_files[0].read_text() + assert "===PROMPT===" in contents + assert "===RESPONSE===" in contents + + +# =================================================================== +# Step 3 — preference correction web interface contract +# =================================================================== + + +class TestPreferenceCorrectionWebInterfaceContract: + """Verify the live WebInterface preference-correction message flow.""" + + def test_preference_context_round_trip_sends_jump_then_data(self): + module = _import_web_interface_module() + interface = module.WebInterface.__new__(module.WebInterface) + defaults = { + "meal": MEALS[0], + "setting": SETTINGS[0], + "time_of_day": TIMES_OF_DAY[0], + } + sent_messages = [] + + interface.current_page = "task_selection" + interface._send_message = sent_messages.append + + def fake_get_required_message(condition): + response = { + "state": "preference_context_response", + "meal": MEALS[1], + "setting": SETTINGS[1], + "time_of_day": TIMES_OF_DAY[1], + } + assert condition(response) is True + return response + + interface.get_required_web_interface_message = fake_get_required_message + + with patch.object(module.time, "sleep") as mock_sleep: + returned = module.WebInterface.get_preference_context( + interface, + list(MEALS), + list(SETTINGS), + list(TIMES_OF_DAY), + defaults, + ) + + assert interface.current_page == "preference_context" + assert sent_messages == [ + {"state": "preference_context", "status": "jump"}, + { + "state": "preference_context_data", + "meals": list(MEALS), + "settings": list(SETTINGS), + "time_of_day": list(TIMES_OF_DAY), + "defaults": defaults, + }, + ] + mock_sleep.assert_called_once_with(0.5) + assert returned == { + "meal": MEALS[1], + "setting": SETTINGS[1], + "time_of_day": TIMES_OF_DAY[1], + } + + def test_preference_context_raises_when_page_exits_early(self): + module = _import_web_interface_module() + interface = module.WebInterface.__new__(module.WebInterface) + defaults = { + "meal": MEALS[0], + "setting": SETTINGS[0], + "time_of_day": TIMES_OF_DAY[0], + } + sent_messages = [] + + interface.current_page = "task_selection" + interface._send_message = sent_messages.append + interface.get_required_web_interface_message = lambda condition: None + + with patch.object(module.time, "sleep") as mock_sleep: + with pytest.raises(RuntimeError, match="Preference context exited before submission"): + module.WebInterface.get_preference_context( + interface, + list(MEALS), + list(SETTINGS), + list(TIMES_OF_DAY), + defaults, + ) + + assert interface.current_page == "preference_context" + assert sent_messages[0] == {"state": "preference_context", "status": "jump"} + assert sent_messages[1]["state"] == "preference_context_data" + mock_sleep.assert_called_once_with(0.5) + + def test_round_trip_sends_jump_then_data_and_returns_bundle(self): + module = _import_web_interface_module() + interface = module.WebInterface.__new__(module.WebInterface) + predicted = _default_bundle() + corrected_bundle = dict(predicted) + corrected_bundle[PREF_FIELDS[0]] = PREF_OPTIONS[PREF_FIELDS[0]][-1] + sent_messages = [] + + interface.current_page = "task_selection" + interface._send_message = sent_messages.append + + def fake_get_required_message(condition): + response = { + "state": "preference_correction_response", + "bundle": corrected_bundle, + } + assert condition(response) is True + return response + + interface.get_required_web_interface_message = fake_get_required_message + + with patch.object(module.time, "sleep") as mock_sleep: + returned = module.WebInterface.get_preference_corrections( + interface, + predicted, + dict(PREF_OPTIONS), + ) + + assert interface.current_page == "preference_correction" + assert sent_messages == [ + {"state": "preference_correction", "status": "jump"}, + { + "state": "preference_correction_data", + "predicted_bundle": predicted, + "options": dict(PREF_OPTIONS), + }, + ] + mock_sleep.assert_called_once_with(0.5) + assert returned == corrected_bundle + + def test_callback_accepts_status_less_preference_response(self, tmp_path): + module = _import_web_interface_module() + interface = module.WebInterface.__new__(module.WebInterface) + response_bundle = _default_bundle() + + interface.webapp_received_messages_log = tmp_path / "received.txt" + interface.task_selection_queue = queue.Queue() + interface.received_web_interface_messages = queue.Queue() + interface.explanation_lock = threading.Lock() + interface.current_page = "preference_correction" + interface.task_selection_jump = True + + msg = types.SimpleNamespace( + data=json.dumps({ + "state": "preference_correction_response", + "bundle": response_bundle, + }) + ) + + module.WebInterface._message_callback(interface, msg) + + assert interface.task_selection_jump is False + assert interface.task_selection_queue.empty() + queued = interface.received_web_interface_messages.get_nowait() + assert queued == { + "state": "preference_correction_response", + "bundle": response_bundle, + } + + def test_callback_accepts_status_less_preference_context_response(self, tmp_path): + module = _import_web_interface_module() + interface = module.WebInterface.__new__(module.WebInterface) + + interface.webapp_received_messages_log = tmp_path / "received.txt" + interface.task_selection_queue = queue.Queue() + interface.received_web_interface_messages = queue.Queue() + interface.explanation_lock = threading.Lock() + interface.current_page = "preference_context" + interface.task_selection_jump = True + + msg = types.SimpleNamespace( + data=json.dumps({ + "state": "preference_context_response", + "meal": MEALS[0], + "setting": SETTINGS[0], + "time_of_day": TIMES_OF_DAY[0], + }) + ) + + module.WebInterface._message_callback(interface, msg) + + assert interface.task_selection_jump is False + assert interface.task_selection_queue.empty() + queued = interface.received_web_interface_messages.get_nowait() + assert queued == { + "state": "preference_context_response", + "meal": MEALS[0], + "setting": SETTINGS[0], + "time_of_day": TIMES_OF_DAY[0], + } + + def test_returns_predicted_bundle_when_correction_page_exits_early(self): + module = _import_web_interface_module() + interface = module.WebInterface.__new__(module.WebInterface) + predicted = _default_bundle() + sent_messages = [] + + interface.current_page = "task_selection" + interface._send_message = sent_messages.append + interface.get_required_web_interface_message = lambda condition: None + + with patch.object(module.time, "sleep") as mock_sleep: + returned = module.WebInterface.get_preference_corrections( + interface, + predicted, + dict(PREF_OPTIONS), + ) + + assert interface.current_page == "preference_correction" + assert sent_messages[0] == {"state": "preference_correction", "status": "jump"} + assert sent_messages[1]["state"] == "preference_correction_data" + mock_sleep.assert_called_once_with(0.5) + assert returned == predicted + assert returned is not predicted + + def test_preference_correction_applied_confirmation_message(self): + module = _import_web_interface_module() + interface = module.WebInterface.__new__(module.WebInterface) + sent_messages = [] + + interface._send_message = sent_messages.append + + module.WebInterface.notify_preference_corrections_applied( + interface, + "Preferences were applied successfully.", + ) + + assert sent_messages == [ + { + "state": "preference_correction_applied", + "message": "Preferences were applied successfully.", + } + ] + + +class TestCorrectedDiffLogic: + """Validates the diff logic used in _Runner.run() after the correction + round-trip: corrected = {k: v for k, v in user_bundle.items() + if v != predicted_bundle.get(k)}.""" + + def test_corrections_detected(self): + predicted = _default_bundle() + user_bundle = dict(predicted) + + field_a, field_b = PREF_FIELDS[0], PREF_FIELDS[1] + user_bundle[field_a] = PREF_OPTIONS[field_a][-1] + user_bundle[field_b] = PREF_OPTIONS[field_b][-1] + + corrected = { + k: v for k, v in user_bundle.items() if v != predicted.get(k) + } + + if PREF_OPTIONS[field_a][0] != PREF_OPTIONS[field_a][-1]: + assert field_a in corrected + if PREF_OPTIONS[field_b][0] != PREF_OPTIONS[field_b][-1]: + assert field_b in corrected + + def test_no_corrections_means_empty(self): + predicted = _default_bundle() + user_bundle = dict(predicted) + + corrected = { + k: v for k, v in user_bundle.items() if v != predicted.get(k) + } + assert corrected == {} + + def test_ground_truth_has_all_fields(self): + predicted = _default_bundle() + user_bundle = dict(predicted) + user_bundle[PREF_FIELDS[0]] = PREF_OPTIONS[PREF_FIELDS[0]][-1] + + ground_truth = user_bundle + for field in PREF_FIELDS: + assert field in ground_truth + + def test_corrected_values_are_valid_options(self): + predicted = _default_bundle() + user_bundle = dict(predicted) + for field in PREF_FIELDS[:3]: + user_bundle[field] = PREF_OPTIONS[field][-1] + + corrected = { + k: v for k, v in user_bundle.items() if v != predicted.get(k) + } + for field, val in corrected.items(): + assert val in PREF_OPTIONS[field] + + +class TestPrefOptionsConsistency: + """Sanity checks on PREF_OPTIONS / PREF_FIELDS configuration.""" + + def test_fields_match_options_keys(self): + assert set(PREF_FIELDS) == set(PREF_OPTIONS.keys()) + + def test_every_field_has_at_least_two_options(self): + for field, opts in PREF_OPTIONS.items(): + assert len(opts) >= 2, f"{field} has fewer than 2 options" + + def test_no_empty_option_strings(self): + for field, opts in PREF_OPTIONS.items(): + for opt in opts: + assert isinstance(opt, str) and opt.strip(), ( + f"{field} has empty/whitespace option" + ) + + +# =================================================================== +# Step 5 — Learn: PredictionModel.next_day + update wiring +# =================================================================== + + +class TestNextDay: + """Verify PredictionModel.next_day auto-detects the next unused day.""" + + @patch(f"{_PM_MODULE}._resolve_api_key", return_value="fake-key") + @patch(f"{_PM_MODULE}.OpenAI") + def test_empty_logs_returns_1(self, mock_openai_cls, _key, tmp_path): + mock_openai_cls.return_value = MagicMock() + from feeding_deployment.preference_learning.methods.prediction_model import ( + PredictionModel, + ) + model = PredictionModel( + user="u", physical_profile_label="p", + logs_dir=tmp_path / "pref", + use_long_term_memory=False, use_episodic_memory=False, + ) + assert model.next_day() == 1 + + @patch(f"{_PM_MODULE}._resolve_api_key", return_value="fake-key") + @patch(f"{_PM_MODULE}.OpenAI") + def test_after_three_days_returns_4(self, mock_openai_cls, _key, tmp_path): + mock_openai_cls.return_value = MagicMock() + from feeding_deployment.preference_learning.methods.prediction_model import ( + PredictionModel, + ) + model = PredictionModel( + user="u", physical_profile_label="p", + logs_dir=tmp_path / "pref", + use_long_term_memory=False, use_episodic_memory=False, + ) + for d in [1, 2, 3]: + (model.working_memory_dir / f"day_{d:04d}.json").write_text("{}") + assert model.next_day() == 4 + + @patch(f"{_PM_MODULE}._resolve_api_key", return_value="fake-key") + @patch(f"{_PM_MODULE}.OpenAI") + def test_gap_in_days_uses_max(self, mock_openai_cls, _key, tmp_path): + mock_openai_cls.return_value = MagicMock() + from feeding_deployment.preference_learning.methods.prediction_model import ( + PredictionModel, + ) + model = PredictionModel( + user="u", physical_profile_label="p", + logs_dir=tmp_path / "pref", + use_long_term_memory=False, use_episodic_memory=False, + ) + for d in [1, 5]: + (model.working_memory_dir / f"day_{d:04d}.json").write_text("{}") + assert model.next_day() == 6 + + +class TestUpdateWritesLogs: + """Verify PredictionModel.update writes per-day JSON files.""" + + @patch(f"{_PM_MODULE}._resolve_api_key", return_value="fake-key") + @patch(f"{_PM_MODULE}.OpenAI") + def test_update_creates_working_memory_log(self, mock_openai_cls, _key, tmp_path): + mock_openai_cls.return_value = MagicMock() + from feeding_deployment.preference_learning.methods.prediction_model import ( + PredictionModel, + ) + model = PredictionModel( + user="u", physical_profile_label="p", + logs_dir=tmp_path / "pref", + use_long_term_memory=False, use_episodic_memory=False, + ) + ctx = {"meal": MEALS[0], "setting": SETTINGS[0], "time_of_day": TIMES_OF_DAY[0]} + bundle = _default_bundle() + corrected = {PREF_FIELDS[0]: PREF_OPTIONS[PREF_FIELDS[0]][-1]} + + model.update(day=1, context=ctx, corrected=corrected, ground_truth_bundle=bundle) + + log_file = model.working_memory_dir / "day_0001.json" + assert log_file.exists() + data = json.loads(log_file.read_text()) + assert data["day"] == 1 + assert data["context"] == ctx + assert data["corrected"] == corrected + + @patch(f"{_PM_MODULE}._resolve_api_key", return_value="fake-key") + @patch(f"{_PM_MODULE}.OpenAI") + def test_update_increments_next_day(self, mock_openai_cls, _key, tmp_path): + mock_openai_cls.return_value = MagicMock() + from feeding_deployment.preference_learning.methods.prediction_model import ( + PredictionModel, + ) + model = PredictionModel( + user="u", physical_profile_label="p", + logs_dir=tmp_path / "pref", + use_long_term_memory=False, use_episodic_memory=False, + ) + ctx = {"meal": MEALS[0], "setting": SETTINGS[0], "time_of_day": TIMES_OF_DAY[0]} + bundle = _default_bundle() + + assert model.next_day() == 1 + model.update(day=1, context=ctx, corrected={}, ground_truth_bundle=bundle) + assert model.next_day() == 2 + model.update(day=2, context=ctx, corrected={}, ground_truth_bundle=bundle) + assert model.next_day() == 3 + + @patch(f"{_PM_MODULE}._resolve_api_key", return_value="fake-key") + @patch(f"{_PM_MODULE}.OpenAI") + def test_update_with_ltm_creates_ltm_log(self, mock_openai_cls, _key, tmp_path): + mock_client = MagicMock() + choice = MagicMock() + choice.message.content = json.dumps({"summary": "test"}) + resp = MagicMock() + resp.choices = [choice] + mock_client.chat.completions.create.return_value = resp + mock_openai_cls.return_value = mock_client + + from feeding_deployment.preference_learning.methods.prediction_model import ( + PredictionModel, + ) + model = PredictionModel( + user="u", physical_profile_label="p", + logs_dir=tmp_path / "pref", + use_long_term_memory=True, use_episodic_memory=False, + physical_profile_description="Test profile.", + ) + ctx = {"meal": MEALS[0], "setting": SETTINGS[0], "time_of_day": TIMES_OF_DAY[0]} + bundle = _default_bundle() + + model.update(day=1, context=ctx, corrected={}, ground_truth_bundle=bundle) + + ltm_file = tmp_path / "pref" / "u" / "long_term_memory" / "day_0001.json" + assert ltm_file.exists() + data = json.loads(ltm_file.read_text()) + assert data["day"] == 1 + assert "episode_text" in data + + @patch(f"{_PM_MODULE}._resolve_api_key", return_value="fake-key") + @patch(f"{_PM_MODULE}.OpenAI") + def test_update_with_em_creates_em_log(self, mock_openai_cls, _key, tmp_path): + mock_client = MagicMock() + mock_client.embeddings.create.return_value = MagicMock( + data=[MagicMock(embedding=[0.1] * 1536)] + ) + mock_openai_cls.return_value = mock_client + + from feeding_deployment.preference_learning.methods.prediction_model import ( + PredictionModel, + ) + model = PredictionModel( + user="u", physical_profile_label="p", + logs_dir=tmp_path / "pref", + use_long_term_memory=False, use_episodic_memory=True, + ) + ctx = {"meal": MEALS[0], "setting": SETTINGS[0], "time_of_day": TIMES_OF_DAY[0]} + bundle = _default_bundle() + + model.update(day=1, context=ctx, corrected={}, ground_truth_bundle=bundle) + + em_file = tmp_path / "pref" / "u" / "episodic_memory" / "day_0001.json" + assert em_file.exists() + data = json.loads(em_file.read_text()) + assert data["day"] == 1 + assert "episode_text" in data + + +# =================================================================== +# Terminal interaction: context collection + preference correction +# =================================================================== + + +_TERMINAL_MODULE = "feeding_deployment.integration.terminal_preferences" + + +class TestTerminalCollectContext: + + @patch("builtins.input", side_effect=["1", "1", "1"]) + def test_picks_first_options(self, _mock_input): + from feeding_deployment.integration.terminal_preferences import ( + terminal_collect_context, + ) + ctx = terminal_collect_context() + assert ctx["meal"] == MEALS[0] + assert ctx["setting"] == SETTINGS[0] + assert ctx["time_of_day"] == TIMES_OF_DAY[0] + + @patch("builtins.input", side_effect=["2", "3", "2"]) + def test_picks_specific_options(self, _mock_input): + from feeding_deployment.integration.terminal_preferences import ( + terminal_collect_context, + ) + ctx = terminal_collect_context() + assert ctx["meal"] == MEALS[1] + assert ctx["setting"] == SETTINGS[2] + assert ctx["time_of_day"] == TIMES_OF_DAY[1] + + @patch("builtins.input", side_effect=["bad", "0", "1", "1", "1"]) + def test_invalid_input_retries(self, _mock_input): + from feeding_deployment.integration.terminal_preferences import ( + terminal_collect_context, + ) + ctx = terminal_collect_context() + assert ctx["meal"] == MEALS[0] + + +class TestTerminalCorrectPreferences: + + @patch("builtins.input", side_effect=[""] * len(PREF_FIELDS)) + def test_accept_all_returns_predicted(self, _mock_input): + from feeding_deployment.integration.terminal_preferences import ( + terminal_correct_preferences, + ) + predicted = _default_bundle() + result = terminal_correct_preferences(predicted, dict(PREF_OPTIONS)) + assert result == predicted + + def test_change_one_field(self): + from feeding_deployment.integration.terminal_preferences import ( + terminal_correct_preferences, + ) + predicted = _default_bundle() + target_field = PREF_FIELDS[0] + target_opts = PREF_OPTIONS[target_field] + new_idx = len(target_opts) # pick the last option + + inputs = [] + for field in PREF_FIELDS: + if field == target_field: + inputs.append(str(new_idx)) + else: + inputs.append("") + + with patch("builtins.input", side_effect=inputs): + result = terminal_correct_preferences(predicted, dict(PREF_OPTIONS)) + + if target_opts[0] != target_opts[-1]: + assert result[target_field] == target_opts[-1] + assert result[target_field] != predicted[target_field] + + for field in PREF_FIELDS: + if field != target_field: + assert result[field] == predicted[field] + + def test_invalid_input_retries_then_accepts(self): + from feeding_deployment.integration.terminal_preferences import ( + terminal_correct_preferences, + ) + predicted = _default_bundle() + inputs = ["bad", "0", ""] # two bad inputs then accept for first field + inputs += [""] * (len(PREF_FIELDS) - 1) + + with patch("builtins.input", side_effect=inputs): + result = terminal_correct_preferences(predicted, dict(PREF_OPTIONS)) + assert result == predicted + + @patch("builtins.input", side_effect=[""] * len(PREF_FIELDS)) + def test_no_changes_means_empty_corrections(self, _mock_input): + from feeding_deployment.integration.terminal_preferences import ( + terminal_correct_preferences, + ) + predicted = _default_bundle() + result = terminal_correct_preferences(predicted, dict(PREF_OPTIONS)) + corrected = {k: v for k, v in result.items() if v != predicted.get(k)} + assert corrected == {}