Skip to content

Commit 3ddfdb8

Browse files
committed
add check_policy node
1 parent aa0d5b7 commit 3ddfdb8

5 files changed

Lines changed: 119 additions & 85 deletions

File tree

bt_nodes/configuration/src/configuration/init_receptionist.cpp

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ InitReceptionist::InitReceptionist(
2323
{
2424
config().blackboard->get("node", node_);
2525

26-
node_->declare_parameter("cam_frame", "head_front_camera_rgb_optical_frame");
26+
node_->declare_parameter("cam_frame", "head_front_camera_color_optical_frame");
2727
node_->declare_parameter("manipulation_frame", "base_link");
2828
// node_->declare_parameter("party_wp", std::vector<double>{0.0, 0.0, 0.0});
2929
// node_->declare_parameter("entrance_wp", std::vector<double>{0.0, 0.0, 0.0});
3030
node_->declare_parameter("host_name", "John Doe");
3131
node_->declare_parameter("host_drink", "beer");
32-
// node_->declare_parameter("waypoints_names", std::vector<std::string>{});
32+
node_->declare_parameter("waypoints_names", std::vector<std::string>{});
3333

3434
tf_buffer_ = std::make_shared<tf2_ros::Buffer>(node_->get_clock());
3535
tf_listener_ = std::make_shared<tf2_ros::TransformListener>(*tf_buffer_);
@@ -43,47 +43,58 @@ BT::NodeStatus InitReceptionist::tick()
4343

4444
if (
4545
node_->has_parameter("cam_frame") && node_->has_parameter("manipulation_frame") &&
46-
node_->has_parameter("host_name") && node_->has_parameter("host_drink"))
47-
// && node_->has_parameter("waypoints_names"))
46+
node_->has_parameter("host_name") && node_->has_parameter("host_drink")
47+
&& node_->has_parameter("waypoints_names"))
4848
// node_->has_parameter("party_wp") && node_->has_parameter("entrance_wp"))
4949
{
5050
node_->get_parameter("cam_frame", cam_frame_);
5151
node_->get_parameter("manipulation_frame", manipulation_frame_);
5252
node_->get_parameter("host_name", host_name_);
5353
node_->get_parameter("host_drink", host_drink_);
54-
// node_->get_parameter("waypoints_names", wp_names_);
55-
56-
// for (auto wp : wp_names_) {
57-
// node_->declare_parameter("waypoints." + wp, std::vector<double>());
58-
// std::vector<double> wp_pos;
59-
// node_->get_parameter("waypoints." + wp, wp_pos);
60-
// geometry_msgs::msg::TransformStamped transform_msg;
61-
// tf2::Quaternion q;
62-
63-
// if (wp.find("entrance")!=std::string::npos) {
64-
// setOutput("entrance_wp", wp);
65-
// } else if (wp.find("party")!=std::string::npos) {
66-
// setOutput("party_wp", wp);
67-
// }
68-
69-
// q.setRPY(0, 0, wp_pos[2]);
70-
// transform_msg.header.frame_id = "map";
71-
72-
// transform_msg.child_frame_id = wp;
73-
// transform_msg.transform.translation.x = wp_pos[0];
74-
// transform_msg.transform.translation.y = wp_pos[1];
75-
// transform_msg.transform.rotation = tf2::toMsg(q);
76-
// tf_static_broadcaster_->sendTransform(transform_msg);
77-
// rclcpp::spin_some(node_);
78-
// }
54+
node_->get_parameter("waypoints_names", wp_names_);
55+
56+
RCLCPP_INFO(
57+
node_->get_logger(), "Waypoints to be initialized: [%s]",
58+
std::to_string(wp_names_.size()).c_str());
59+
60+
for (auto wp : wp_names_) {
61+
node_->declare_parameter("waypoints." + wp, std::vector<double>());
62+
std::vector<double> wp_pos;
63+
node_->get_parameter("waypoints." + wp, wp_pos);
64+
RCLCPP_INFO(
65+
node_->get_logger(), "Waypoint [%s] position: x: %.2f, y: %.2f, yaw: %.2f", wp.c_str(),
66+
wp_pos[0], wp_pos[1], wp_pos[2]);
67+
geometry_msgs::msg::TransformStamped transform_msg;
68+
tf2::Quaternion q;
69+
70+
if (wp.find("entrance")!=std::string::npos) {
71+
setOutput("entrance_wp", wp);
72+
RCLCPP_INFO(
73+
node_->get_logger(), "Entrance waypoint set to frame: [%s]", wp.c_str());
74+
} else if (wp.find("party")!=std::string::npos) {
75+
setOutput("party_wp", wp);
76+
}
77+
78+
q.setRPY(0, 0, wp_pos[2]);
79+
transform_msg.header.frame_id = "map";
80+
81+
transform_msg.child_frame_id = wp;
82+
transform_msg.transform.translation.x = wp_pos[0];
83+
transform_msg.transform.translation.y = wp_pos[1];
84+
transform_msg.transform.rotation = tf2::toMsg(q);
85+
tf_static_broadcaster_->sendTransform(transform_msg);
86+
RCLCPP_INFO(
87+
node_->get_logger(), "Published static transform for waypoint [%s]", wp.c_str());
88+
rclcpp::spin_some(node_->get_node_base_interface());
89+
}
7990

8091
geometry_msgs::msg::TransformStamped transform_msg;
8192

8293
transform_msg.header.frame_id = "base_link";
8394

8495
transform_msg.child_frame_id = "attention_home";
8596
transform_msg.transform.translation.x = 1.5;
86-
transform_msg.transform.translation.z = 1.5;
97+
transform_msg.transform.translation.z = 0.5;
8798
tf_static_broadcaster_->sendTransform(transform_msg);
8899
rclcpp::spin_some(node_->get_node_base_interface());
89100

bt_nodes/hri/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ find_package(std_srvs REQUIRED)
2121
find_package(sensor_msgs REQUIRED)
2222
find_package(gpsr_msgs REQUIRED)
2323
find_package(audio_common_msgs REQUIRED)
24+
find_package(perception_system_interfaces REQUIRED)
2425
set(CMAKE_CXX_STANDARD 17)
2526

2627
set(dependencies
@@ -38,6 +39,7 @@ set(dependencies
3839
gpsr_msgs
3940
std_srvs
4041
audio_common_msgs
42+
perception_system_interfaces
4143
)
4244

4345
include_directories(include)

bt_nodes/hri/include/hri/check_policy.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "rclcpp/rclcpp.hpp"
2727
#include "rclcpp_cascade_lifecycle/rclcpp_cascade_lifecycle.hpp"
2828
#include "std_msgs/msg/int8.hpp"
29+
#include "perception_system_interfaces/msg/detection_array.hpp"
2930

3031
namespace dialog
3132
{
@@ -56,8 +57,8 @@ class CheckPolicy
5657
std::string image_topic_;
5758
bool value_;
5859
sensor_msgs::msg::Image::SharedPtr image_;
59-
rclcpp::Subscription<sensor_msgs::msg::Image>::SharedPtr image_sub_;
60-
void image_callback(const sensor_msgs::msg::Image::SharedPtr msg);
60+
rclcpp::Subscription<perception_system_interfaces::msg::DetectionArray>::SharedPtr image_sub_;
61+
void image_callback(const perception_system_interfaces::msg::DetectionArray::SharedPtr msg);
6162

6263
};
6364

bt_nodes/hri/src/hri/check_policy.cpp

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "llama_msgs/action/generate_response.hpp"
2424
#include "std_msgs/msg/int8.hpp"
2525

26+
2627
namespace dialog
2728
{
2829

@@ -38,70 +39,81 @@ CheckPolicy::CheckPolicy(
3839
xml_tag_name, action_name, conf)
3940
{
4041
getInput("image_topic", image_topic_);
41-
image_sub_ = node_->create_subscription<sensor_msgs::msg::Image>(
42+
image_sub_ = node_->create_subscription<perception_system_interfaces::msg::DetectionArray>(
4243
image_topic_, 10, std::bind(&CheckPolicy::image_callback, this, _1));
4344
}
4445

4546
void CheckPolicy::on_tick()
4647
{
4748
rclcpp::spin_some(node_->get_node_base_interface());
4849
RCLCPP_DEBUG(node_->get_logger(), "CheckPolicy ticked");
50+
RCLCPP_INFO(node_->get_logger(), "CheckPolicy ticked");
4951
if (!image_) {
5052
RCLCPP_ERROR(node_->get_logger(), "No image received");
53+
RCLCPP_INFO(node_->get_logger(), "No image received, setting to IDLE");
5154
setStatus(BT::NodeStatus::IDLE);
5255
return;
5356
}
57+
RCLCPP_INFO(node_->get_logger(), "Image received, proceeding with CheckPolicy");
5458

5559
std::string text_;
5660
getInput("question", text_);
5761

58-
std::string prompt_ = text_ + ". Please answer only with 'yes' or 'no'";
62+
std::string prompt_ = text_;
5963
goal_.prompt = prompt_;
6064
goal_.images.push_back(*image_);
6165
goal_.reset = true;
6266
goal_.sampling_config.temp = 0.0;
63-
goal_.sampling_config.grammar =
64-
R"(root ::= object
65-
value ::= object | array | string | number | ("true" | "false" | "null") ws
66-
67-
object ::=
68-
"{" ws (
69-
string ":" ws value
70-
("," ws string ":" ws value)*
71-
)? "}" ws
72-
73-
array ::=
74-
"[" ws (
75-
value
76-
("," ws value)*
77-
)? "]" ws
78-
79-
string ::=
80-
"\"" (
81-
[^"\\] |
82-
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
83-
)* "\"" ws
84-
85-
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
86-
87-
# Optional space: by convention, applied in this grammar after literal chars when allowed
88-
ws ::= ([ \t\n] ws)?)";
67+
// goal_.sampling_config.grammar =
68+
// R"(root ::= object
69+
// value ::= object | array | string | number | ("true" | "false" | "null") ws
70+
71+
// object ::=
72+
// "{" ws (
73+
// string ":" ws value
74+
// ("," ws string ":" ws value)*
75+
// )? "}" ws
76+
77+
// array ::=
78+
// "[" ws (
79+
// value
80+
// ("," ws value)*
81+
// )? "]" ws
82+
83+
// string ::=
84+
// "\"" (
85+
// [^"\\] |
86+
// "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
87+
// )* "\"" ws
88+
89+
// number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
90+
91+
// # Optional space: by convention, applied in this grammar after literal chars when allowed
92+
// ws ::= ([ \t\n] ws)?)";
8993
}
9094

91-
void CheckPolicy::image_callback(const sensor_msgs::msg::Image::SharedPtr msg)
95+
void CheckPolicy::image_callback(const perception_system_interfaces::msg::DetectionArray::SharedPtr msg)
9296
{
93-
image_ = msg;
97+
image_ = std::make_shared<sensor_msgs::msg::Image>(msg->source_img);
98+
RCLCPP_INFO(node_->get_logger(), "Image received in CheckPolicy");
9499
}
95100

96101
BT::NodeStatus CheckPolicy::on_success()
97102
{
98103
fprintf(stderr, "%s\n", result_.result->response.text.c_str());
104+
RCLCPP_INFO(node_->get_logger(), "CheckPolicy succeeded");
105+
RCLCPP_INFO(
106+
node_->get_logger(), "LLM response: %s",
107+
result_.result->response.text.c_str());
99108

100109
if (result_.result->response.text.empty() || result_.result->response.text == "{}") {
101110
return BT::NodeStatus::FAILURE;
102111
}
103112
std::string answer = result_.result->response.text;
104113
setOutput("output_text", answer);
114+
RCLCPP_INFO(
115+
node_->get_logger(), "CheckPolicy extracted answer: %s",
116+
answer.c_str());
105117

106118
answer.erase(
107119
std::remove_if(
@@ -114,15 +126,7 @@ BT::NodeStatus CheckPolicy::on_success()
114126
if (answer.empty()) {
115127
return BT::NodeStatus::FAILURE;
116128
}
117-
if (answer.find("yes") != std::string::npos) {
118-
value_ = true;
119-
} else if (answer.find("no") != std::string::npos) {
120-
value_ = false;
121-
} else {
122-
RCLCPP_ERROR(node_->get_logger(), "Not a valid answer: %s", answer.c_str());
123-
return BT::NodeStatus::FAILURE;
124-
}
125-
setOutput("output", value_);
129+
126130
return BT::NodeStatus::SUCCESS;
127131
}
128132

@@ -131,7 +135,7 @@ BT::NodeStatus CheckPolicy::on_success()
131135
BT_REGISTER_NODES(factory)
132136
{
133137
BT::NodeBuilder builder = [](const std::string & name, const BT::NodeConfiguration & config) {
134-
return std::make_unique<dialog::CheckPolicy>(name, "/llava/generate_response", config);
138+
return std::make_unique<dialog::CheckPolicy>(name, "/llama/generate_response", config);
135139
};
136140

137141
factory.registerBuilder<dialog::CheckPolicy>("CheckPolicy", builder);

robocup_bringup/launch/dialog.launch.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,22 +54,37 @@ def generate_launch_description():
5454
# model_filename="Spaetzle-v60-7b_Q4_0.gguf",
5555

5656
# comment this for GPSR:
57-
model_repo='TheBloke/Marcoroni-7B-v3-GGUF',
58-
model_filename='marcoroni-7b-v3.Q3_K_L.gguf',
57+
model_repo='qwen/Qwen2.5-Coder-7B-Instruct-GGUF',
58+
model_filename='qwen2.5-coder-7b-instruct-q4_k_m-00001-of-00002.gguf',
59+
system_prompt_type= "ChatML"
60+
61+
)
62+
63+
llava_cmd = create_llama_launch(
64+
use_llava=True,
65+
n_ctx=2048,
66+
n_batch=256,
67+
n_gpu_layers=23,
68+
n_threads=4,
69+
n_predict=-1,
70+
71+
# uncomment this for GPSR:
72+
# model_repo="cstr/Spaetzle-v60-7b-Q4_0-GGUF",
73+
# model_filename="Spaetzle-v60-7b_Q4_0.gguf",
74+
75+
# comment this for GPSR:
76+
model_repo='bartowski/Qwen2-VL-2B-Instruct-GGUF',
77+
model_filename='Qwen2-VL-2B-Instruct-Q4_K_M.gguf',
78+
mmproj_repo= "bartowski/Qwen2-VL-2B-Instruct-GGUF",
79+
mmproj_filename= "mmproj-Qwen2-VL-2B-Instruct-f16.gguf",
80+
system_prompt_type= "ChatML"
5981

60-
prefix='\n\n### Instruction:\n',
61-
suffix='\n\n### Response:\n',
62-
stopping_words=["\n\n\n\n"],
6382
)
6483

6584
whisper_cmd = IncludeLaunchDescription(
6685
PythonLaunchDescriptionSource(
6786
os.path.join(whisper_dir, 'launch', 'whisper.launch.py')
68-
),
69-
launch_arguments={
70-
'model_repo': model_repo,
71-
'model_filename': model_filename
72-
}.items()
87+
)
7388
)
7489

7590
audio_common_tts_node = Node(
@@ -99,10 +114,11 @@ def generate_launch_description():
99114
ld = LaunchDescription()
100115
ld.add_action(declare_model_repo_cmd)
101116
ld.add_action(declare_model_filename_cmd)
102-
ld.add_action(whisper_cmd)
103-
ld.add_action(llama_cmd)
117+
#ld.add_action(whisper_cmd)
118+
#ld.add_action(llama_cmd)
119+
ld.add_action(llava_cmd)
104120
ld.add_action(audio_common_tts_node)
105121
ld.add_action(audio_common_player_node)
106-
ld.add_action(music_player_node)
122+
#ld.add_action(music_player_node)
107123

108124
return ld

0 commit comments

Comments
 (0)