mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-05-22 04:16:34 +00:00

Whenever cloudflared receives a SIGTERM or SIGINT it goes into graceful shutdown mode, which unregisters the connection and closes the control stream. Unregistering makes it so we no longer receive any new requests and makes the edge close the connection, allowing in-flight requests to finish (within a 3 minute period). This was working fine for http2 connections, but the quic proxy was cancelling the context as soon as the controls stream ended, forcing the process to stop immediately. This commit changes the behavior so that we wait the full grace period before cancelling the request
585 lines
14 KiB
Go
585 lines
14 KiB
Go
package connection
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gobwas/ws/wsutil"
|
|
"github.com/google/uuid"
|
|
"github.com/rs/zerolog"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/net/http2"
|
|
|
|
"github.com/cloudflare/cloudflared/tunnelrpc"
|
|
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
)
|
|
|
|
var (
|
|
testTransport = http2.Transport{}
|
|
)
|
|
|
|
func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
|
|
edgeConn, cfdConn := net.Pipe()
|
|
var connIndex = uint8(0)
|
|
log := zerolog.Nop()
|
|
obs := NewObserver(&log, &log)
|
|
controlStream := NewControlStream(
|
|
obs,
|
|
mockConnectedFuse{},
|
|
&TunnelProperties{},
|
|
connIndex,
|
|
nil,
|
|
nil,
|
|
1*time.Second,
|
|
nil,
|
|
1*time.Second,
|
|
HTTP2,
|
|
)
|
|
return NewHTTP2Connection(
|
|
cfdConn,
|
|
// OriginProxy is set in testConfigManager
|
|
testOrchestrator,
|
|
&pogs.ConnectionOptions{},
|
|
obs,
|
|
connIndex,
|
|
controlStream,
|
|
&log,
|
|
), edgeConn
|
|
}
|
|
|
|
func TestHTTP2ConfigurationSet(t *testing.T) {
|
|
http2Conn, edgeConn := newTestHTTP2Connection()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
http2Conn.Serve(ctx)
|
|
}()
|
|
|
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
|
require.NoError(t, err)
|
|
|
|
endpoint := fmt.Sprintf("http://localhost:8080/ok")
|
|
reqBody := []byte(`{
|
|
"version": 2,
|
|
"config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}}
|
|
`)
|
|
reader := bytes.NewReader(reqBody)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPut, endpoint, reader)
|
|
require.NoError(t, err)
|
|
req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate)
|
|
|
|
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
bdy, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy))
|
|
cancel()
|
|
wg.Wait()
|
|
|
|
}
|
|
|
|
func TestServeHTTP(t *testing.T) {
|
|
tests := []testRequest{
|
|
{
|
|
name: "ok",
|
|
endpoint: "ok",
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: []byte(http.StatusText(http.StatusOK)),
|
|
},
|
|
{
|
|
name: "large_file",
|
|
endpoint: "large_file",
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: testLargeResp,
|
|
},
|
|
{
|
|
name: "Bad request",
|
|
endpoint: "400",
|
|
expectedStatus: http.StatusBadRequest,
|
|
expectedBody: []byte(http.StatusText(http.StatusBadRequest)),
|
|
},
|
|
{
|
|
name: "Internal server error",
|
|
endpoint: "500",
|
|
expectedStatus: http.StatusInternalServerError,
|
|
expectedBody: []byte(http.StatusText(http.StatusInternalServerError)),
|
|
},
|
|
{
|
|
name: "Proxy error",
|
|
endpoint: "error",
|
|
expectedStatus: http.StatusBadGateway,
|
|
expectedBody: nil,
|
|
isProxyError: true,
|
|
},
|
|
}
|
|
|
|
http2Conn, edgeConn := newTestHTTP2Connection()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
http2Conn.Serve(ctx)
|
|
}()
|
|
|
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
|
require.NoError(t, err)
|
|
|
|
for _, test := range tests {
|
|
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
|
require.NoError(t, err)
|
|
|
|
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, test.expectedStatus, resp.StatusCode)
|
|
if test.expectedBody != nil {
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.Equal(t, test.expectedBody, respBody)
|
|
}
|
|
if test.isProxyError {
|
|
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
|
|
} else {
|
|
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
|
|
}
|
|
}
|
|
cancel()
|
|
wg.Wait()
|
|
}
|
|
|
|
type mockNamedTunnelRPCClient struct {
|
|
shouldFail error
|
|
registered chan struct{}
|
|
unregistered chan struct{}
|
|
}
|
|
|
|
func (mc mockNamedTunnelRPCClient) SendLocalConfiguration(c context.Context, config []byte) error {
|
|
return nil
|
|
}
|
|
|
|
func (mc mockNamedTunnelRPCClient) RegisterConnection(
|
|
ctx context.Context,
|
|
auth pogs.TunnelAuth,
|
|
tunnelID uuid.UUID,
|
|
options *pogs.ConnectionOptions,
|
|
connIndex uint8,
|
|
edgeAddress net.IP,
|
|
) (*pogs.ConnectionDetails, error) {
|
|
if mc.shouldFail != nil {
|
|
return nil, mc.shouldFail
|
|
}
|
|
close(mc.registered)
|
|
return &pogs.ConnectionDetails{
|
|
Location: "LIS",
|
|
UUID: uuid.New(),
|
|
TunnelIsRemotelyManaged: false,
|
|
}, nil
|
|
}
|
|
|
|
func (mc mockNamedTunnelRPCClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) error {
|
|
close(mc.unregistered)
|
|
return nil
|
|
}
|
|
|
|
func (mockNamedTunnelRPCClient) Close() {}
|
|
|
|
type mockRPCClientFactory struct {
|
|
shouldFail error
|
|
registered chan struct{}
|
|
unregistered chan struct{}
|
|
}
|
|
|
|
func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, time.Duration) tunnelrpc.RegistrationClient {
|
|
return &mockNamedTunnelRPCClient{
|
|
shouldFail: mf.shouldFail,
|
|
registered: mf.registered,
|
|
unregistered: mf.unregistered,
|
|
}
|
|
}
|
|
|
|
type wsRespWriter struct {
|
|
*httptest.ResponseRecorder
|
|
readPipe *io.PipeReader
|
|
writePipe *io.PipeWriter
|
|
closed bool
|
|
panicked bool
|
|
}
|
|
|
|
func newWSRespWriter() *wsRespWriter {
|
|
readPipe, writePipe := io.Pipe()
|
|
return &wsRespWriter{
|
|
httptest.NewRecorder(),
|
|
readPipe,
|
|
writePipe,
|
|
false,
|
|
false,
|
|
}
|
|
}
|
|
|
|
type nowriter struct {
|
|
io.Reader
|
|
}
|
|
|
|
func (nowriter) Write(_ []byte) (int, error) {
|
|
return 0, fmt.Errorf("writer not implemented")
|
|
}
|
|
|
|
func (w *wsRespWriter) RespBody() io.ReadWriter {
|
|
return nowriter{w.readPipe}
|
|
}
|
|
|
|
func (w *wsRespWriter) Write(data []byte) (n int, err error) {
|
|
if w.closed {
|
|
w.panicked = true
|
|
return 0, errors.New("wsRespWriter panicked")
|
|
}
|
|
return w.writePipe.Write(data)
|
|
}
|
|
|
|
func (w *wsRespWriter) close() {
|
|
w.closed = true
|
|
}
|
|
|
|
func TestServeWS(t *testing.T) {
|
|
http2Conn, _ := newTestHTTP2Connection()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
respWriter := newWSRespWriter()
|
|
readPipe, writePipe := io.Pipe()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe)
|
|
require.NoError(t, err)
|
|
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
|
|
|
serveDone := make(chan struct{})
|
|
go func() {
|
|
defer close(serveDone)
|
|
http2Conn.ServeHTTP(respWriter, req)
|
|
respWriter.close()
|
|
}()
|
|
|
|
data := []byte("test websocket")
|
|
err = wsutil.WriteClientBinary(writePipe, data)
|
|
require.NoError(t, err)
|
|
|
|
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
|
|
require.NoError(t, err)
|
|
require.Equal(t, data, respBody, fmt.Sprintf("Expect %s, got %s", string(data), string(respBody)))
|
|
|
|
cancel()
|
|
resp := respWriter.Result()
|
|
// http2RespWriter should rewrite status 101 to 200
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
|
|
|
|
<-serveDone
|
|
require.False(t, respWriter.panicked)
|
|
}
|
|
|
|
// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184
|
|
// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns
|
|
func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
|
|
cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
var wg sync.WaitGroup
|
|
|
|
serverDone := make(chan struct{})
|
|
go func() {
|
|
defer close(serverDone)
|
|
cfdHTTP2Conn.Serve(ctx)
|
|
}()
|
|
|
|
edgeTransport := http2.Transport{}
|
|
edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn)
|
|
require.NoError(t, err)
|
|
message := []byte(t.Name())
|
|
|
|
for i := 0; i < 100; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
readPipe, writePipe := io.Pipe()
|
|
reqCtx, reqCancel := context.WithCancel(ctx)
|
|
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
|
|
require.NoError(t, err)
|
|
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
|
|
|
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
// http2RespWriter should rewrite status 101 to 200
|
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
for {
|
|
select {
|
|
case <-reqCtx.Done():
|
|
return
|
|
default:
|
|
}
|
|
_ = wsutil.WriteClientBinary(writePipe, message)
|
|
}
|
|
}()
|
|
|
|
time.Sleep(time.Millisecond * 100)
|
|
reqCancel()
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
cancel()
|
|
<-serverDone
|
|
}
|
|
|
|
func TestServeControlStream(t *testing.T) {
|
|
http2Conn, edgeConn := newTestHTTP2Connection()
|
|
|
|
rpcClientFactory := mockRPCClientFactory{
|
|
registered: make(chan struct{}),
|
|
unregistered: make(chan struct{}),
|
|
}
|
|
|
|
obs := NewObserver(&log, &log)
|
|
controlStream := NewControlStream(
|
|
obs,
|
|
mockConnectedFuse{},
|
|
&TunnelProperties{},
|
|
1,
|
|
nil,
|
|
rpcClientFactory.newMockRPCClient,
|
|
1*time.Second,
|
|
nil,
|
|
1*time.Second,
|
|
HTTP2,
|
|
)
|
|
http2Conn.controlStreamHandler = controlStream
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
http2Conn.Serve(ctx)
|
|
}()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
|
require.NoError(t, err)
|
|
req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
|
|
|
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
|
require.NoError(t, err)
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
edgeHTTP2Conn.RoundTrip(req)
|
|
}()
|
|
|
|
<-rpcClientFactory.registered
|
|
cancel()
|
|
<-rpcClientFactory.unregistered
|
|
assert.False(t, http2Conn.stoppedGracefully)
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestFailRegistration(t *testing.T) {
|
|
http2Conn, edgeConn := newTestHTTP2Connection()
|
|
|
|
rpcClientFactory := mockRPCClientFactory{
|
|
shouldFail: errDuplicationConnection,
|
|
registered: make(chan struct{}),
|
|
unregistered: make(chan struct{}),
|
|
}
|
|
|
|
obs := NewObserver(&log, &log)
|
|
controlStream := NewControlStream(
|
|
obs,
|
|
mockConnectedFuse{},
|
|
&TunnelProperties{},
|
|
http2Conn.connIndex,
|
|
nil,
|
|
rpcClientFactory.newMockRPCClient,
|
|
1*time.Second,
|
|
nil,
|
|
1*time.Second,
|
|
HTTP2,
|
|
)
|
|
http2Conn.controlStreamHandler = controlStream
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
http2Conn.Serve(ctx)
|
|
}()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
|
require.NoError(t, err)
|
|
req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
|
|
|
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
|
require.NoError(t, err)
|
|
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
|
require.NoError(t, err)
|
|
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
|
|
|
|
assert.NotNil(t, http2Conn.controlStreamErr)
|
|
cancel()
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestGracefulShutdownHTTP2(t *testing.T) {
|
|
http2Conn, edgeConn := newTestHTTP2Connection()
|
|
|
|
rpcClientFactory := mockRPCClientFactory{
|
|
registered: make(chan struct{}),
|
|
unregistered: make(chan struct{}),
|
|
}
|
|
events := &eventCollectorSink{}
|
|
|
|
shutdownC := make(chan struct{})
|
|
obs := NewObserver(&log, &log)
|
|
obs.RegisterSink(events)
|
|
controlStream := NewControlStream(
|
|
obs,
|
|
mockConnectedFuse{},
|
|
&TunnelProperties{},
|
|
http2Conn.connIndex,
|
|
nil,
|
|
rpcClientFactory.newMockRPCClient,
|
|
1*time.Second,
|
|
shutdownC,
|
|
1*time.Second,
|
|
HTTP2,
|
|
)
|
|
|
|
http2Conn.controlStreamHandler = controlStream
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
http2Conn.Serve(ctx)
|
|
}()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
|
require.NoError(t, err)
|
|
req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
|
|
|
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
|
require.NoError(t, err)
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
_, _ = edgeHTTP2Conn.RoundTrip(req)
|
|
}()
|
|
|
|
select {
|
|
case <-rpcClientFactory.registered:
|
|
break // ok
|
|
case <-time.Tick(time.Second):
|
|
t.Fatal("timeout out waiting for registration")
|
|
}
|
|
|
|
// signal graceful shutdown
|
|
close(shutdownC)
|
|
|
|
select {
|
|
case <-rpcClientFactory.unregistered:
|
|
break // ok
|
|
case <-time.Tick(time.Second):
|
|
t.Fatal("timeout out waiting for unregistered signal")
|
|
}
|
|
assert.True(t, controlStream.IsStopped())
|
|
|
|
cancel()
|
|
wg.Wait()
|
|
|
|
events.assertSawEvent(t, Event{
|
|
Index: http2Conn.connIndex,
|
|
EventType: Unregistering,
|
|
})
|
|
}
|
|
|
|
func benchmarkServeHTTP(b *testing.B, test testRequest) {
|
|
http2Conn, edgeConn := newTestHTTP2Connection()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
var wg sync.WaitGroup
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
http2Conn.Serve(ctx)
|
|
}()
|
|
|
|
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
|
require.NoError(b, err)
|
|
|
|
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
|
require.NoError(b, err)
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
b.StartTimer()
|
|
resp, err := edgeHTTP2Conn.RoundTrip(req)
|
|
b.StopTimer()
|
|
require.NoError(b, err)
|
|
require.Equal(b, test.expectedStatus, resp.StatusCode)
|
|
if test.expectedBody != nil {
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
require.NoError(b, err)
|
|
require.Equal(b, test.expectedBody, respBody)
|
|
}
|
|
resp.Body.Close()
|
|
}
|
|
|
|
cancel()
|
|
wg.Wait()
|
|
}
|
|
|
|
func BenchmarkServeHTTPSimple(b *testing.B) {
|
|
test := testRequest{
|
|
name: "ok",
|
|
endpoint: "ok",
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: []byte(http.StatusText(http.StatusOK)),
|
|
}
|
|
|
|
benchmarkServeHTTP(b, test)
|
|
}
|
|
|
|
func BenchmarkServeHTTPLargeFile(b *testing.B) {
|
|
test := testRequest{
|
|
name: "large_file",
|
|
endpoint: "large_file",
|
|
expectedStatus: http.StatusOK,
|
|
expectedBody: testLargeResp,
|
|
}
|
|
|
|
benchmarkServeHTTP(b, test)
|
|
}
|