mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 19:49:57 +00:00
TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown
This commit is contained in:
@@ -2,6 +2,7 @@ package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
@@ -27,7 +28,7 @@ var (
|
||||
)
|
||||
|
||||
func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
|
||||
edgeConn, originConn := net.Pipe()
|
||||
edgeConn, cfdConn := net.Pipe()
|
||||
var connIndex = uint8(0)
|
||||
log := zerolog.Nop()
|
||||
obs := NewObserver(&log, &log, false)
|
||||
@@ -41,7 +42,8 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
|
||||
1*time.Second,
|
||||
)
|
||||
return NewHTTP2Connection(
|
||||
originConn,
|
||||
cfdConn,
|
||||
// OriginProxy is set in testConfig
|
||||
testConfig,
|
||||
&pogs.ConnectionOptions{},
|
||||
obs,
|
||||
@@ -166,6 +168,8 @@ type wsRespWriter struct {
|
||||
*httptest.ResponseRecorder
|
||||
readPipe *io.PipeReader
|
||||
writePipe *io.PipeWriter
|
||||
closed bool
|
||||
panicked bool
|
||||
}
|
||||
|
||||
func newWSRespWriter() *wsRespWriter {
|
||||
@@ -174,46 +178,59 @@ func newWSRespWriter() *wsRespWriter {
|
||||
httptest.NewRecorder(),
|
||||
readPipe,
|
||||
writePipe,
|
||||
false,
|
||||
false,
|
||||
}
|
||||
}
|
||||
|
||||
type nowriter struct {
|
||||
io.Reader
|
||||
}
|
||||
|
||||
func (nowriter) Write(_ []byte) (int, error) {
|
||||
return 0, fmt.Errorf("writer not implemented")
|
||||
}
|
||||
|
||||
func (w *wsRespWriter) RespBody() io.ReadWriter {
|
||||
return nowriter{w.readPipe}
|
||||
}
|
||||
|
||||
func (w *wsRespWriter) Write(data []byte) (n int, err error) {
|
||||
if w.closed {
|
||||
w.panicked = true
|
||||
return 0, errors.New("wsRespWriter panicked")
|
||||
}
|
||||
return w.writePipe.Write(data)
|
||||
}
|
||||
|
||||
func (w *wsRespWriter) close() {
|
||||
w.closed = true
|
||||
}
|
||||
|
||||
func TestServeWS(t *testing.T) {
|
||||
http2Conn, _ := newTestHTTP2Connection()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
http2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
respWriter := newWSRespWriter()
|
||||
readPipe, writePipe := io.Pipe()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
||||
|
||||
wg.Add(1)
|
||||
serveDone := make(chan struct{})
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer close(serveDone)
|
||||
http2Conn.ServeHTTP(respWriter, req)
|
||||
respWriter.close()
|
||||
}()
|
||||
|
||||
data := []byte("test websocket")
|
||||
err = wsutil.WriteClientText(writePipe, data)
|
||||
err = wsutil.WriteClientBinary(writePipe, data)
|
||||
require.NoError(t, err)
|
||||
|
||||
respBody, err := wsutil.ReadServerText(respWriter.RespBody())
|
||||
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
|
||||
|
||||
@@ -223,7 +240,65 @@ func TestServeWS(t *testing.T) {
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
|
||||
|
||||
<-serveDone
|
||||
require.False(t, respWriter.panicked)
|
||||
}
|
||||
|
||||
// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184
|
||||
// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns
|
||||
func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
|
||||
cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
go func() {
|
||||
defer close(serverDone)
|
||||
cfdHTTP2Conn.Serve(ctx)
|
||||
}()
|
||||
|
||||
edgeTransport := http2.Transport{}
|
||||
edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn)
|
||||
require.NoError(t, err)
|
||||
message := []byte(t.Name())
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
readPipe, writePipe := io.Pipe()
|
||||
reqCtx, reqCancel := context.WithCancel(ctx)
|
||||
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
||||
|
||||
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
// http2RespWriter should rewrite status 101 to 200
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-reqCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
_ = wsutil.WriteClientBinary(writePipe, message)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
reqCancel()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
cancel()
|
||||
<-serverDone
|
||||
}
|
||||
|
||||
func TestServeControlStream(t *testing.T) {
|
||||
|
Reference in New Issue
Block a user