diff --git a/CHANGES.md b/CHANGES.md index 57f0ccb8..70657f79 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -19,6 +19,7 @@ Release Notes. * Fix plugin interceptors bypassed on Windows. * Fix wrong tracing context switch when trace ignore plugin activated. +* Fix gRPC server streaming plugin panic on grpc-go v1.69+ where `transport.Stream` was replaced by `transport.ServerStream`. #### Issues and PR - All issues are [here](https://github.com/apache/skywalking/milestone/238?closed=1) diff --git a/docs/en/agent/support-plugins.md b/docs/en/agent/support-plugins.md index 7836e634..64786223 100644 --- a/docs/en/agent/support-plugins.md +++ b/docs/en/agent/support-plugins.md @@ -19,7 +19,7 @@ metrics based on the tracing data. * `dubbo`: [Dubbo](https://github.com/apache/dubbo-go) tested v3.0.1 to v3.0.5. * `kratosv2`: [Kratos](https://github.com/go-kratos/kratos) tested v2.3.1 to v2.6.2. * `microv4`: [Go-Micro](https://github.com/go-micro/go-micro) tested v4.6.0 to v4.10.2. - * `grpc` : [gRPC](https://github.com/grpc/grpc-go) tested v1.55.0 to v1.64.0. + * `grpc` : [gRPC](https://github.com/grpc/grpc-go) tested v1.55.0 to v1.78.0. * Database Client * `gorm`: [GORM](https://github.com/go-gorm/gorm) tested v1.22.0 to v1.25.10. * [MySQL Driver](https://github.com/go-gorm/mysql) diff --git a/plugins/core/instrument/enhance.go b/plugins/core/instrument/enhance.go index fbc9c1d9..06b8099a 100644 --- a/plugins/core/instrument/enhance.go +++ b/plugins/core/instrument/enhance.go @@ -81,8 +81,71 @@ func generateTypeNameByExp(exp dst.Expr) string { data = "..." + generateTypeNameByExp(n.Elt) case *dst.ArrayType: data = "[]" + generateTypeNameByExp(n.Elt) + case *dst.FuncType: + data = generateFuncTypeName(n) + case *dst.ChanType: + data = generateChanTypeName(n) default: return "" } return data } + +func generateFuncTypeName(n *dst.FuncType) string { + result := "func(" + if n.Params != nil { + result += joinFieldTypes(n.Params.List, ", ") + } + result += ")" + if n.Results != nil && len(n.Results.List) > 0 { + result += generateFuncResultTypes(n.Results.List) + } + return result +} + +func joinFieldTypes(fields []*dst.Field, sep string) string { + result := "" + first := true + for _, field := range fields { + count := len(field.Names) + if count == 0 { + count = 1 + } + for k := 0; k < count; k++ { + if !first { + result += sep + } + result += generateTypeNameByExp(field.Type) + first = false + } + } + return result +} + +func generateFuncResultTypes(fields []*dst.Field) string { + totalResults := 0 + for _, field := range fields { + if len(field.Names) == 0 { + totalResults++ + } else { + totalResults += len(field.Names) + } + } + if totalResults == 1 { + return " " + joinFieldTypes(fields, ", ") + } + return " (" + joinFieldTypes(fields, ", ") + ")" +} + +func generateChanTypeName(n *dst.ChanType) string { + switch n.Dir { + case dst.SEND | dst.RECV: + return "chan " + generateTypeNameByExp(n.Value) + case dst.SEND: + return "chan<- " + generateTypeNameByExp(n.Value) + case dst.RECV: + return "<-chan " + generateTypeNameByExp(n.Value) + default: + return "chan " + generateTypeNameByExp(n.Value) + } +} diff --git a/plugins/core/instrument/enhance_test.go b/plugins/core/instrument/enhance_test.go new file mode 100644 index 00000000..810418ba --- /dev/null +++ b/plugins/core/instrument/enhance_test.go @@ -0,0 +1,251 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package instrument + +import ( + "testing" + + "github.com/dave/dst" + "github.com/stretchr/testify/assert" +) + +func TestGenerateTypeNameByExp_BasicTypes(t *testing.T) { + tests := []struct { + name string + expr dst.Expr + expected string + }{ + {"ident", &dst.Ident{Name: "string"}, "string"}, + {"selector", &dst.SelectorExpr{X: dst.NewIdent("context"), Sel: dst.NewIdent("Context")}, "context.Context"}, + {"star selector", &dst.StarExpr{X: &dst.SelectorExpr{X: dst.NewIdent("http"), Sel: dst.NewIdent("Request")}}, "*http.Request"}, + {"star ident", &dst.StarExpr{X: dst.NewIdent("error")}, "*error"}, + {"ellipsis", &dst.Ellipsis{Elt: dst.NewIdent("string")}, "...string"}, + {"array", &dst.ArrayType{Elt: dst.NewIdent("int")}, "[]int"}, + {"interface", &dst.InterfaceType{}, "interface{}"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, generateTypeNameByExp(tt.expr)) + }) + } +} + +func TestGenerateTypeNameByExp_FuncType(t *testing.T) { + tests := []struct { + name string + expr dst.Expr + expected string + }{ + { + "func with no params no results", + &dst.FuncType{Params: &dst.FieldList{}}, + "func()", + }, + { + "func with single param", + &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Type: dst.NewIdent("int")}, + }}, + }, + "func(int)", + }, + { + "func with multiple params", + &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Type: &dst.SelectorExpr{X: dst.NewIdent("context"), Sel: dst.NewIdent("Context")}}, + {Type: dst.NewIdent("string")}, + }}, + }, + "func(context.Context, string)", + }, + { + "func with single unnamed result", + &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Type: dst.NewIdent("int")}, + }}, + Results: &dst.FieldList{List: []*dst.Field{ + {Type: dst.NewIdent("error")}, + }}, + }, + "func(int) error", + }, + { + "func with multiple results", + &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Type: dst.NewIdent("string")}, + }}, + Results: &dst.FieldList{List: []*dst.Field{ + {Type: dst.NewIdent("int")}, + {Type: dst.NewIdent("error")}, + }}, + }, + "func(string) (int, error)", + }, + { + "func with complex params", + &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Type: &dst.SelectorExpr{X: dst.NewIdent("context"), Sel: dst.NewIdent("Context")}}, + {Type: &dst.StarExpr{X: &dst.SelectorExpr{X: dst.NewIdent("primitive"), Sel: dst.NewIdent("SendResult")}}}, + {Type: dst.NewIdent("error")}, + }}, + }, + "func(context.Context, *primitive.SendResult, error)", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, generateTypeNameByExp(tt.expr)) + }) + } +} + +func TestGenerateTypeNameByExp_FuncType_NamedFields(t *testing.T) { + tests := []struct { + name string + expr dst.Expr + expected string + }{ + { + "func with named params ignores names", + &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Names: []*dst.Ident{dst.NewIdent("ctx")}, Type: &dst.SelectorExpr{X: dst.NewIdent("context"), Sel: dst.NewIdent("Context")}}, + {Names: []*dst.Ident{dst.NewIdent("err")}, Type: dst.NewIdent("error")}, + }}, + }, + "func(context.Context, error)", + }, + { + "func with multi-name field expands types", + &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Names: []*dst.Ident{dst.NewIdent("a"), dst.NewIdent("b")}, Type: dst.NewIdent("int")}, + {Names: []*dst.Ident{dst.NewIdent("s")}, Type: dst.NewIdent("string")}, + }}, + }, + "func(int, int, string)", + }, + { + "func with named single result ignores name", + &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Type: dst.NewIdent("int")}, + }}, + Results: &dst.FieldList{List: []*dst.Field{ + {Names: []*dst.Ident{dst.NewIdent("err")}, Type: dst.NewIdent("error")}, + }}, + }, + "func(int) error", + }, + { + "func with named multiple results ignores names", + &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Type: dst.NewIdent("int")}, + }}, + Results: &dst.FieldList{List: []*dst.Field{ + {Names: []*dst.Ident{dst.NewIdent("n")}, Type: dst.NewIdent("int")}, + {Names: []*dst.Ident{dst.NewIdent("err")}, Type: dst.NewIdent("error")}, + }}, + }, + "func(int) (int, error)", + }, + { + "func with multi-name result field expands types", + &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Type: dst.NewIdent("string")}, + }}, + Results: &dst.FieldList{List: []*dst.Field{ + {Names: []*dst.Ident{dst.NewIdent("x"), dst.NewIdent("y")}, Type: dst.NewIdent("int")}, + }}, + }, + "func(string) (int, int)", + }, + { + "func with named params and results from real code", + &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Names: []*dst.Ident{dst.NewIdent("ctx")}, Type: &dst.SelectorExpr{X: dst.NewIdent("context"), Sel: dst.NewIdent("Context")}}, + {Names: []*dst.Ident{dst.NewIdent("req")}, Type: &dst.StarExpr{X: &dst.SelectorExpr{X: dst.NewIdent("http"), Sel: dst.NewIdent("Request")}}}, + }}, + Results: &dst.FieldList{List: []*dst.Field{ + {Names: []*dst.Ident{dst.NewIdent("resp")}, Type: &dst.StarExpr{X: &dst.SelectorExpr{X: dst.NewIdent("http"), Sel: dst.NewIdent("Response")}}}, + {Names: []*dst.Ident{dst.NewIdent("err")}, Type: dst.NewIdent("error")}, + }}, + }, + "func(context.Context, *http.Request) (*http.Response, error)", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, generateTypeNameByExp(tt.expr)) + }) + } +} + +func TestGenerateTypeNameByExp_ChanType(t *testing.T) { + tests := []struct { + name string + expr dst.Expr + expected string + }{ + { + "bidirectional chan", + &dst.ChanType{ + Dir: dst.SEND | dst.RECV, + Value: dst.NewIdent("int"), + }, + "chan int", + }, + { + "send-only chan", + &dst.ChanType{ + Dir: dst.SEND, + Value: dst.NewIdent("int"), + }, + "chan<- int", + }, + { + "receive-only chan", + &dst.ChanType{ + Dir: dst.RECV, + Value: dst.NewIdent("Delivery"), + }, + "<-chan Delivery", + }, + { + "receive-only chan with selector type", + &dst.ChanType{ + Dir: dst.RECV, + Value: &dst.SelectorExpr{X: dst.NewIdent("amqp"), Sel: dst.NewIdent("Delivery")}, + }, + "<-chan amqp.Delivery", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, generateTypeNameByExp(tt.expr)) + }) + } +} diff --git a/plugins/grpc/instrument.go b/plugins/grpc/instrument.go index b6dc327f..b9387f4a 100644 --- a/plugins/grpc/instrument.go +++ b/plugins/grpc/instrument.go @@ -70,8 +70,16 @@ func (i *Instrument) Points() []*instrument.Point { { PackagePath: "", At: instrument.NewMethodEnhance("*Server", "handleStream", - instrument.WithArgType(0, "transport.ServerTransport")), - Interceptor: "ServerHandleStreamInterceptor ", + instrument.WithArgType(0, "transport.ServerTransport"), + instrument.WithArgType(1, "*transport.Stream")), + Interceptor: "ServerHandleStreamInterceptor", + }, + { + PackagePath: "", + At: instrument.NewMethodEnhance("*Server", "handleStream", + instrument.WithArgType(0, "transport.ServerTransport"), + instrument.WithArgType(1, "*transport.ServerStream")), + Interceptor: "ServerHandleStreamInterceptorV2", }, { PackagePath: "", diff --git a/plugins/grpc/server_handleStream_interceptor_test.go b/plugins/grpc/server_handleStream_interceptor_test.go new file mode 100644 index 00000000..d2f75a00 --- /dev/null +++ b/plugins/grpc/server_handleStream_interceptor_test.go @@ -0,0 +1,133 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package grpc + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/apache/skywalking-go/plugins/core" + "github.com/apache/skywalking-go/plugins/core/operator" +) + +func TestServerHandleStreamInterceptorBeforeInvoke(t *testing.T) { + defer core.ResetTracingContext() + + interceptor := &ServerHandleStreamInterceptor{} + stream := &nativeStream{ + ctx: context.Background(), + method: "/api.Echo/UnaryEcho", + } + invocation := operator.NewInvocation(nil, nil, stream) + + err := interceptor.BeforeInvoke(invocation) + assert.Nil(t, err) + assert.NotNil(t, invocation.GetContext()) +} + +func TestServerHandleStreamInterceptorAfterInvoke(t *testing.T) { + defer core.ResetTracingContext() + + interceptor := &ServerHandleStreamInterceptor{} + stream := &nativeStream{ + ctx: context.Background(), + method: "/api.Echo/UnaryEcho", + } + invocation := operator.NewInvocation(nil, nil, stream) + + err := interceptor.BeforeInvoke(invocation) + assert.Nil(t, err) + + time.Sleep(100 * time.Millisecond) + + err = interceptor.AfterInvoke(invocation) + assert.Nil(t, err) + + time.Sleep(100 * time.Millisecond) + spans := core.GetReportedSpans() + assert.NotNil(t, spans) + assert.Equal(t, 1, len(spans)) + assert.Equal(t, "api.Echo.UnaryEcho", spans[0].OperationName()) +} + +func TestServerHandleStreamInterceptorAfterInvokeWithNilContext(t *testing.T) { + defer core.ResetTracingContext() + + interceptor := &ServerHandleStreamInterceptor{} + invocation := operator.NewInvocation(nil, nil, nil) + + err := interceptor.AfterInvoke(invocation) + assert.Nil(t, err) +} + +func TestServerHandleStreamInterceptorV2BeforeInvokeWithServerStream(t *testing.T) { + defer core.ResetTracingContext() + + interceptor := &ServerHandleStreamInterceptorV2{} + stream := &nativeServerStream{ + nativeStream: nativeStream{ + ctx: context.Background(), + method: "/api.Echo/ServerStreamingEcho", + }, + } + invocation := operator.NewInvocation(nil, nil, stream) + + err := interceptor.BeforeInvoke(invocation) + assert.Nil(t, err) + assert.NotNil(t, invocation.GetContext()) +} + +func TestServerHandleStreamInterceptorV2AfterInvoke(t *testing.T) { + defer core.ResetTracingContext() + + interceptor := &ServerHandleStreamInterceptorV2{} + stream := &nativeServerStream{ + nativeStream: nativeStream{ + ctx: context.Background(), + method: "/api.Echo/ServerStreamingEcho", + }, + } + invocation := operator.NewInvocation(nil, nil, stream) + + err := interceptor.BeforeInvoke(invocation) + assert.Nil(t, err) + + time.Sleep(100 * time.Millisecond) + + err = interceptor.AfterInvoke(invocation) + assert.Nil(t, err) + + time.Sleep(100 * time.Millisecond) + spans := core.GetReportedSpans() + assert.NotNil(t, spans) + assert.Equal(t, 1, len(spans)) + assert.Equal(t, "api.Echo.ServerStreamingEcho", spans[0].OperationName()) +} + +func TestServerHandleStreamInterceptorV2AfterInvokeWithNilContext(t *testing.T) { + defer core.ResetTracingContext() + + interceptor := &ServerHandleStreamInterceptorV2{} + invocation := operator.NewInvocation(nil, nil, nil) + + err := interceptor.AfterInvoke(invocation) + assert.Nil(t, err) +} diff --git a/plugins/grpc/server_handleStream_v2_interceptor.go b/plugins/grpc/server_handleStream_v2_interceptor.go new file mode 100644 index 00000000..52718721 --- /dev/null +++ b/plugins/grpc/server_handleStream_v2_interceptor.go @@ -0,0 +1,59 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package grpc + +import ( + "google.golang.org/grpc/metadata" + + "github.com/apache/skywalking-go/plugins/core/operator" + "github.com/apache/skywalking-go/plugins/core/tracing" +) + +type ServerHandleStreamInterceptorV2 struct { +} + +func (h *ServerHandleStreamInterceptorV2) BeforeInvoke(invocation operator.Invocation) error { + stream := invocation.Args()[1].(*nativeServerStream) + method := stream.Method() + ctx := stream.Context() + md, _ := metadata.FromIncomingContext(ctx) + s, err := tracing.CreateEntrySpan(formatOperationName(method, ""), func(headerKey string) (string, error) { + Value := "" + vals := md.Get(headerKey) + if len(vals) > 0 { + Value = vals[0] + } + return Value, nil + }, tracing.WithLayer(tracing.SpanLayerRPCFramework), + tracing.WithTag(tracing.TagURL, method), + tracing.WithComponent(23), + ) + if err != nil { + return err + } + invocation.SetContext(s) + return nil +} + +func (h *ServerHandleStreamInterceptorV2) AfterInvoke(invocation operator.Invocation, result ...interface{}) error { + if invocation.GetContext() == nil { + return nil + } + invocation.GetContext().(tracing.Span).End() + return nil +} diff --git a/plugins/grpc/structures.go b/plugins/grpc/structures.go index 7574b33a..b8078dd9 100644 --- a/plugins/grpc/structures.go +++ b/plugins/grpc/structures.go @@ -35,6 +35,11 @@ func (s *nativeStream) Context() context.Context { return s.ctx } +//skywalking:native google.golang.org/grpc/internal/transport ServerStream +type nativeServerStream struct { + nativeStream +} + //skywalking:native google.golang.org/grpc ClientConn type nativeClientConn struct { } diff --git a/test/plugins/scenarios/grpc/plugin.yml b/test/plugins/scenarios/grpc/plugin.yml index 6da45fff..e9701f97 100644 --- a/test/plugins/scenarios/grpc/plugin.yml +++ b/test/plugins/scenarios/grpc/plugin.yml @@ -30,3 +30,6 @@ support-version: - v1.60.0 - v1.62.0 - v1.64.0 + - v1.69.0 + - v1.72.0 + - v1.78.0 diff --git a/tools/go-agent/instrument/plugins/instrument.go b/tools/go-agent/instrument/plugins/instrument.go index 8820a847..0af1733e 100644 --- a/tools/go-agent/instrument/plugins/instrument.go +++ b/tools/go-agent/instrument/plugins/instrument.go @@ -540,7 +540,9 @@ func (i *Instrument) validateMethodInsMatch(matcher *instrument.EnhanceMatcher, return false } var name = tools.GenerateTypeNameByExp(node.Recv.List[0].Type) - return name == matcher.Receiver + if name != matcher.Receiver { + return false + } } for _, filter := range matcher.MethodFilters { if !filter(node, allFiles) { diff --git a/tools/go-agent/instrument/plugins/instrument_test.go b/tools/go-agent/instrument/plugins/instrument_test.go index 57f56310..5ef71ffe 100644 --- a/tools/go-agent/instrument/plugins/instrument_test.go +++ b/tools/go-agent/instrument/plugins/instrument_test.go @@ -21,6 +21,8 @@ import ( "embed" "testing" + "github.com/dave/dst" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/apache/skywalking-go/plugins/core/instrument" @@ -117,3 +119,129 @@ func (i *TestInstrument) Points() []*instrument.Point { func (i *TestInstrument) FS() *embed.FS { return nil } + +func TestInstrument_validateMethodInsMatch_WithArgTypeFilters(t *testing.T) { + inst := &Instrument{} + + handleStreamV1 := &dst.FuncDecl{ + Name: dst.NewIdent("handleStream"), + Recv: &dst.FieldList{List: []*dst.Field{ + {Type: &dst.StarExpr{X: dst.NewIdent("Server")}}, + }}, + Type: &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Type: &dst.SelectorExpr{X: dst.NewIdent("transport"), Sel: dst.NewIdent("ServerTransport")}}, + {Type: &dst.StarExpr{X: &dst.SelectorExpr{X: dst.NewIdent("transport"), Sel: dst.NewIdent("Stream")}}}, + }}, + }, + } + + handleStreamV2 := &dst.FuncDecl{ + Name: dst.NewIdent("handleStream"), + Recv: &dst.FieldList{List: []*dst.Field{ + {Type: &dst.StarExpr{X: dst.NewIdent("Server")}}, + }}, + Type: &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Type: &dst.SelectorExpr{X: dst.NewIdent("transport"), Sel: dst.NewIdent("ServerTransport")}}, + {Type: &dst.StarExpr{X: &dst.SelectorExpr{X: dst.NewIdent("transport"), Sel: dst.NewIdent("ServerStream")}}}, + }}, + }, + } + + matcherV1 := &instrument.EnhanceMatcher{ + Type: instrument.EnhanceTypeMethod, + Name: "handleStream", + Receiver: "*Server", + MethodFilters: []instrument.MethodFilterOption{ + instrument.WithArgType(0, "transport.ServerTransport"), + instrument.WithArgType(1, "*transport.Stream"), + }, + } + + matcherV2 := &instrument.EnhanceMatcher{ + Type: instrument.EnhanceTypeMethod, + Name: "handleStream", + Receiver: "*Server", + MethodFilters: []instrument.MethodFilterOption{ + instrument.WithArgType(0, "transport.ServerTransport"), + instrument.WithArgType(1, "*transport.ServerStream"), + }, + } + + tests := []struct { + name string + matcher *instrument.EnhanceMatcher + node *dst.FuncDecl + want bool + }{ + {"V1 matcher matches V1 signature", matcherV1, handleStreamV1, true}, + {"V1 matcher does not match V2 signature", matcherV1, handleStreamV2, false}, + {"V2 matcher does not match V1 signature", matcherV2, handleStreamV1, false}, + {"V2 matcher matches V2 signature", matcherV2, handleStreamV2, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := inst.validateMethodInsMatch(tt.matcher, tt.node, nil) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestInstrument_validateMethodInsMatch_WithResultTypeFilter(t *testing.T) { + inst := &Instrument{} + + sendResponseDecl := &dst.FuncDecl{ + Name: dst.NewIdent("sendResponse"), + Recv: &dst.FieldList{List: []*dst.Field{ + {Type: &dst.StarExpr{X: dst.NewIdent("Server")}}, + }}, + Type: &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{}}, + Results: &dst.FieldList{List: []*dst.Field{ + {Type: dst.NewIdent("error")}, + }}, + }, + } + + matcher := &instrument.EnhanceMatcher{ + Type: instrument.EnhanceTypeMethod, + Name: "sendResponse", + Receiver: "*Server", + MethodFilters: []instrument.MethodFilterOption{ + instrument.WithResultType(0, "error"), + }, + } + + assert.True(t, inst.validateMethodInsMatch(matcher, sendResponseDecl, nil)) +} + +func TestInstrument_validateMethodInsMatch_ReceiverMismatch(t *testing.T) { + inst := &Instrument{} + + methodDecl := &dst.FuncDecl{ + Name: dst.NewIdent("handleStream"), + Recv: &dst.FieldList{List: []*dst.Field{ + {Type: &dst.StarExpr{X: dst.NewIdent("ClientConn")}}, + }}, + Type: &dst.FuncType{ + Params: &dst.FieldList{List: []*dst.Field{ + {Type: &dst.SelectorExpr{X: dst.NewIdent("transport"), Sel: dst.NewIdent("ServerTransport")}}, + {Type: &dst.StarExpr{X: &dst.SelectorExpr{X: dst.NewIdent("transport"), Sel: dst.NewIdent("Stream")}}}, + }}, + }, + } + + matcher := &instrument.EnhanceMatcher{ + Type: instrument.EnhanceTypeMethod, + Name: "handleStream", + Receiver: "*Server", + MethodFilters: []instrument.MethodFilterOption{ + instrument.WithArgType(0, "transport.ServerTransport"), + instrument.WithArgType(1, "*transport.Stream"), + }, + } + + assert.False(t, inst.validateMethodInsMatch(matcher, methodDecl, nil)) +}