TUN-4168: Transparently proxy websocket connections using stdlib HTTP client instead of gorilla/websocket; move websocket client code into carrier package since it's only used by access subcommands now (#345).

This commit is contained in:
Igor Postelnik
2021-04-02 01:10:43 -05:00
parent b25d38dd72
commit 3ad99b241c
12 changed files with 455 additions and 315 deletions

View File

@@ -67,7 +67,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
lbProbe: lbProbe,
rule: ingress.ServiceWarpRouting,
}
if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy, logFields); err != nil {
if err := p.proxyStreamRequest(serveCtx, w, req, p.warpRouting.Proxy, logFields); err != nil {
p.logRequestError(err, cfRay, ingress.ServiceWarpRouting)
return err
}
@@ -96,7 +96,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
return fmt.Errorf("Not a connection-oriented service")
}
if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, connectionProxy, logFields); err != nil {
if err := p.proxyStreamRequest(serveCtx, w, req, connectionProxy, logFields); err != nil {
p.logRequestError(err, cfRay, ruleNum)
return err
}
@@ -152,7 +152,6 @@ func (p *proxy) proxyStreamRequest(
serveCtx context.Context,
w connection.ResponseWriter,
req *http.Request,
sourceConnectionType connection.Type,
connectionProxy ingress.StreamBasedOriginProxy,
fields logFields,
) error {

View File

@@ -19,6 +19,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
@@ -79,6 +80,11 @@ func (w *mockWSRespWriter) respBody() io.ReadWriter {
return bytes.NewBuffer(data)
}
func (w *mockWSRespWriter) Close() error {
close(w.writeNotification)
return nil
}
func (w *mockWSRespWriter) Read(data []byte) (int, error) {
return w.reader.Read(data)
}
@@ -125,14 +131,14 @@ func TestProxySingleOrigin(t *testing.T) {
require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC))
proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, &log)
t.Run("testProxyHTTP", testProxyHTTP(t, proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(t, proxy))
t.Run("testProxySSE", testProxySSE(t, proxy))
t.Run("testProxyHTTP", testProxyHTTP(proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
t.Run("testProxySSE", testProxySSE(proxy))
cancel()
wg.Wait()
}
func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) {
return func(t *testing.T) {
responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
@@ -145,23 +151,43 @@ func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T
}
}
func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
return func(t *testing.T) {
// WSRoute is a websocket echo handler
ctx, cancel := context.WithCancel(context.Background())
const testTimeout = 5 * time.Second * 1000
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
readPipe, writePipe := io.Pipe()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), readPipe)
responseWriter := newMockWSRespWriter(readPipe)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
responseWriter := newMockWSRespWriter(nil)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
finished := make(chan struct{})
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
err = proxy.Proxy(responseWriter, req, connection.TypeWebsocket)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code)
}()
return nil
})
errGroup.Go(func() error {
select {
case <-finished:
case <-ctx.Done():
}
if ctx.Err() == context.DeadlineExceeded {
t.Errorf("Test timed out")
readPipe.Close()
writePipe.Close()
responseWriter.Close()
}
return nil
})
msg := []byte("test websocket")
err = wsutil.WriteClientText(writePipe, msg)
@@ -179,12 +205,16 @@ func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *test
require.NoError(t, err)
require.Equal(t, msg, returnedMsg)
cancel()
wg.Wait()
_ = readPipe.Close()
_ = writePipe.Close()
_ = responseWriter.Close()
close(finished)
errGroup.Wait()
}
}
func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
return func(t *testing.T) {
var (
pushCount = 50