mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 20:29:57 +00:00
TUN-5141: Make sure websocket pinger returns before streaming returns
This commit is contained in:
@@ -48,7 +48,12 @@ type tcpOverWSConnection struct {
|
||||
}
|
||||
|
||||
func (wc *tcpOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
wc.streamHandler(websocket.NewConn(ctx, tunnelConn, log), wc.conn, log)
|
||||
wsCtx, cancel := context.WithCancel(ctx)
|
||||
wsConn := websocket.NewConn(wsCtx, tunnelConn, log)
|
||||
wc.streamHandler(wsConn, wc.conn, log)
|
||||
cancel()
|
||||
// Makes sure wsConn stops sending ping before terminating the stream
|
||||
wsConn.WaitForShutdown()
|
||||
}
|
||||
|
||||
func (wc *tcpOverWSConnection) Close() {
|
||||
@@ -63,7 +68,12 @@ type socksProxyOverWSConnection struct {
|
||||
}
|
||||
|
||||
func (sp *socksProxyOverWSConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
socks.StreamNetHandler(websocket.NewConn(ctx, tunnelConn, log), sp.accessPolicy, log)
|
||||
wsCtx, cancel := context.WithCancel(ctx)
|
||||
wsConn := websocket.NewConn(wsCtx, tunnelConn, log)
|
||||
socks.StreamNetHandler(wsConn, sp.accessPolicy, log)
|
||||
cancel()
|
||||
// Makes sure wsConn stops sending ping before terminating the stream
|
||||
wsConn.WaitForShutdown()
|
||||
}
|
||||
|
||||
func (sp *socksProxyOverWSConnection) Close() {
|
||||
|
@@ -19,6 +19,7 @@ import (
|
||||
"golang.org/x/net/proxy"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/socks"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
@@ -189,6 +190,53 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
eyeballConn, err := connection.NewHTTP2RespWriter(r, w, connection.TypeWebsocket)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cfdConn, originConn := net.Pipe()
|
||||
tcpOverWSConn := tcpOverWSConnection{
|
||||
conn: cfdConn,
|
||||
streamHandler: DefaultStreamHandler,
|
||||
}
|
||||
go func() {
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
// Simulate losing connection to origin
|
||||
originConn.Close()
|
||||
}()
|
||||
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
|
||||
tcpOverWSConn.Stream(ctx, eyeballConn, testLogger)
|
||||
})
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
client := server.Client()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
for i := 0; i < 50; i++ {
|
||||
eyeballConn, edgeConn := net.Pipe()
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodConnect, server.URL, edgeConn)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resp, err := client.Transport.RoundTrip(req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
||||
|
||||
errGroup.Go(func() error {
|
||||
for {
|
||||
if err := wsutil.WriteClientBinary(eyeballConn, testMessage); err != nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
assert.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
type wsEyeball struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
Reference in New Issue
Block a user