mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-28 16:19:58 +00:00
TUN-4168: Transparently proxy websocket connections using stdlib HTTP client instead of gorilla/websocket; move websocket client code into carrier package since it's only used by access subcommands now (#345).
This commit is contained in:
@@ -67,7 +67,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||
lbProbe: lbProbe,
|
||||
rule: ingress.ServiceWarpRouting,
|
||||
}
|
||||
if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, p.warpRouting.Proxy, logFields); err != nil {
|
||||
if err := p.proxyStreamRequest(serveCtx, w, req, p.warpRouting.Proxy, logFields); err != nil {
|
||||
p.logRequestError(err, cfRay, ingress.ServiceWarpRouting)
|
||||
return err
|
||||
}
|
||||
@@ -96,7 +96,7 @@ func (p *proxy) Proxy(w connection.ResponseWriter, req *http.Request, sourceConn
|
||||
return fmt.Errorf("Not a connection-oriented service")
|
||||
}
|
||||
|
||||
if err := p.proxyStreamRequest(serveCtx, w, req, sourceConnectionType, connectionProxy, logFields); err != nil {
|
||||
if err := p.proxyStreamRequest(serveCtx, w, req, connectionProxy, logFields); err != nil {
|
||||
p.logRequestError(err, cfRay, ruleNum)
|
||||
return err
|
||||
}
|
||||
@@ -152,7 +152,6 @@ func (p *proxy) proxyStreamRequest(
|
||||
serveCtx context.Context,
|
||||
w connection.ResponseWriter,
|
||||
req *http.Request,
|
||||
sourceConnectionType connection.Type,
|
||||
connectionProxy ingress.StreamBasedOriginProxy,
|
||||
fields logFields,
|
||||
) error {
|
||||
|
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/cli/v2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
@@ -79,6 +80,11 @@ func (w *mockWSRespWriter) respBody() io.ReadWriter {
|
||||
return bytes.NewBuffer(data)
|
||||
}
|
||||
|
||||
func (w *mockWSRespWriter) Close() error {
|
||||
close(w.writeNotification)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mockWSRespWriter) Read(data []byte) (int, error) {
|
||||
return w.reader.Read(data)
|
||||
}
|
||||
@@ -125,14 +131,14 @@ func TestProxySingleOrigin(t *testing.T) {
|
||||
require.NoError(t, ingressRule.StartOrigins(&wg, &log, ctx.Done(), errC))
|
||||
|
||||
proxy := NewOriginProxy(ingressRule, unusedWarpRoutingService, testTags, &log)
|
||||
t.Run("testProxyHTTP", testProxyHTTP(t, proxy))
|
||||
t.Run("testProxyWebsocket", testProxyWebsocket(t, proxy))
|
||||
t.Run("testProxySSE", testProxySSE(t, proxy))
|
||||
t.Run("testProxyHTTP", testProxyHTTP(proxy))
|
||||
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
|
||||
t.Run("testProxySSE", testProxySSE(proxy))
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
|
||||
func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
responseWriter := newMockHTTPRespWriter()
|
||||
req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
|
||||
@@ -145,23 +151,43 @@ func testProxyHTTP(t *testing.T, proxy connection.OriginProxy) func(t *testing.T
|
||||
}
|
||||
}
|
||||
|
||||
func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
|
||||
func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
// WSRoute is a websocket echo handler
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
const testTimeout = 5 * time.Second * 1000
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
|
||||
defer cancel()
|
||||
readPipe, writePipe := io.Pipe()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), readPipe)
|
||||
responseWriter := newMockWSRespWriter(readPipe)
|
||||
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
|
||||
req.Header.Set("Connection", "Upgrade")
|
||||
req.Header.Set("Upgrade", "websocket")
|
||||
responseWriter := newMockWSRespWriter(nil)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
finished := make(chan struct{})
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
err = proxy.Proxy(responseWriter, req, connection.TypeWebsocket)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code)
|
||||
}()
|
||||
return nil
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
select {
|
||||
case <-finished:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
t.Errorf("Test timed out")
|
||||
readPipe.Close()
|
||||
writePipe.Close()
|
||||
responseWriter.Close()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
msg := []byte("test websocket")
|
||||
err = wsutil.WriteClientText(writePipe, msg)
|
||||
@@ -179,12 +205,16 @@ func testProxyWebsocket(t *testing.T, proxy connection.OriginProxy) func(t *test
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, msg, returnedMsg)
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
_ = readPipe.Close()
|
||||
_ = writePipe.Close()
|
||||
_ = responseWriter.Close()
|
||||
|
||||
close(finished)
|
||||
errGroup.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func testProxySSE(t *testing.T, proxy connection.OriginProxy) func(t *testing.T) {
|
||||
func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
|
||||
return func(t *testing.T) {
|
||||
var (
|
||||
pushCount = 50
|
||||
|
Reference in New Issue
Block a user