cloudflared/connection/http2_test.go
João "Pisco" Fernandes 8bfe111cab TUN-8861: Add session limiter to TCP session manager
## Summary
In order to make cloudflared behavior more predictable and
prevent an exhaustion of resources, we have decided to add
session limits that can be configured by the user. This commit
adds the session limiter to the HTTP/TCP handling path.
For now the limiter is set to run only in unlimited mode.
2025-01-20 10:53:53 +00:00

624 lines
15 KiB
Go

package connection
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/gobwas/ws/wsutil"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"
"github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
var (
testTransport = http2.Transport{}
)
func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
edgeConn, cfdConn := net.Pipe()
var connIndex = uint8(0)
log := zerolog.Nop()
obs := NewObserver(&log, &log)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&TunnelProperties{},
connIndex,
nil,
nil,
1*time.Second,
nil,
1*time.Second,
HTTP2,
)
return NewHTTP2Connection(
cfdConn,
// OriginProxy is set in testConfigManager
testOrchestrator,
&pogs.ConnectionOptions{},
obs,
connIndex,
controlStream,
&log,
), edgeConn
}
func TestHTTP2ConfigurationSet(t *testing.T) {
http2Conn, edgeConn := newTestHTTP2Connection()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = http2Conn.Serve(ctx)
}()
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
reqBody := []byte(`{
"version": 2,
"config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}}
`)
reader := bytes.NewReader(reqBody)
req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://localhost:8080/ok", reader)
require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate)
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
bdy, err := io.ReadAll(resp.Body)
defer resp.Body.Close()
require.NoError(t, err)
assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy))
cancel()
wg.Wait()
}
func TestServeHTTP(t *testing.T) {
tests := []testRequest{
{
name: "ok",
endpoint: "ok",
expectedStatus: http.StatusOK,
expectedBody: []byte(http.StatusText(http.StatusOK)),
},
{
name: "large_file",
endpoint: "large_file",
expectedStatus: http.StatusOK,
expectedBody: testLargeResp,
},
{
name: "Bad request",
endpoint: "400",
expectedStatus: http.StatusBadRequest,
expectedBody: []byte(http.StatusText(http.StatusBadRequest)),
},
{
name: "Internal server error",
endpoint: "500",
expectedStatus: http.StatusInternalServerError,
expectedBody: []byte(http.StatusText(http.StatusInternalServerError)),
},
{
name: "Proxy error",
endpoint: "error",
expectedStatus: http.StatusBadGateway,
expectedBody: nil,
isProxyError: true,
},
}
http2Conn, edgeConn := newTestHTTP2Connection()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = http2Conn.Serve(ctx)
}()
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
for _, test := range tests {
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
require.NoError(t, err)
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
require.Equal(t, test.expectedStatus, resp.StatusCode)
if test.expectedBody != nil {
respBody, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, test.expectedBody, respBody)
}
_ = resp.Body.Close()
if test.isProxyError {
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
} else {
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
}
}
cancel()
wg.Wait()
}
type mockNamedTunnelRPCClient struct {
shouldFail error
registered chan struct{}
unregistered chan struct{}
}
func (mc mockNamedTunnelRPCClient) SendLocalConfiguration(c context.Context, config []byte) error {
return nil
}
func (mc mockNamedTunnelRPCClient) RegisterConnection(
ctx context.Context,
auth pogs.TunnelAuth,
tunnelID uuid.UUID,
options *pogs.ConnectionOptions,
connIndex uint8,
edgeAddress net.IP,
) (*pogs.ConnectionDetails, error) {
if mc.shouldFail != nil {
return nil, mc.shouldFail
}
close(mc.registered)
return &pogs.ConnectionDetails{
Location: "LIS",
UUID: uuid.New(),
TunnelIsRemotelyManaged: false,
}, nil
}
func (mc mockNamedTunnelRPCClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) error {
close(mc.unregistered)
return nil
}
func (mockNamedTunnelRPCClient) Close() {}
type mockRPCClientFactory struct {
shouldFail error
registered chan struct{}
unregistered chan struct{}
}
func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, time.Duration) tunnelrpc.RegistrationClient {
return &mockNamedTunnelRPCClient{
shouldFail: mf.shouldFail,
registered: mf.registered,
unregistered: mf.unregistered,
}
}
type wsRespWriter struct {
*httptest.ResponseRecorder
readPipe *io.PipeReader
writePipe *io.PipeWriter
closed bool
panicked bool
}
func newWSRespWriter() *wsRespWriter {
readPipe, writePipe := io.Pipe()
return &wsRespWriter{
httptest.NewRecorder(),
readPipe,
writePipe,
false,
false,
}
}
type nowriter struct {
io.Reader
}
func (nowriter) Write(_ []byte) (int, error) {
return 0, fmt.Errorf("writer not implemented")
}
func (w *wsRespWriter) RespBody() io.ReadWriter {
return nowriter{w.readPipe}
}
func (w *wsRespWriter) Write(data []byte) (n int, err error) {
if w.closed {
w.panicked = true
return 0, errors.New("wsRespWriter panicked")
}
return w.writePipe.Write(data)
}
func (w *wsRespWriter) close() {
w.closed = true
}
func TestServeWS(t *testing.T) {
http2Conn, _ := newTestHTTP2Connection()
ctx, cancel := context.WithCancel(context.Background())
respWriter := newWSRespWriter()
readPipe, writePipe := io.Pipe()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe)
require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
serveDone := make(chan struct{})
go func() {
defer close(serveDone)
http2Conn.ServeHTTP(respWriter, req)
respWriter.close()
}()
data := []byte("test websocket")
err = wsutil.WriteClientBinary(writePipe, data)
require.NoError(t, err)
respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
require.NoError(t, err)
require.Equal(t, data, respBody, "expect %s, got %s", string(data), string(respBody))
cancel()
resp := respWriter.Result()
defer resp.Body.Close()
// http2RespWriter should rewrite status 101 to 200
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
<-serveDone
require.False(t, respWriter.panicked)
}
// TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184
// to make sure we don't write to the ResponseWriter after the ServeHTTP method returns
func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
serverDone := make(chan struct{})
go func() {
defer close(serverDone)
_ = cfdHTTP2Conn.Serve(ctx)
}()
edgeTransport := http2.Transport{}
edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn)
require.NoError(t, err)
message := []byte(t.Name())
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
readPipe, writePipe := io.Pipe()
reqCtx, reqCancel := context.WithCancel(ctx)
req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
assert.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
resp, err := edgeHTTP2Conn.RoundTrip(req)
assert.NoError(t, err)
_ = resp.Body.Close()
// http2RespWriter should rewrite status 101 to 200
assert.Equal(t, http.StatusOK, resp.StatusCode)
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-reqCtx.Done():
return
default:
}
_ = wsutil.WriteClientBinary(writePipe, message)
}
}()
time.Sleep(time.Millisecond * 100)
reqCancel()
}()
}
wg.Wait()
cancel()
<-serverDone
}
func TestServeControlStream(t *testing.T) {
http2Conn, edgeConn := newTestHTTP2Connection()
rpcClientFactory := mockRPCClientFactory{
registered: make(chan struct{}),
unregistered: make(chan struct{}),
}
obs := NewObserver(&log, &log)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&TunnelProperties{},
1,
nil,
rpcClientFactory.newMockRPCClient,
1*time.Second,
nil,
1*time.Second,
HTTP2,
)
http2Conn.controlStreamHandler = controlStream
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = http2Conn.Serve(ctx)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
wg.Add(1)
go func() {
defer wg.Done()
// nolint: bodyclose
_, _ = edgeHTTP2Conn.RoundTrip(req)
}()
<-rpcClientFactory.registered
cancel()
<-rpcClientFactory.unregistered
assert.False(t, http2Conn.stoppedGracefully)
wg.Wait()
}
func TestFailRegistration(t *testing.T) {
http2Conn, edgeConn := newTestHTTP2Connection()
rpcClientFactory := mockRPCClientFactory{
shouldFail: errDuplicationConnection,
registered: make(chan struct{}),
unregistered: make(chan struct{}),
}
obs := NewObserver(&log, &log)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&TunnelProperties{},
http2Conn.connIndex,
nil,
rpcClientFactory.newMockRPCClient,
1*time.Second,
nil,
1*time.Second,
HTTP2,
)
http2Conn.controlStreamHandler = controlStream
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = http2Conn.Serve(ctx)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
require.Error(t, http2Conn.controlStreamErr)
cancel()
wg.Wait()
}
func TestGracefulShutdownHTTP2(t *testing.T) {
http2Conn, edgeConn := newTestHTTP2Connection()
rpcClientFactory := mockRPCClientFactory{
registered: make(chan struct{}),
unregistered: make(chan struct{}),
}
events := &eventCollectorSink{}
shutdownC := make(chan struct{})
obs := NewObserver(&log, &log)
obs.RegisterSink(events)
controlStream := NewControlStream(
obs,
mockConnectedFuse{},
&TunnelProperties{},
http2Conn.connIndex,
nil,
rpcClientFactory.newMockRPCClient,
1*time.Second,
shutdownC,
1*time.Second,
HTTP2,
)
http2Conn.controlStreamHandler = controlStream
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = http2Conn.Serve(ctx)
}()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
require.NoError(t, err)
req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
wg.Add(1)
go func() {
defer wg.Done()
// nolint: bodyclose
_, _ = edgeHTTP2Conn.RoundTrip(req)
}()
select {
case <-rpcClientFactory.registered:
break // ok
case <-time.Tick(time.Second):
t.Fatal("timeout out waiting for registration")
}
// signal graceful shutdown
close(shutdownC)
select {
case <-rpcClientFactory.unregistered:
break // ok
case <-time.Tick(time.Second):
t.Fatal("timeout out waiting for unregistered signal")
}
assert.True(t, controlStream.IsStopped())
cancel()
wg.Wait()
events.assertSawEvent(t, Event{
Index: http2Conn.connIndex,
EventType: Unregistering,
})
}
func TestServeTCP_RateLimited(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
http2Conn, edgeConn := newTestHTTP2Connection()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = http2Conn.Serve(ctx)
}()
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(t, err)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080", nil)
require.NoError(t, err)
req.Header.Set(InternalTCPProxySrcHeader, "tcp")
req.Header.Set(tracing.TracerContextName, "flow-rate-limited")
resp, err := edgeHTTP2Conn.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusBadGateway, resp.StatusCode)
require.Equal(t, responseMetaHeaderCfdFlowRateLimited, resp.Header.Get(ResponseMetaHeader))
cancel()
wg.Wait()
}
func benchmarkServeHTTP(b *testing.B, test testRequest) {
http2Conn, edgeConn := newTestHTTP2Connection()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = http2Conn.Serve(ctx)
}()
endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
require.NoError(b, err)
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
require.NoError(b, err)
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StartTimer()
resp, err := edgeHTTP2Conn.RoundTrip(req)
b.StopTimer()
require.NoError(b, err)
require.Equal(b, test.expectedStatus, resp.StatusCode)
if test.expectedBody != nil {
respBody, err := io.ReadAll(resp.Body)
require.NoError(b, err)
require.Equal(b, test.expectedBody, respBody)
}
resp.Body.Close()
}
cancel()
wg.Wait()
}
func BenchmarkServeHTTPSimple(b *testing.B) {
test := testRequest{
name: "ok",
endpoint: "ok",
expectedStatus: http.StatusOK,
expectedBody: []byte(http.StatusText(http.StatusOK)),
}
benchmarkServeHTTP(b, test)
}
func BenchmarkServeHTTPLargeFile(b *testing.B) {
test := testRequest{
name: "large_file",
endpoint: "large_file",
expectedStatus: http.StatusOK,
expectedBody: testLargeResp,
}
benchmarkServeHTTP(b, test)
}