Skip to content

Commit 0f69e19

Browse files
authored
refactor: use AST pattern matching instead of string matching for sta… (#262)
* refactor: use AST pattern matching instead of string matching for statement detection Fixes #159 The codebase was using starts_with() string checks to detect statement types (SET, SHOW, INSERT, etc.) which is fragile - it can break with leading whitespace, comments, or case variations. Since we already have the parsed Statement AST from sqlparser, this switches to using pattern matching directly on the Statement enum. Changes in permissions.rs: - Renamed check_query_permission to check_statement_permission, now takes &Statement instead of &str - Permission detection (SELECT/INSERT/UPDATE/DELETE/CREATE/DROP/ALTER) now uses match on Statement variants - Added should_skip_permission_check() helper using matches!() macro for SET, SHOW, and transaction statements - Removed all the to_lowercase().starts_with() chains Changes in handlers.rs: - INSERT detection now uses matches!(statement, Statement::Insert(_)) - Removed unnecessary query_lower variable construction - Fixed extended query handler to properly destructure statement from the portal tuple All 12 unit tests pass. The existing integration test failures (dbeaver, metabase, psql) are unrelated - they fail due to missing DataFusion functions like array_length and array_contains. * test: add pgAdmin startup queries test for issue #178 Adds a test file to verify pgAdmin startup queries work correctly. The test covers: - SELECT version() query - The CASE expression checking pg_extension for 'bdr' extension and pg_replication_slots for replication slot count Both pg_extension and pg_replication_slots tables are already implemented in pg_catalog, so these queries now pass. * refactor: avoid String allocation in extended query handler Use reference instead of calling .to_string() on query string as suggested by mjgarton - the query is only used for logging and doesn't need to be owned. * refactor: remove redundant comments in permissions hook Remove comments that don't fully describe what's being skipped. The function name should_skip_permission_check is self-documenting. Suggested by mjgarton.
1 parent 2b51166 commit 0f69e19

3 files changed

Lines changed: 84 additions & 60 deletions

File tree

datafusion-postgres/src/handlers.rs

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,7 @@ impl SimpleQueryHandler for DfSessionService {
139139

140140
let mut results = vec![];
141141
'stmt: for statement in statements {
142-
// TODO: improve statement check by using statement directly
143142
let query = statement.to_string();
144-
let query_lower = query.to_lowercase().trim().to_string();
145143

146144
// Call query hooks with the parsed statement
147145
for hook in &self.query_hooks {
@@ -179,7 +177,7 @@ impl SimpleQueryHandler for DfSessionService {
179177
}
180178
};
181179

182-
if query_lower.starts_with("insert into") {
180+
if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
183181
let resp = map_rows_affected_for_insert(&df).await?;
184182
results.push(resp);
185183
} else {
@@ -265,13 +263,7 @@ impl ExtendedQueryHandler for DfSessionService {
265263
where
266264
C: ClientInfo + Unpin + Send + Sync,
267265
{
268-
let query = portal
269-
.statement
270-
.statement
271-
.0
272-
.to_lowercase()
273-
.trim()
274-
.to_string();
266+
let query = &portal.statement.statement.0;
275267
log::debug!("Received execute extended query: {query}"); // Log for debugging
276268

277269
// Check query hooks first
@@ -302,7 +294,7 @@ impl ExtendedQueryHandler for DfSessionService {
302294
}
303295
}
304296

305-
if let (_, Some((_, plan))) = &portal.statement.statement {
297+
if let (_, Some((statement, plan))) = &portal.statement.statement {
306298
let param_types = plan
307299
.get_parameter_types()
308300
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
@@ -345,7 +337,7 @@ impl ExtendedQueryHandler for DfSessionService {
345337
}
346338
};
347339

348-
if query.starts_with("insert into") {
340+
if matches!(statement, sqlparser::ast::Statement::Insert(_)) {
349341
let resp = map_rows_affected_for_insert(&dataframe).await?;
350342

351343
Ok(resp)

datafusion-postgres/src/hooks/permissions.rs

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@ impl PermissionsHook {
2323
PermissionsHook { auth_manager }
2424
}
2525

26-
/// Check if the current user has permission to execute a query
27-
async fn check_query_permission<C>(&self, client: &C, query: &str) -> PgWireResult<()>
26+
/// Check if the current user has permission to execute a statement
27+
async fn check_statement_permission<C>(
28+
&self,
29+
client: &C,
30+
statement: &Statement,
31+
) -> PgWireResult<()>
2832
where
2933
C: ClientInfo + ?Sized,
3034
{
@@ -35,29 +39,19 @@ impl PermissionsHook {
3539
.map(|s| s.as_str())
3640
.unwrap_or("anonymous");
3741

38-
// Parse query to determine required permissions
39-
let query_lower = query.to_lowercase();
40-
let query_trimmed = query_lower.trim();
41-
42-
let (required_permission, resource) = if query_trimmed.starts_with("select") {
43-
(Permission::Select, ResourceType::All)
44-
} else if query_trimmed.starts_with("insert") {
45-
(Permission::Insert, ResourceType::All)
46-
} else if query_trimmed.starts_with("update") {
47-
(Permission::Update, ResourceType::All)
48-
} else if query_trimmed.starts_with("delete") {
49-
(Permission::Delete, ResourceType::All)
50-
} else if query_trimmed.starts_with("create table")
51-
|| query_trimmed.starts_with("create view")
52-
{
53-
(Permission::Create, ResourceType::All)
54-
} else if query_trimmed.starts_with("drop") {
55-
(Permission::Drop, ResourceType::All)
56-
} else if query_trimmed.starts_with("alter") {
57-
(Permission::Alter, ResourceType::All)
58-
} else {
59-
// For other queries (SHOW, EXPLAIN, etc.), allow all users
60-
return Ok(());
42+
// Determine required permissions based on Statement type
43+
let (required_permission, resource) = match statement {
44+
Statement::Query(_) => (Permission::Select, ResourceType::All),
45+
Statement::Insert(_) => (Permission::Insert, ResourceType::All),
46+
Statement::Update { .. } => (Permission::Update, ResourceType::All),
47+
Statement::Delete(_) => (Permission::Delete, ResourceType::All),
48+
Statement::CreateTable { .. } | Statement::CreateView { .. } => {
49+
(Permission::Create, ResourceType::All)
50+
}
51+
Statement::Drop { .. } => (Permission::Drop, ResourceType::All),
52+
Statement::AlterTable { .. } => (Permission::Alter, ResourceType::All),
53+
// For other statements (SET, SHOW, EXPLAIN, transactions, etc.), allow all users
54+
_ => return Ok(()),
6155
};
6256

6357
// Check permission
@@ -78,6 +72,21 @@ impl PermissionsHook {
7872

7973
Ok(())
8074
}
75+
76+
/// Check if a statement should skip permission checks
77+
fn should_skip_permission_check(statement: &Statement) -> bool {
78+
matches!(
79+
statement,
80+
Statement::Set { .. }
81+
| Statement::ShowVariable { .. }
82+
| Statement::ShowStatus { .. }
83+
| Statement::StartTransaction { .. }
84+
| Statement::Commit { .. }
85+
| Statement::Rollback { .. }
86+
| Statement::Savepoint { .. }
87+
| Statement::ReleaseSavepoint { .. }
88+
)
89+
}
8190
}
8291

8392
#[async_trait]
@@ -89,22 +98,13 @@ impl QueryHook for PermissionsHook {
8998
_session_context: &SessionContext,
9099
client: &mut (dyn ClientInfo + Send + Sync),
91100
) -> Option<PgWireResult<Response>> {
92-
let query_lower = statement.to_string().to_lowercase();
93-
94-
// Check permissions for the query (skip for SET, transaction, and SHOW statements)
95-
if !query_lower.starts_with("set")
96-
&& !query_lower.starts_with("begin")
97-
&& !query_lower.starts_with("commit")
98-
&& !query_lower.starts_with("rollback")
99-
&& !query_lower.starts_with("start")
100-
&& !query_lower.starts_with("end")
101-
&& !query_lower.starts_with("abort")
102-
&& !query_lower.starts_with("show")
103-
{
104-
let res = self.check_query_permission(&*client, &query_lower).await;
105-
if let Err(e) = res {
106-
return Some(Err(e));
107-
}
101+
if Self::should_skip_permission_check(statement) {
102+
return None;
103+
}
104+
105+
// Check permissions for other statements
106+
if let Err(e) = self.check_statement_permission(&*client, statement).await {
107+
return Some(Err(e));
108108
}
109109

110110
None
@@ -127,15 +127,15 @@ impl QueryHook for PermissionsHook {
127127
_session_context: &SessionContext,
128128
client: &mut (dyn ClientInfo + Send + Sync),
129129
) -> Option<PgWireResult<Response>> {
130-
let query = statement.to_string().to_lowercase();
130+
if Self::should_skip_permission_check(statement) {
131+
return None;
132+
}
131133

132-
// Check permissions for the query (skip for SET and SHOW statements)
133-
if !query.starts_with("set") && !query.starts_with("show") {
134-
let res = self.check_query_permission(&*client, &query).await;
135-
if let Err(e) = res {
136-
return Some(Err(e));
137-
}
134+
// Check permissions for other statements
135+
if let Err(e) = self.check_statement_permission(&*client, statement).await {
136+
return Some(Err(e));
138137
}
138+
139139
None
140140
}
141141
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
use pgwire::api::query::SimpleQueryHandler;
2+
3+
use datafusion_postgres::testing::*;
4+
5+
// pgAdmin startup queries from issue #178
6+
// https://github.com/datafusion-contrib/datafusion-postgres/issues/178
7+
const PGADMIN_QUERIES: &[&str] = &[
8+
// Basic version query (fixed by #179)
9+
"SELECT version()",
10+
// Query to check for BDR extension and replication slots
11+
r#"SELECT CASE
12+
WHEN (SELECT count(extname) FROM pg_catalog.pg_extension WHERE extname='bdr') > 0
13+
THEN 'pgd'
14+
WHEN (SELECT COUNT(*) FROM pg_replication_slots) > 0
15+
THEN 'log'
16+
ELSE NULL
17+
END as type"#,
18+
];
19+
20+
#[tokio::test]
21+
pub async fn test_pgadmin_startup_sql() {
22+
let service = setup_handlers();
23+
let mut client = MockClient::new();
24+
25+
for query in PGADMIN_QUERIES {
26+
SimpleQueryHandler::do_query(&service, &mut client, query)
27+
.await
28+
.unwrap_or_else(|e| {
29+
panic!("failed to run sql:\n--------------\n{query}\n--------------\n{e}")
30+
});
31+
}
32+
}

0 commit comments

Comments
 (0)