Skip to content

Commit c8e3f95

Browse files
committed
Add statement-level query tag support via context
Previously, query tags could only be set at the session level via WithSessionParams during connection creation. This adds per-statement query tag support, allowing different tags for each query execution. Users pass query tags through context using the new driverctx.NewContextWithQueryTags function. The tags are serialized into the TExecuteStatementReq.ConfOverlay["query_tags"] field, consistent with the Python and NodeJS connector implementations. Co-authored-by: Isaac
1 parent 0b7d209 commit c8e3f95

7 files changed

Lines changed: 359 additions & 7 deletions

File tree

connection.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,17 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
323323
req.Parameters = parameters
324324
}
325325

326+
// Add per-statement query tags if provided via context
327+
if queryTags := driverctx.QueryTagsFromContext(ctx); len(queryTags) > 0 {
328+
serialized := SerializeQueryTags(queryTags)
329+
if serialized != "" {
330+
if req.ConfOverlay == nil {
331+
req.ConfOverlay = make(map[string]string)
332+
}
333+
req.ConfOverlay["query_tags"] = serialized
334+
}
335+
}
336+
326337
resp, err := c.client.ExecuteStatement(ctx, &req)
327338
var log *logger.DBSQLLogger
328339
log, ctx = client.LoggerAndContext(ctx, resp)

connection_test.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/apache/thrift/lib/go/thrift"
1111
"github.com/pkg/errors"
1212

13+
"github.com/databricks/databricks-sql-go/driverctx"
1314
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1415
"github.com/databricks/databricks-sql-go/internal/cli_service"
1516
"github.com/databricks/databricks-sql-go/internal/client"
@@ -493,6 +494,121 @@ func TestConn_executeStatement_ProtocolFeatures(t *testing.T) {
493494
}
494495
}
495496

