Skip to content

Commit 606f44e

Browse files
committed
Add WithQueryTags connector option for session-level map support
Addresses review feedback from jiabin-hu: 1. Add WithQueryTags(map[string]string) as a connector option that accepts a structured map and serializes it internally, consistent with the statement-level API and the Python connector approach (databricks/databricks-sql-python@e916f71). 2. Context values pattern is the idiomatic Go approach for per-request metadata in database/sql drivers (same pattern used by ConnId, CorrelationId, QueryId, and StagingInfo in this driver). Co-authored-by: Isaac
1 parent ba658f2 commit 606f44e

3 files changed

Lines changed: 84 additions & 6 deletions

File tree

connector.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,23 @@ func WithSessionParams(params map[string]string) ConnOption {
235235
}
236236
}
237237

238+
// WithQueryTags sets session-level query tags from a map.
239+
// Tags are serialized and passed as QUERY_TAGS in the session configuration.
240+
// All queries in the session will carry these tags unless overridden at the statement level.
241+
// This is the preferred way to set session-level query tags, as it handles serialization
242+
// and escaping automatically (consistent with the statement-level API).
243+
func WithQueryTags(tags map[string]string) ConnOption {
244+
return func(c *config.Config) {
245+
serialized := SerializeQueryTags(tags)
246+
if serialized != "" {
247+
if c.SessionParams == nil {
248+
c.SessionParams = make(map[string]string)
249+
}
250+
c.SessionParams["QUERY_TAGS"] = serialized
251+
}
252+
}
253+
}
254+
238255
// WithSkipTLSHostVerify disables the verification of the hostname in the TLS certificate.
239256
// WARNING:
240257
// When this option is used, TLS is susceptible to machine-in-the-middle attacks.

connector_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,65 @@ func TestNewConnector(t *testing.T) {
268268
})
269269
}
270270

271+
func TestWithQueryTags(t *testing.T) {
272+
t.Run("WithQueryTags serializes map into SessionParams QUERY_TAGS", func(t *testing.T) {
273+
con, err := NewConnector(
274+
WithQueryTags(map[string]string{
275+
"team": "data-eng",
276+
}),
277+
)
278+
require.NoError(t, err)
279+
coni, ok := con.(*connector)
280+
require.True(t, ok)
281+
assert.Equal(t, "team:data-eng", coni.cfg.SessionParams["QUERY_TAGS"])
282+
})
283+
284+
t.Run("WithQueryTags with multiple tags", func(t *testing.T) {
285+
con, err := NewConnector(
286+
WithQueryTags(map[string]string{
287+
"team": "eng",
288+
"app": "etl",
289+
}),
290+
)
291+
require.NoError(t, err)
292+
coni, ok := con.(*connector)
293+
require.True(t, ok)
294+
// Map iteration is non-deterministic
295+
qt := coni.cfg.SessionParams["QUERY_TAGS"]
296+
assert.True(t, qt == "team:eng,app:etl" || qt == "app:etl,team:eng", "got: %s", qt)
297+
})
298+
299+
t.Run("WithQueryTags with empty map does not set QUERY_TAGS", func(t *testing.T) {
300+
con, err := NewConnector(
301+
WithQueryTags(map[string]string{}),
302+
)
303+
require.NoError(t, err)
304+
coni, ok := con.(*connector)
305+
require.True(t, ok)
306+
_, exists := coni.cfg.SessionParams["QUERY_TAGS"]
307+
assert.False(t, exists)
308+
})
309+
310+
t.Run("WithQueryTags overrides WithSessionParams QUERY_TAGS", func(t *testing.T) {
311+
con, err := NewConnector(
312+
WithSessionParams(map[string]string{
313+
"QUERY_TAGS": "old:value",
314+
"ansi_mode": "false",
315+
}),
316+
WithQueryTags(map[string]string{
317+
"team": "new-team",
318+
}),
319+
)
320+
require.NoError(t, err)
321+
coni, ok := con.(*connector)
322+
require.True(t, ok)
323+
// WithQueryTags should override the QUERY_TAGS from WithSessionParams
324+
assert.Equal(t, "team:new-team", coni.cfg.SessionParams["QUERY_TAGS"])
325+
// Other session params should be preserved
326+
assert.Equal(t, "false", coni.cfg.SessionParams["ansi_mode"])
327+
})
328+
}
329+
271330
type mockRoundTripper struct{}
272331

273332
var _ http.RoundTripper = mockRoundTripper{}

examples/query_tags/main.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@ func main() {
2222
log.Fatal(err.Error())
2323
}
2424

25-
// Session-level query tags: applied to all queries in this session.
25+
// Connection-level query tags: applied to all queries in this session.
26+
// WithQueryTags accepts a map and handles serialization automatically.
2627
connector, err := dbsql.NewConnector(
2728
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
2829
dbsql.WithPort(port),
2930
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
3031
dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")),
31-
dbsql.WithSessionParams(map[string]string{
32-
"QUERY_TAGS": "team:engineering,test:query-tags,driver:go",
33-
"ansi_mode": "false",
32+
dbsql.WithQueryTags(map[string]string{
33+
"team": "engineering",
34+
"test": "query-tags",
35+
"driver": "go",
3436
}),
3537
)
3638
if err != nil {
@@ -40,8 +42,8 @@ func main() {
4042
db := sql.OpenDB(connector)
4143
defer db.Close()
4244

43-
// Example 1: Session-level query tags (set during connection)
44-
fmt.Println("=== Session-level query tags ===")
45+
// Example 1: Connection-level query tags (set during connection)
46+
fmt.Println("=== Connection-level query tags ===")
4547
ctx := context.Background()
4648
var result int
4749
err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)

0 commit comments

Comments
 (0)