TUN-8861: Add session limiter to TCP session manager

## Summary
In order to make cloudflared behavior more predictable and
prevent an exhaustion of resources, we have decided to add
session limits that can be configured by the user. This commit
adds the session limiter to the HTTP/TCP handling path.
For now the limiter is set to run only in unlimited mode.
This commit is contained in:
João "Pisco" Fernandes
2025-01-14 14:05:18 +00:00
parent bf4954e96a
commit 8bfe111cab
12 changed files with 275 additions and 102 deletions

View File

@@ -21,8 +21,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"
"go.uber.org/mock/gomock"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/mocks"
cfdsession "github.com/cloudflare/cloudflared/session"
"github.com/cloudflare/cloudflared/cfio"
"github.com/cloudflare/cloudflared/config"
"github.com/cloudflare/cloudflared/connection"
@@ -71,11 +76,6 @@ func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
}
// respHeaders is a test function to read respHeaders
func (w *mockHTTPRespWriter) headers() http.Header {
return w.Header()
}
func (m *mockHTTPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
panic("Hijack not implemented")
}
@@ -113,7 +113,7 @@ func (w *mockWSRespWriter) Read(data []byte) (int, error) {
return w.reader.Read(data)
}
func (m *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (w *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
panic("Hijack not implemented")
}
@@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, time.Duration(0), &log)
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
t.Run("testProxyHTTP", testProxyHTTP(proxy))
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
t.Run("testProxySSE", testProxySSE(proxy))
@@ -246,7 +246,7 @@ func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
_ = responseWriter.Close()
close(finished)
errGroup.Wait()
_ = errGroup.Wait()
}
}
@@ -267,7 +267,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
defer wg.Done()
log := zerolog.Nop()
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false)
require.Equal(t, err.Error(), "context canceled")
require.Equal(t, "context canceled", err.Error())
require.Equal(t, http.StatusOK, responseWriter.Code)
}()
@@ -275,7 +275,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
for i := 0; i < pushCount; i++ {
line := responseWriter.ReadBytes()
expect := fmt.Sprintf("%d\n\n", i)
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
require.Equal(t, []byte(expect), line, "Expect to read %v, got %v", expect, line)
}
cancel()
@@ -290,7 +290,9 @@ func TestProxySSEAllData(t *testing.T) {
responseWriter := newMockSSERespWriter()
// responseWriter uses an unbuffered channel, so we call in a different go-routine
go cfio.Copy(responseWriter, eyeballReader)
go func() {
_, _ = cfio.Copy(responseWriter, eyeballReader)
}()
result := string(<-responseWriter.writeNotification)
require.Equal(t, "data\r\r", result)
@@ -366,7 +368,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
ctx, cancel := context.WithCancel(context.Background())
require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, time.Duration(0), &log)
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
for _, test := range tests {
responseWriter := newMockHTTPRespWriter()
@@ -414,23 +416,18 @@ func TestProxyError(t *testing.T) {
log := zerolog.Nop()
proxy := NewOriginProxy(ing, noWarpRouting, testTags, time.Duration(0), &log)
proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
responseWriter := newMockHTTPRespWriter()
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
assert.NoError(t, err)
require.NoError(t, err)
assert.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false))
require.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false))
}
type replayer struct {
sync.RWMutex
writeDone chan struct{}
rw *bytes.Buffer
}
func newReplayer(buffer *bytes.Buffer) {
rw *bytes.Buffer
}
func (r *replayer) Read(p []byte) (int, error) {
@@ -471,7 +468,7 @@ func (r *replayer) Bytes() []byte {
// eyeball sends tcp packets wrapped in websockets. (E.g: cloudflared access).
func TestConnections(t *testing.T) {
logger := logger.Create(nil)
replayer := &replayer{rw: &bytes.Buffer{}}
replayer := &replayer{rw: bytes.NewBuffer([]byte{})}
type args struct {
ingressServiceScheme string
originService func(*testing.T, net.Listener)
@@ -486,6 +483,9 @@ func TestConnections(t *testing.T) {
// requestheaders to be sent in the call to proxy.Proxy
requestHeaders http.Header
// sessionLimiterResponse is the response of the cfdsession.Limiter#Acquire method call
sessionLimiterResponse error
}
type want struct {
@@ -663,6 +663,25 @@ func TestConnections(t *testing.T) {
err: true,
},
},
{
name: "tcp-* proxy rate limited flow",
args: args{
ingressServiceScheme: "tcp://",
originService: runEchoTCPService,
eyeballResponseWriter: newTCPRespWriter(replayer),
eyeballRequestBody: newTCPRequestBody([]byte("rate-limited")),
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
connectionType: connection.TypeTCP,
requestHeaders: map[string][]string{
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
},
sessionLimiterResponse: cfdsession.ErrTooManyActiveSessions,
},
want: want{
message: []byte{},
err: true,
},
},
}
for _, test := range tests {
@@ -674,8 +693,16 @@ func TestConnections(t *testing.T) {
test.args.originService(t, ln)
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
ingressRule.StartOrigins(logger, ctx.Done())
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, time.Duration(0), logger)
_ = ingressRule.StartOrigins(logger, ctx.Done())
// Mock session limiter
ctrl := gomock.NewController(t)
defer ctrl.Finish()
sessionLimiter := mocks.NewMockLimiter(ctrl)
sessionLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.sessionLimiterResponse)
sessionLimiter.EXPECT().Release().AnyTimes()
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, sessionLimiter, time.Duration(0), logger)
proxy.warpRouting = test.args.warpRoutingService
dest := ln.Addr().String()
@@ -693,7 +720,7 @@ func TestConnections(t *testing.T) {
respWriter = newTCPRespWriter(pipedReqBody.pipedConn)
go func() {
resp := pipedReqBody.roundtrip(test.args.ingressServiceScheme + ln.Addr().String())
replayer.Write(resp)
_, _ = replayer.Write(resp)
}()
}
if test.args.connectionType == connection.TypeTCP {
@@ -705,9 +732,9 @@ func TestConnections(t *testing.T) {
}
cancel()
assert.Equal(t, test.want.err, err != nil)
assert.Equal(t, test.want.message, replayer.Bytes())
assert.Equal(t, test.want.headers, respWriter.Header())
require.Equal(t, test.want.err, err != nil)
require.Equal(t, test.want.message, replayer.Bytes())
require.Equal(t, test.want.headers, respWriter.Header())
replayer.rw.Reset()
})
}
@@ -720,7 +747,9 @@ type requestBody struct {
func newWSRequestBody(data []byte) *requestBody {
pr, pw := io.Pipe()
go wsutil.WriteClientBinary(pw, data)
go func() {
_ = wsutil.WriteClientBinary(pw, data)
}()
return &requestBody{
pr: pr,
pw: pw,
@@ -728,7 +757,9 @@ func newWSRequestBody(data []byte) *requestBody {
}
func newTCPRequestBody(data []byte) *requestBody {
pr, pw := io.Pipe()
go pw.Write(data)
go func() {
_, _ = pw.Write(data)
}()
return &requestBody{
pr: pr,
pw: pw,
@@ -740,8 +771,8 @@ func (r *requestBody) Read(p []byte) (n int, err error) {
}
func (r *requestBody) Close() error {
r.pw.Close()
r.pr.Close()
_ = r.pw.Close()
_ = r.pr.Close()
return nil
}
@@ -774,6 +805,7 @@ func (p *pipedRequestBody) roundtrip(addr string) []byte {
panic(err)
}
defer conn.Close()
defer resp.Body.Close()
if resp.StatusCode != http.StatusSwitchingProtocols {
panic(fmt.Errorf("resp returned status code: %d", resp.StatusCode))
@@ -917,7 +949,9 @@ func runEchoTCPService(t *testing.T, l net.Listener) {
go func() {
for {
conn, err := l.Accept()
require.NoError(t, err)
if err != nil {
panic(err)
}
defer conn.Close()
for {
@@ -971,12 +1005,15 @@ func runEchoWSService(t *testing.T, l net.Listener) {
}
}
// nolint: gosec
server := http.Server{
Handler: http.HandlerFunc(ws),
}
go func() {
err := server.Serve(l)
require.NoError(t, err)
if err != nil {
panic(err)
}
}()
}