497+
func TestConn_executeStatement_QueryTags(t *testing.T) {
498+
t.Parallel()
499+
500+
makeTestConn := func(captureReq *(*cli_service.TExecuteStatementReq)) *conn {
501+
executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
502+
*captureReq = req
503+
return &cli_service.TExecuteStatementResp{
504+
Status: &cli_service.TStatus{
505+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
506+
},
507+
OperationHandle: &cli_service.TOperationHandle{
508+
OperationId: &cli_service.THandleIdentifier{
509+
GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
510+
Secret: []byte("secret"),
511+
},
512+
},
513+
DirectResults: &cli_service.TSparkDirectResults{
514+
OperationStatus: &cli_service.TGetOperationStatusResp{
515+
Status: &cli_service.TStatus{
516+
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
517+
},
518+
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE),
519+
},
520+
},
521+
}, nil
522+
}
523+
524+
return &conn{
525+
session: getTestSession(),
526+
client: &client.TestClient{
527+
FnExecuteStatement: executeStatement,
528+
},
529+
cfg: config.WithDefaults(),
530+
}
531+
}
532+
533+
t.Run("query tags from context are set in ConfOverlay", func(t *testing.T) {
534+
var capturedReq *cli_service.TExecuteStatementReq
535+
testConn := makeTestConn(&capturedReq)
536+
537+
ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
538+
"team": "engineering",
539+
"app": "etl",
540+
})
541+
542+
_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
543+
assert.NoError(t, err)
544+
assert.NotNil(t, capturedReq.ConfOverlay)
545+
// Map iteration is non-deterministic, so check both possible orderings
546+
queryTags := capturedReq.ConfOverlay["query_tags"]
547+
assert.True(t,
548+
queryTags == "team:engineering,app:etl" || queryTags == "app:etl,team:engineering",
549+
"unexpected query_tags value: %s", queryTags)
550+
})
551+
552+
t.Run("no query tags in context means no ConfOverlay", func(t *testing.T) {
553+
var capturedReq *cli_service.TExecuteStatementReq
554+
testConn := makeTestConn(&capturedReq)
555+
556+
_, err := testConn.executeStatement(context.Background(), "SELECT 1", nil)
557+
assert.NoError(t, err)
558+
assert.Nil(t, capturedReq.ConfOverlay)
559+
})
560+
561+
t.Run("empty query tags map means no ConfOverlay", func(t *testing.T) {
562+
var capturedReq *cli_service.TExecuteStatementReq
563+
testConn := makeTestConn(&capturedReq)
564+
565+
ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{})
566+
567+
_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
568+
assert.NoError(t, err)
569+
assert.Nil(t, capturedReq.ConfOverlay)
570+
})
571+
572+
t.Run("single query tag", func(t *testing.T) {
573+
var capturedReq *cli_service.TExecuteStatementReq
574+
testConn := makeTestConn(&capturedReq)
575+
576+
ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
577+
"team": "data-eng",
578+
})
579+
580+
_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
581+
assert.NoError(t, err)
582+
assert.Equal(t, "team:data-eng", capturedReq.ConfOverlay["query_tags"])
583+
})
584+
585+
t.Run("query tags with special characters in values", func(t *testing.T) {
586+
var capturedReq *cli_service.TExecuteStatementReq
587+
testConn := makeTestConn(&capturedReq)
588+
589+
ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
590+
"url": "http://host:8080",
591+
})
592+
593+
_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
594+
assert.NoError(t, err)
595+
assert.Equal(t, `url:http\://host\:8080`, capturedReq.ConfOverlay["query_tags"])
596+
})
597+
598+
t.Run("query tags with empty value", func(t *testing.T) {
599+
var capturedReq *cli_service.TExecuteStatementReq
600+
testConn := makeTestConn(&capturedReq)
601+
602+
ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
603+
"flag": "",
604+
})
605+
606+
_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
607+
assert.NoError(t, err)
608+
assert.Equal(t, "flag", capturedReq.ConfOverlay["query_tags"])
609+
})
610+
}
611+
496612
func TestConn_pollOperation(t *testing.T) {
497613
t.Parallel()
498614
t.Run("pollOperation returns finished state response when query finishes", func(t *testing.T) {

driverctx/ctx.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ const (
1515
QueryIdCallbackKey
1616
ConnIdCallbackKey
1717
StagingAllowedLocalPathKey
18+
QueryTagsContextKey
1819
)
1920

2021
type IdCallbackFunc func(string)
@@ -107,16 +108,40 @@ func NewContextWithStagingInfo(ctx context.Context, stagingAllowedLocalPath []st
107108
return context.WithValue(ctx, StagingAllowedLocalPathKey, stagingAllowedLocalPath)
108109
}
109110

111+
// NewContextWithQueryTags creates a new context with per-statement query tags.
112+
// These tags are serialized and passed via confOverlay as "query_tags" in TExecuteStatementReq.
113+
// They apply only to the statement executed with this context and do not persist across queries.
114+
func NewContextWithQueryTags(ctx context.Context, queryTags map[string]string) context.Context {
115+
return context.WithValue(ctx, QueryTagsContextKey, queryTags)
116+
}
117+
118+
// QueryTagsFromContext retrieves the per-statement query tags stored in context.
119+
func QueryTagsFromContext(ctx context.Context) map[string]string {
120+
if ctx == nil {
121+
return nil
122+
}
123+
124+
queryTags, ok := ctx.Value(QueryTagsContextKey).(map[string]string)
125+
if !ok {
126+
return nil
127+
}
128+
return queryTags
129+
}
130+
110131
func NewContextFromBackground(ctx context.Context) context.Context {
111132
connId := ConnIdFromContext(ctx)
112133
corrId := CorrelationIdFromContext(ctx)
113134
queryId := QueryIdFromContext(ctx)
114135
stagingPaths := StagingPathsFromContext(ctx)
136+
queryTags := QueryTagsFromContext(ctx)
115137

116138
newCtx := NewContextWithConnId(context.Background(), connId)
117139
newCtx = NewContextWithCorrelationId(newCtx, corrId)
118140
newCtx = NewContextWithQueryId(newCtx, queryId)
119141
newCtx = NewContextWithStagingInfo(newCtx, stagingPaths)
142+
if queryTags != nil {
143+
newCtx = NewContextWithQueryTags(newCtx, queryTags)
144+
}
120145

121146
return newCtx
122147
}

driverctx/ctx_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,57 @@ import (
88
"github.com/stretchr/testify/assert"
99
)
1010

11+
func TestNewContextWithQueryTags(t *testing.T) {
12+
t.Run("stores and retrieves query tags", func(t *testing.T) {
13+
tags := map[string]string{"team": "engineering", "app": "etl"}
14+
ctx := NewContextWithQueryTags(context.Background(), tags)
15+
result := QueryTagsFromContext(ctx)
16+
assert.Equal(t, tags, result)
17+
})
18+
19+
t.Run("returns nil for context without query tags", func(t *testing.T) {
20+
result := QueryTagsFromContext(context.Background())
21+
assert.Nil(t, result)
22+
})
23+
24+
t.Run("returns nil for nil context", func(t *testing.T) {
25+
result := QueryTagsFromContext(nil)
26+
assert.Nil(t, result)
27+
})
28+
29+
t.Run("it maintains timeout", func(t *testing.T) {
30+
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
31+
defer cancel()
32+
tags := map[string]string{"team": "eng"}
33+
ctx1 := NewContextWithQueryTags(ctx, tags)
34+
result := QueryTagsFromContext(ctx1)
35+
assert.Equal(t, tags, result)
36+
dl, ok := ctx.Deadline()
37+
dl1, ok1 := ctx1.Deadline()
38+
assert.Equal(t, dl, dl1)
39+
assert.True(t, ok)
40+
assert.True(t, ok1)
41+
})
42+
43+
t.Run("NewContextFromBackground preserves query tags", func(t *testing.T) {
44+
tags := map[string]string{"team": "eng"}
45+
ctx := NewContextWithConnId(context.Background(), "conn-1")
46+
ctx = NewContextWithCorrelationId(ctx, "corr-1")
47+
ctx = NewContextWithQueryTags(ctx, tags)
48+
49+
newCtx := NewContextFromBackground(ctx)
50+
assert.Equal(t, tags, QueryTagsFromContext(newCtx))
51+
assert.Equal(t, "conn-1", ConnIdFromContext(newCtx))
52+
assert.Equal(t, "corr-1", CorrelationIdFromContext(newCtx))
53+
})
54+
55+
t.Run("NewContextFromBackground without query tags", func(t *testing.T) {
56+
ctx := NewContextWithConnId(context.Background(), "conn-1")
57+
newCtx := NewContextFromBackground(ctx)
58+
assert.Nil(t, QueryTagsFromContext(newCtx))
59+
})
60+
}
61+
1162
func TestNewContextWithCorrelationId(t *testing.T) {
1263
t.Run("base case", func(t *testing.T) {
1364

examples/query_tags/main.go

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"strconv"
1010

1111
dbsql "github.com/databricks/databricks-sql-go"
12+
"github.com/databricks/databricks-sql-go/driverctx"
1213
"github.com/joho/godotenv"
1314
)
1415

@@ -21,6 +22,7 @@ func main() {
2122
log.Fatal(err.Error())
2223
}
2324

25+
// Session-level query tags: applied to all queries in this session.
2426
connector, err := dbsql.NewConnector(
2527
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
2628
dbsql.WithPort(port),
@@ -38,16 +40,41 @@ func main() {
3840
db := sql.OpenDB(connector)
3941
defer db.Close()
4042

43+
// Example 1: Session-level query tags (set during connection)
44+
fmt.Println("=== Session-level query tags ===")
4145
ctx := context.Background()
4246
var result int
4347
err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)
4448
if err != nil {
45-
if err == sql.ErrNoRows {
46-
fmt.Println("not found")
47-
return
48-
} else {
49-
fmt.Printf("err: %v\n", err)
50-
}
49+
log.Printf("err: %v\n", err)
50+
} else {
51+
fmt.Println(result)
52+
}
53+
54+
// Example 2: Statement-level query tags (per-query override via context)
55+
fmt.Println("=== Statement-level query tags ===")
56+
ctx = driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
57+
"team": "data-eng",
58+
"application": "etl-pipeline",
59+
"env": "production",
60+
})
61+
err = db.QueryRowContext(ctx, "SELECT 2").Scan(&result)
62+
if err != nil {
63+
log.Printf("err: %v\n", err)
64+
} else {
65+
fmt.Println(result)
66+
}
67+
68+
// Example 3: Different query tags for a different statement
69+
fmt.Println("=== Different statement-level query tags ===")
70+
ctx = driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
71+
"team": "analytics",
72+
"job": "daily-report",
73+
})
74+
err = db.QueryRowContext(ctx, "SELECT 3").Scan(&result)
75+
if err != nil {
76+
log.Printf("err: %v\n", err)
77+
} else {
78+
fmt.Println(result)
5179
}
52-
fmt.Println(result)
5380
}

query_tags.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package dbsql
2+
3+
import "strings"
4+
5+
// SerializeQueryTags converts a map of query tags to the wire format string.
6+
// The format is comma-separated key:value pairs (e.g., "team:engineering,app:etl").
7+
//
8+
// Escaping rules (consistent with Python and NodeJS connectors):
9+
// - Keys: only backslashes are escaped
10+
// - Values: backslashes, colons, and commas are escaped with a leading backslash
11+
// - Empty string values result in just the key being emitted (no colon)
12+
//
13+
// Returns empty string if the map is nil or empty.
14+
func SerializeQueryTags(tags map[string]string) string {
15+
if len(tags) == 0 {
16+
return ""
17+
}
18+
19+
parts := make([]string, 0, len(tags))
20+
for k, v := range tags {
21+
escapedKey := strings.ReplaceAll(k, `\`, `\\`)
22+
if v == "" {
23+
parts = append(parts, escapedKey)
24+
} else {
25+
escapedValue := strings.ReplaceAll(v, `\`, `\\`)
26+
escapedValue = strings.ReplaceAll(escapedValue, `:`, `\:`)
27+
escapedValue = strings.ReplaceAll(escapedValue, `,`, `\,`)
28+
parts = append(parts, escapedKey+":"+escapedValue)
29+
}
30+
}
31+
return strings.Join(parts, ",")
32+
}

0 commit comments

Comments
 (0)