diff --git a/go.mod b/go.mod index c1b50edcf..e05f1a421 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/golang/mock v1.6.0 github.com/google/go-cmp v0.5.9 github.com/google/uuid v1.3.0 + github.com/gorilla/websocket v1.5.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.15.0 github.com/hashicorp/go-hclog v1.4.0 github.com/hashicorp/go-plugin v1.4.8 @@ -33,7 +34,6 @@ require ( github.com/prometheus/client_model v0.3.0 github.com/prometheus/common v0.39.0 github.com/rs/zerolog v1.28.0 - github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75 go.buf.build/grpc/go/conduitio/conduit-connector-protocol v1.4.5 go.buf.build/protocolbuffers/go/grpc-ecosystem/grpc-gateway v1.3.50 go.uber.org/goleak v1.2.0 @@ -85,7 +85,6 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v2.0.0+incompatible // indirect - github.com/gorilla/websocket v1.4.2 // indirect github.com/hashicorp/yamux v0.1.1 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.13.0 // indirect @@ -111,7 +110,6 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/prometheus/procfs v0.8.0 // indirect github.com/segmentio/kafka-go v0.4.35 // indirect - github.com/sirupsen/logrus v1.8.1 // indirect github.com/xdg/scram v1.0.5 // indirect github.com/xdg/stringprep v1.0.3 // indirect github.com/xitongsys/parquet-go v1.6.2 // indirect diff --git a/go.sum b/go.sum index 206d1d5eb..8bcc5e993 100644 --- a/go.sum +++ b/go.sum @@ -284,8 +284,8 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= -github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= -github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/grpc-ecosystem/grpc-gateway/v2 v2.15.0 h1:1JYBfzqrWPcCclBwxFCPAou9n+q86mfnu7NAeHfte7A= github.com/grpc-ecosystem/grpc-gateway/v2 v2.15.0/go.mod h1:YDZoGHuwE+ov0c8smSH49WLF3F2LaWnYYuDVd+EWrc0= @@ -493,8 +493,6 @@ github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFR github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= -github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= -github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= @@ -520,8 +518,6 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75 h1:6fotK7otjonDflCTK0BCfls4SPy3NcCVb5dqqmbRknE= -github.com/tmc/grpc-websocket-proxy v0.0.0-20220101234140-673ab2c3ae75/go.mod h1:KO6IkyS8Y3j8OdNO85qEYBsRPuteD+YciPomcXdrMnk= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/xdg/scram v1.0.5 h1:TuS0RFmt5Is5qm9Tm2SoD89OPqe4IRiFtyFY4iwWXsw= github.com/xdg/scram v1.0.5/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= @@ -662,7 +658,6 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211123203042-d83791d6bcd9/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220706163947-c90051bbdb60/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU= diff --git a/pkg/conduit/runtime.go b/pkg/conduit/runtime.go index 5beafbf4b..2abff4229 100644 --- a/pkg/conduit/runtime.go +++ b/pkg/conduit/runtime.go @@ -454,6 +454,7 @@ func (r *Runtime) serveHTTPAPI( grpcutil.WithDefaultGatewayMiddleware( r.logger, allowCORS(gwmux, "http://localhost:4200"), ), + r.logger, ) return r.serveHTTP( diff --git a/pkg/foundation/grpcutil/gateway.go b/pkg/foundation/grpcutil/gateway.go index 77c9c6275..8cc0a36af 100644 --- a/pkg/foundation/grpcutil/gateway.go +++ b/pkg/foundation/grpcutil/gateway.go @@ -21,7 +21,6 @@ import ( "github.com/conduitio/conduit/pkg/foundation/log" "github.com/google/uuid" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" - "github.com/tmc/grpc-websocket-proxy/wsproxy" "google.golang.org/protobuf/encoding/protojson" ) @@ -110,8 +109,8 @@ func WithHTTPEndpointHeader(h http.Handler) http.Handler { }) } -func WithWebsockets(h http.Handler) http.Handler { - return wsproxy.WebsocketProxy(h) +func WithWebsockets(h http.Handler, l log.CtxLogger) http.Handler { + return newWebSocketProxy(h, l) } func extractEndpoint(r *http.Request) string { diff --git a/pkg/foundation/grpcutil/websocket.go b/pkg/foundation/grpcutil/websocket.go new file mode 100644 index 000000000..f2b1e360a --- /dev/null +++ b/pkg/foundation/grpcutil/websocket.go @@ -0,0 +1,252 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package grpcutil + +import ( + "bufio" + "context" + "io" + "net/http" + "strings" + "time" + + "github.com/conduitio/conduit/pkg/foundation/log" + "github.com/gorilla/websocket" +) + +type inMemoryResponseWriter struct { + io.Writer + header http.Header +} + +func newInMemoryResponseWriter(writer io.Writer) *inMemoryResponseWriter { + return &inMemoryResponseWriter{ + Writer: writer, + header: http.Header{}, + } +} + +func (w *inMemoryResponseWriter) Write(b []byte) (int, error) { + return w.Writer.Write(b) +} +func (w *inMemoryResponseWriter) Header() http.Header { + return w.header +} +func (w *inMemoryResponseWriter) WriteHeader(int) { + // we don't have a use for the code +} +func (w *inMemoryResponseWriter) Flush() {} + +var ( + defaultWriteWait = 10 * time.Second + defaultPongWait = 60 * time.Second +) + +// webSocketProxy is a proxy around a http.Handler which +// redirects the response data from the http.Handler +// to a WebSocket connection. +type webSocketProxy struct { + handler http.Handler + logger log.CtxLogger + upgrader websocket.Upgrader + + // Time allowed to write a message to the peer. + writeWait time.Duration + // Time allowed to read the next pong message from the peer. + pongWait time.Duration + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod time.Duration +} + +func newWebSocketProxy(handler http.Handler, logger log.CtxLogger) *webSocketProxy { + proxy := &webSocketProxy{ + handler: handler, + logger: logger.WithComponent("grpcutil.webSocketProxy"), + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, + writeWait: defaultWriteWait, + pongWait: defaultPongWait, + pingPeriod: (defaultPongWait * 9) / 10, + } + + return proxy +} + +func (p *webSocketProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !websocket.IsWebSocketUpgrade(r) { + p.handler.ServeHTTP(w, r) + return + } + p.proxy(w, r) +} + +// proxy creates a "pipeline" from the underlying response +// to a WebSocket connection. The pipeline is constructed in +// the following way: +// +// underlying response +// -> inMemoryResponseWriter +// -> scanner +// -> messages channel +// -> connection writer +// +// In the case of an error due to which we need to abort the request +// and close the WebSocket connection, we need to cancel the request context +// and stop writing any data to the WebSocket connection. This will +// automatically halt all the "pipeline nodes" after the underlying response. +func (p *webSocketProxy) proxy(w http.ResponseWriter, r *http.Request) { + ctx, cancelFn := context.WithCancel(r.Context()) + defer cancelFn() + r = r.WithContext(ctx) + + // Upgrade connection to WebSocket + conn, err := p.upgrader.Upgrade(w, r, http.Header{}) + if err != nil { + p.logger.Err(ctx, err).Msg("error upgrading websocket") + return + } + defer conn.Close() + + // We use a pipe to read the data being written to the underlying http.Handler + // and then write it to the WebSocket connection. + responseR, responseW := io.Pipe() + response := newInMemoryResponseWriter(responseW) + + // Start the "underlying" http.Handler + go func() { + p.handler.ServeHTTP(response, r) + p.logger.Debug(ctx).Err(ctx.Err()).Msg("closing pipes") + responseW.CloseWithError(io.EOF) + }() + + messages := make(chan []byte) + // startWebSocketRead and startWebSocketWrite need to cancel the context + // if they encounter an error reading from or writing to the WS connection + go p.startWebSocketRead(ctx, conn, cancelFn) + go p.readFromHTTPResponse(ctx, responseR, messages) + p.startWebSocketWrite(ctx, messages, conn, cancelFn) +} + +// startWebSocketRead starts a read loop on the proxy's WebSocket connection. +// The read loop will stop if there's been an error reading a message. +func (p *webSocketProxy) startWebSocketRead(ctx context.Context, conn *websocket.Conn, onDone func()) { + defer onDone() + + conn.SetReadLimit(512) + err := conn.SetReadDeadline(time.Now().Add(p.pongWait)) + if err != nil { + p.logger.Warn(ctx).Err(err).Msgf("couldn't set read deadline %v", p.pongWait) + return + } + + conn.SetPongHandler(func(string) error { + err := conn.SetReadDeadline(time.Now().Add(p.pongWait)) + if err != nil { + // todo return err? + p.logger.Warn(ctx).Err(err).Msgf("couldn't set read deadline %v", p.pongWait) + } + return nil + }) + + for { + // The only use we have for reads right now + // is for ping, pong and close messages. + // https://pkg.go.dev/github.com/gorilla/websocket#hdr-Control_Messages + // Also, a read loop can detect client disconnects much quicker: + // https://groups.google.com/g/golang-nuts/c/FFzQO26jEoE/m/mYhcsK20EwAJ + _, _, err := conn.ReadMessage() + if err != nil { + if p.isClosedConnErr(err) { + p.logger.Debug(ctx).Err(err).Msg("closed connection") + } + + p.logger.Warn(ctx).Err(err).Msg("read error") + break + } + } +} + +func (p *webSocketProxy) isClosedConnErr(err error) bool { + str := err.Error() + if strings.Contains(str, "use of closed network connection") { + return true + } + return websocket.IsCloseError( + err, + websocket.CloseNormalClosure, + websocket.CloseGoingAway, + websocket.CloseAbnormalClosure, + ) +} + +func (p *webSocketProxy) readFromHTTPResponse(ctx context.Context, responseReader io.Reader, c chan []byte) { + defer close(c) + scanner := bufio.NewScanner(responseReader) + + for scanner.Scan() { + if len(scanner.Bytes()) == 0 { + p.logger.Warn(ctx).Err(scanner.Err()).Msg("[write] empty scan") + continue + } + + p.logger.Trace(ctx).Msgf("[write] scanned %v", scanner.Text()) + c <- scanner.Bytes() + } + + if sErr := scanner.Err(); sErr != nil { + p.logger.Err(ctx, sErr).Msg("failed reading data from original response") + c <- []byte(sErr.Error()) + } + + p.logger.Debug(ctx).Msg("scanner reached end of input data") +} + +func (p *webSocketProxy) startWebSocketWrite(ctx context.Context, messages chan []byte, conn *websocket.Conn, cancelFn func()) { + ticker := time.NewTicker(p.pingPeriod) + defer func() { + ticker.Stop() + cancelFn() + for range messages { + // throw away + } + }() + + for { + select { + case message, ok := <-messages: + conn.SetWriteDeadline(time.Now().Add(p.writeWait)) //nolint:errcheck // always returns nil + if !ok { + // readFromHTTPResponse closed the channel. + err := conn.WriteMessage(websocket.CloseMessage, []byte{}) + if err != nil { + p.logger.Warn(ctx).Err(err).Msg("[write] failed sending close message") + } + return + } + + if err := conn.WriteMessage(websocket.TextMessage, message); err != nil { + p.logger.Warn(ctx).Err(err).Msg("[write] error writing websocket message") + return + } + case <-ticker.C: + conn.SetWriteDeadline(time.Now().Add(p.writeWait)) //nolint:errcheck // always returns nil + if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} diff --git a/pkg/foundation/grpcutil/websocket_test.go b/pkg/foundation/grpcutil/websocket_test.go new file mode 100644 index 000000000..b37a16955 --- /dev/null +++ b/pkg/foundation/grpcutil/websocket_test.go @@ -0,0 +1,156 @@ +// Copyright © 2022 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package grpcutil + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/conduitio/conduit/pkg/foundation/cchan" + "github.com/conduitio/conduit/pkg/foundation/log" + "github.com/gorilla/websocket" + "github.com/matryer/is" +) + +func TestWebSocket_NoUpgradeToWebSocket(t *testing.T) { + is := is.New(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + msg := "hi there" + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte(msg)) + is.NoErr(err) + }) + s := httptest.NewServer(newWebSocketProxy(h, log.Nop())) + defer s.Close() + + req, err := http.NewRequestWithContext(ctx, "GET", s.URL, nil) + is.NoErr(err) + + resp, err := http.DefaultClient.Do(req) + is.NoErr(err) + is.True(resp.Body != nil) // expected response to have a body + defer resp.Body.Close() + + bytes, err := io.ReadAll(resp.Body) + is.NoErr(err) + is.Equal(msg, string(bytes)) +} + +func TestWebSocket_Read_Single(t *testing.T) { + is := is.New(t) + + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Data written to a WebSocket is new-line delimited + _, err := w.Write([]byte("hi there\n")) + is.NoErr(err) + }) + s := httptest.NewServer(newWebSocketProxy(h, log.Nop())) + defer s.Close() + + // Convert http to ws + wsURL := "ws" + strings.TrimPrefix(s.URL, "http") + + // Connect to the server + ws, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) + is.NoErr(err) + defer ws.Close() + defer resp.Body.Close() + + msgType, bytes, err := ws.ReadMessage() + is.NoErr(err) + is.Equal("hi there", string(bytes)) + is.Equal(websocket.TextMessage, msgType) + + _, _, err = ws.ReadMessage() + is.True(err != nil) + + err = ws.Close() + is.NoErr(err) +} + +func TestWebSocket_Read_Multiple(t *testing.T) { + is := is.New(t) + + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Data written to a WebSocket is new-line delimited + _, err := w.Write([]byte("first message\n")) + is.NoErr(err) + + _, err = w.Write([]byte("second message\n")) + is.NoErr(err) + }) + s := httptest.NewServer(newWebSocketProxy(h, log.Nop())) + defer s.Close() + + // Convert http to ws + wsURL := "ws" + strings.TrimPrefix(s.URL, "http") + + // Connect to the server + ws, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) + is.NoErr(err) + defer ws.Close() + defer resp.Body.Close() + + msgType, bytes, err := ws.ReadMessage() + is.NoErr(err) + is.Equal("first message", string(bytes)) + is.Equal(websocket.TextMessage, msgType) + + msgType, bytes, err = ws.ReadMessage() + is.NoErr(err) + is.Equal("second message", string(bytes)) + is.Equal(websocket.TextMessage, msgType) + + _, _, err = ws.ReadMessage() + is.True(err != nil) + + err = ws.Close() + is.NoErr(err) +} + +func TestWebSocket_Read_ClientClosed(t *testing.T) { + is := is.New(t) + + handlerDone := make(chan struct{}) + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer close(handlerDone) + <-r.Context().Done() + }) + s := httptest.NewServer(newWebSocketProxy(h, log.Nop())) + defer s.Close() + + // Convert http to ws + wsURL := "ws" + strings.TrimPrefix(s.URL, "http") + + // Connect to the server + ws, resp, err := websocket.DefaultDialer.Dial(wsURL, nil) + is.NoErr(err) + defer ws.Close() + defer resp.Body.Close() + + err = ws.Close() + is.NoErr(err) + + _, ok, err := cchan.Chan[struct{}](handlerDone).RecvTimeout(context.Background(), time.Second) + is.True(!ok) // expected channel to be closed + is.NoErr(err) +}