TUN-5184: Make sure outstanding websocket write is finished, and no more writes after shutdown

This commit is contained in:
cthuang
2021-10-19 20:01:17 +01:00
committed by Nuno Diegues
parent 1ff5fd3fdc
commit db01127191
6 changed files with 212 additions and 70 deletions

View File

@@ -4,16 +4,17 @@ import (
"context"
"fmt"
"io"
"math/rand"
"net/http"
"net/url"
"testing"
"time"
"github.com/gobwas/ws/wsutil"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/websocket"
)
const (
@@ -50,7 +51,15 @@ func (moc *mockOriginProxy) ProxyHTTP(
isWebsocket bool,
) error {
if isWebsocket {
return wsEndpoint(w, req)
switch req.URL.Path {
case "/ws/echo":
return wsEchoEndpoint(w, req)
case "/ws/flaky":
return wsFlakyEndpoint(w, req)
default:
originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
return fmt.Errorf("Unknwon websocket endpoint %s", req.URL.Path)
}
}
switch req.URL.Path {
case "/ok":
@@ -78,32 +87,82 @@ func (moc *mockOriginProxy) ProxyTCP(
return nil
}
type nowriter struct {
io.Reader
type echoPipe struct {
reader *io.PipeReader
writer *io.PipeWriter
}
func (nowriter) Write(p []byte) (int, error) {
return 0, fmt.Errorf("Writer not implemented")
func (ep *echoPipe) Read(p []byte) (int, error) {
return ep.reader.Read(p)
}
func wsEndpoint(w ResponseWriter, r *http.Request) error {
func (ep *echoPipe) Write(p []byte) (int, error) {
return ep.writer.Write(p)
}
// A mock origin that echos data by streaming like a tcpOverWSConnection
// https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go
func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
resp := &http.Response{
StatusCode: http.StatusSwitchingProtocols,
}
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
clientReader := nowriter{r.Body}
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
return err
}
wsCtx, cancel := context.WithCancel(r.Context())
readPipe, writePipe := io.Pipe()
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
go func() {
for {
data, err := wsutil.ReadClientText(clientReader)
if err != nil {
return
}
if err := wsutil.WriteServerText(w, data); err != nil {
return
}
select {
case <-wsCtx.Done():
case <-r.Context().Done():
}
readPipe.Close()
writePipe.Close()
}()
<-r.Context().Done()
originConn := &echoPipe{reader: readPipe, writer: writePipe}
websocket.Stream(wsConn, originConn, &log)
cancel()
wsConn.Close()
return nil
}
type flakyConn struct {
closeAt time.Time
}
func (fc *flakyConn) Read(p []byte) (int, error) {
if time.Now().After(fc.closeAt) {
return 0, io.EOF
}
n := copy(p, "Read from flaky connection")
return n, nil
}
func (fc *flakyConn) Write(p []byte) (int, error) {
if time.Now().After(fc.closeAt) {
return 0, fmt.Errorf("flaky connection closed")
}
return len(p), nil
}
func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
resp := &http.Response{
StatusCode: http.StatusSwitchingProtocols,
}
if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
return err
}
wsCtx, cancel := context.WithCancel(r.Context())
wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, r), &log)
closedAfter := time.Millisecond * time.Duration(rand.Intn(50))
originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
websocket.Stream(wsConn, originConn, &log)
cancel()
wsConn.Close()
return nil
}