@@ -10,6 +10,7 @@ import (
1010 "github.com/apache/thrift/lib/go/thrift"
1111 "github.com/pkg/errors"
1212
13+ "github.com/databricks/databricks-sql-go/driverctx"
1314 dbsqlerr "github.com/databricks/databricks-sql-go/errors"
1415 "github.com/databricks/databricks-sql-go/internal/cli_service"
1516 "github.com/databricks/databricks-sql-go/internal/client"
@@ -493,6 +494,121 @@ func TestConn_executeStatement_ProtocolFeatures(t *testing.T) {
493494 }
494495}
495496
497+ func TestConn_executeStatement_QueryTags (t * testing.T ) {
498+ t .Parallel ()
499+
500+ makeTestConn := func (captureReq * (* cli_service.TExecuteStatementReq )) * conn {
501+ executeStatement := func (ctx context.Context , req * cli_service.TExecuteStatementReq ) (r * cli_service.TExecuteStatementResp , err error ) {
502+ * captureReq = req
503+ return & cli_service.TExecuteStatementResp {
504+ Status : & cli_service.TStatus {
505+ StatusCode : cli_service .TStatusCode_SUCCESS_STATUS ,
506+ },
507+ OperationHandle : & cli_service.TOperationHandle {
508+ OperationId : & cli_service.THandleIdentifier {
509+ GUID : []byte {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 },
510+ Secret : []byte ("secret" ),
511+ },
512+ },
513+ DirectResults : & cli_service.TSparkDirectResults {
514+ OperationStatus : & cli_service.TGetOperationStatusResp {
515+ Status : & cli_service.TStatus {
516+ StatusCode : cli_service .TStatusCode_SUCCESS_STATUS ,
517+ },
518+ OperationState : cli_service .TOperationStatePtr (cli_service .TOperationState_FINISHED_STATE ),
519+ },
520+ },
521+ }, nil
522+ }
523+
524+ return & conn {
525+ session : getTestSession (),
526+ client : & client.TestClient {
527+ FnExecuteStatement : executeStatement ,
528+ },
529+ cfg : config .WithDefaults (),
530+ }
531+ }
532+
533+ t .Run ("query tags from context are set in ConfOverlay" , func (t * testing.T ) {
534+ var capturedReq * cli_service.TExecuteStatementReq
535+ testConn := makeTestConn (& capturedReq )
536+
537+ ctx := driverctx .NewContextWithQueryTags (context .Background (), map [string ]string {
538+ "team" : "engineering" ,
539+ "app" : "etl" ,
540+ })
541+
542+ _ , err := testConn .executeStatement (ctx , "SELECT 1" , nil )
543+ assert .NoError (t , err )
544+ assert .NotNil (t , capturedReq .ConfOverlay )
545+ // Map iteration is non-deterministic, so check both possible orderings
546+ queryTags := capturedReq .ConfOverlay ["query_tags" ]
547+ assert .True (t ,
548+ queryTags == "team:engineering,app:etl" || queryTags == "app:etl,team:engineering" ,
549+ "unexpected query_tags value: %s" , queryTags )
550+ })
551+
552+ t .Run ("no query tags in context means no ConfOverlay" , func (t * testing.T ) {
553+ var capturedReq * cli_service.TExecuteStatementReq
554+ testConn := makeTestConn (& capturedReq )
555+
556+ _ , err := testConn .executeStatement (context .Background (), "SELECT 1" , nil )
557+ assert .NoError (t , err )
558+ assert .Nil (t , capturedReq .ConfOverlay )
559+ })
560+
561+ t .Run ("empty query tags map means no ConfOverlay" , func (t * testing.T ) {
562+ var capturedReq * cli_service.TExecuteStatementReq
563+ testConn := makeTestConn (& capturedReq )
564+
565+ ctx := driverctx .NewContextWithQueryTags (context .Background (), map [string ]string {})
566+
567+ _ , err := testConn .executeStatement (ctx , "SELECT 1" , nil )
568+ assert .NoError (t , err )
569+ assert .Nil (t , capturedReq .ConfOverlay )
570+ })
571+
572+ t .Run ("single query tag" , func (t * testing.T ) {
573+ var capturedReq * cli_service.TExecuteStatementReq
574+ testConn := makeTestConn (& capturedReq )
575+
576+ ctx := driverctx .NewContextWithQueryTags (context .Background (), map [string ]string {
577+ "team" : "data-eng" ,
578+ })
579+
580+ _ , err := testConn .executeStatement (ctx , "SELECT 1" , nil )
581+ assert .NoError (t , err )
582+ assert .Equal (t , "team:data-eng" , capturedReq .ConfOverlay ["query_tags" ])
583+ })
584+
585+ t .Run ("query tags with special characters in values" , func (t * testing.T ) {
586+ var capturedReq * cli_service.TExecuteStatementReq
587+ testConn := makeTestConn (& capturedReq )
588+
589+ ctx := driverctx .NewContextWithQueryTags (context .Background (), map [string ]string {
590+ "url" : "http://host:8080" ,
591+ })
592+
593+ _ , err := testConn .executeStatement (ctx , "SELECT 1" , nil )
594+ assert .NoError (t , err )
595+ assert .Equal (t , `url:http\://host\:8080` , capturedReq .ConfOverlay ["query_tags" ])
596+ })
597+
598+ t .Run ("query tags with empty value" , func (t * testing.T ) {
599+ var capturedReq * cli_service.TExecuteStatementReq
600+ testConn := makeTestConn (& capturedReq )
601+
602+ ctx := driverctx .NewContextWithQueryTags (context .Background (), map [string ]string {
603+ "flag" : "" ,
604+ })
605+
606+ _ , err := testConn .executeStatement (ctx , "SELECT 1" , nil )
607+ assert .NoError (t , err )
608+ assert .Equal (t , "flag" , capturedReq .ConfOverlay ["query_tags" ])
609+ })
610+ }
611+
496612func TestConn_pollOperation (t * testing.T ) {
497613 t .Parallel ()
498614 t .Run ("pollOperation returns finished state response when query finishes" , func (t * testing.T ) {
0 commit comments