mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 19:29:57 +00:00
TUN-8861: Add session limiter to TCP session manager
## Summary In order to make cloudflared behavior more predictable and prevent an exhaustion of resources, we have decided to add session limits that can be configured by the user. This commit adds the session limiter to the HTTP/TCP handling path. For now the limiter is set to run only in unlimited mode.
This commit is contained in:
@@ -9,10 +9,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
"github.com/rs/zerolog"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
|
||||
"github.com/cloudflare/cloudflared/carrier"
|
||||
"github.com/cloudflare/cloudflared/cfio"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
@@ -30,11 +34,11 @@ const (
|
||||
|
||||
// Proxy represents a means to Proxy between cloudflared and the origin services.
|
||||
type Proxy struct {
|
||||
ingressRules ingress.Ingress
|
||||
warpRouting *ingress.WarpRoutingService
|
||||
management *ingress.ManagementService
|
||||
tags []pogs.Tag
|
||||
log *zerolog.Logger
|
||||
ingressRules ingress.Ingress
|
||||
warpRouting *ingress.WarpRoutingService
|
||||
tags []pogs.Tag
|
||||
sessionLimiter cfdsession.Limiter
|
||||
log *zerolog.Logger
|
||||
}
|
||||
|
||||
// NewOriginProxy returns a new instance of the Proxy struct.
|
||||
@@ -42,13 +46,15 @@ func NewOriginProxy(
|
||||
ingressRules ingress.Ingress,
|
||||
warpRouting ingress.WarpRoutingConfig,
|
||||
tags []pogs.Tag,
|
||||
sessionLimiter cfdsession.Limiter,
|
||||
writeTimeout time.Duration,
|
||||
log *zerolog.Logger,
|
||||
) *Proxy {
|
||||
proxy := &Proxy{
|
||||
ingressRules: ingressRules,
|
||||
tags: tags,
|
||||
log: log,
|
||||
ingressRules: ingressRules,
|
||||
tags: tags,
|
||||
sessionLimiter: sessionLimiter,
|
||||
log: log,
|
||||
}
|
||||
|
||||
proxy.warpRouting = ingress.NewWarpRoutingService(warpRouting, writeTimeout)
|
||||
@@ -64,7 +70,7 @@ func (p *Proxy) applyIngressMiddleware(rule *ingress.Rule, r *http.Request, w co
|
||||
}
|
||||
|
||||
if result.ShouldFilterRequest {
|
||||
w.WriteRespHeaders(result.StatusCode, nil)
|
||||
_ = w.WriteRespHeaders(result.StatusCode, nil)
|
||||
return fmt.Errorf("request filtered by middleware handler (%s) due to: %s", handler.Name(), result.Reason), true
|
||||
}
|
||||
}
|
||||
@@ -152,10 +158,18 @@ func (p *Proxy) ProxyTCP(
|
||||
return err
|
||||
}
|
||||
|
||||
logger := newTCPLogger(p.log, req)
|
||||
|
||||
// Try to start a new session
|
||||
if err := p.sessionLimiter.Acquire(management.TCP.String()); err != nil {
|
||||
logger.Warn().Msg("Too many concurrent sessions being handled, rejecting tcp proxy")
|
||||
return pkgerrors.Wrap(err, "failed to start tcp session due to rate limiting")
|
||||
}
|
||||
defer p.sessionLimiter.Release()
|
||||
|
||||
serveCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
logger := newTCPLogger(p.log, req)
|
||||
tracedCtx := tracing.NewTracedContext(serveCtx, req.CfTraceID, &logger)
|
||||
logger.Debug().Msg("tcp proxy stream started")
|
||||
|
||||
|
@@ -21,8 +21,13 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/cli/v2"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/cloudflare/cloudflared/mocks"
|
||||
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cfio"
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
@@ -71,11 +76,6 @@ func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
|
||||
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
|
||||
}
|
||||
|
||||
// respHeaders is a test function to read respHeaders
|
||||
func (w *mockHTTPRespWriter) headers() http.Header {
|
||||
return w.Header()
|
||||
}
|
||||
|
||||
func (m *mockHTTPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
panic("Hijack not implemented")
|
||||
}
|
||||
@@ -113,7 +113,7 @@ func (w *mockWSRespWriter) Read(data []byte) (int, error) {
|
||||
return w.reader.Read(data)
|
||||
}
|
||||
|
||||
func (m *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
func (w *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
panic("Hijack not implemented")
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ func TestProxySingleOrigin(t *testing.T) {
|
||||
|
||||
require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
|
||||
|
||||
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, time.Duration(0), &log)
|
||||
proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
|
||||
t.Run("testProxyHTTP", testProxyHTTP(proxy))
|
||||
t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
|
||||
t.Run("testProxySSE", testProxySSE(proxy))
|
||||
@@ -246,7 +246,7 @@ func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
|
||||
_ = responseWriter.Close()
|
||||
|
||||
close(finished)
|
||||
errGroup.Wait()
|
||||
_ = errGroup.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,7 +267,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
log := zerolog.Nop()
|
||||
err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false)
|
||||
require.Equal(t, err.Error(), "context canceled")
|
||||
require.Equal(t, "context canceled", err.Error())
|
||||
|
||||
require.Equal(t, http.StatusOK, responseWriter.Code)
|
||||
}()
|
||||
@@ -275,7 +275,7 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
|
||||
for i := 0; i < pushCount; i++ {
|
||||
line := responseWriter.ReadBytes()
|
||||
expect := fmt.Sprintf("%d\n\n", i)
|
||||
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
|
||||
require.Equal(t, []byte(expect), line, "Expect to read %v, got %v", expect, line)
|
||||
}
|
||||
|
||||
cancel()
|
||||
@@ -290,7 +290,9 @@ func TestProxySSEAllData(t *testing.T) {
|
||||
responseWriter := newMockSSERespWriter()
|
||||
|
||||
// responseWriter uses an unbuffered channel, so we call in a different go-routine
|
||||
go cfio.Copy(responseWriter, eyeballReader)
|
||||
go func() {
|
||||
_, _ = cfio.Copy(responseWriter, eyeballReader)
|
||||
}()
|
||||
|
||||
result := string(<-responseWriter.writeNotification)
|
||||
require.Equal(t, "data\r\r", result)
|
||||
@@ -366,7 +368,7 @@ func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.Unvalidat
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
|
||||
|
||||
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, time.Duration(0), &log)
|
||||
proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
|
||||
|
||||
for _, test := range tests {
|
||||
responseWriter := newMockHTTPRespWriter()
|
||||
@@ -414,23 +416,18 @@ func TestProxyError(t *testing.T) {
|
||||
|
||||
log := zerolog.Nop()
|
||||
|
||||
proxy := NewOriginProxy(ing, noWarpRouting, testTags, time.Duration(0), &log)
|
||||
proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdsession.NewLimiter(0), time.Duration(0), &log)
|
||||
|
||||
responseWriter := newMockHTTPRespWriter()
|
||||
req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false))
|
||||
require.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false))
|
||||
}
|
||||
|
||||
type replayer struct {
|
||||
sync.RWMutex
|
||||
writeDone chan struct{}
|
||||
rw *bytes.Buffer
|
||||
}
|
||||
|
||||
func newReplayer(buffer *bytes.Buffer) {
|
||||
|
||||
rw *bytes.Buffer
|
||||
}
|
||||
|
||||
func (r *replayer) Read(p []byte) (int, error) {
|
||||
@@ -471,7 +468,7 @@ func (r *replayer) Bytes() []byte {
|
||||
// eyeball sends tcp packets wrapped in websockets. (E.g: cloudflared access).
|
||||
func TestConnections(t *testing.T) {
|
||||
logger := logger.Create(nil)
|
||||
replayer := &replayer{rw: &bytes.Buffer{}}
|
||||
replayer := &replayer{rw: bytes.NewBuffer([]byte{})}
|
||||
type args struct {
|
||||
ingressServiceScheme string
|
||||
originService func(*testing.T, net.Listener)
|
||||
@@ -486,6 +483,9 @@ func TestConnections(t *testing.T) {
|
||||
|
||||
// requestheaders to be sent in the call to proxy.Proxy
|
||||
requestHeaders http.Header
|
||||
|
||||
// sessionLimiterResponse is the response of the cfdsession.Limiter#Acquire method call
|
||||
sessionLimiterResponse error
|
||||
}
|
||||
|
||||
type want struct {
|
||||
@@ -663,6 +663,25 @@ func TestConnections(t *testing.T) {
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tcp-* proxy rate limited flow",
|
||||
args: args{
|
||||
ingressServiceScheme: "tcp://",
|
||||
originService: runEchoTCPService,
|
||||
eyeballResponseWriter: newTCPRespWriter(replayer),
|
||||
eyeballRequestBody: newTCPRequestBody([]byte("rate-limited")),
|
||||
warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
|
||||
connectionType: connection.TypeTCP,
|
||||
requestHeaders: map[string][]string{
|
||||
"Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
|
||||
},
|
||||
sessionLimiterResponse: cfdsession.ErrTooManyActiveSessions,
|
||||
},
|
||||
want: want{
|
||||
message: []byte{},
|
||||
err: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
@@ -674,8 +693,16 @@ func TestConnections(t *testing.T) {
|
||||
test.args.originService(t, ln)
|
||||
|
||||
ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
|
||||
ingressRule.StartOrigins(logger, ctx.Done())
|
||||
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, time.Duration(0), logger)
|
||||
_ = ingressRule.StartOrigins(logger, ctx.Done())
|
||||
|
||||
// Mock session limiter
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
sessionLimiter := mocks.NewMockLimiter(ctrl)
|
||||
sessionLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.sessionLimiterResponse)
|
||||
sessionLimiter.EXPECT().Release().AnyTimes()
|
||||
|
||||
proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, sessionLimiter, time.Duration(0), logger)
|
||||
proxy.warpRouting = test.args.warpRoutingService
|
||||
|
||||
dest := ln.Addr().String()
|
||||
@@ -693,7 +720,7 @@ func TestConnections(t *testing.T) {
|
||||
respWriter = newTCPRespWriter(pipedReqBody.pipedConn)
|
||||
go func() {
|
||||
resp := pipedReqBody.roundtrip(test.args.ingressServiceScheme + ln.Addr().String())
|
||||
replayer.Write(resp)
|
||||
_, _ = replayer.Write(resp)
|
||||
}()
|
||||
}
|
||||
if test.args.connectionType == connection.TypeTCP {
|
||||
@@ -705,9 +732,9 @@ func TestConnections(t *testing.T) {
|
||||
}
|
||||
|
||||
cancel()
|
||||
assert.Equal(t, test.want.err, err != nil)
|
||||
assert.Equal(t, test.want.message, replayer.Bytes())
|
||||
assert.Equal(t, test.want.headers, respWriter.Header())
|
||||
require.Equal(t, test.want.err, err != nil)
|
||||
require.Equal(t, test.want.message, replayer.Bytes())
|
||||
require.Equal(t, test.want.headers, respWriter.Header())
|
||||
replayer.rw.Reset()
|
||||
})
|
||||
}
|
||||
@@ -720,7 +747,9 @@ type requestBody struct {
|
||||
|
||||
func newWSRequestBody(data []byte) *requestBody {
|
||||
pr, pw := io.Pipe()
|
||||
go wsutil.WriteClientBinary(pw, data)
|
||||
go func() {
|
||||
_ = wsutil.WriteClientBinary(pw, data)
|
||||
}()
|
||||
return &requestBody{
|
||||
pr: pr,
|
||||
pw: pw,
|
||||
@@ -728,7 +757,9 @@ func newWSRequestBody(data []byte) *requestBody {
|
||||
}
|
||||
func newTCPRequestBody(data []byte) *requestBody {
|
||||
pr, pw := io.Pipe()
|
||||
go pw.Write(data)
|
||||
go func() {
|
||||
_, _ = pw.Write(data)
|
||||
}()
|
||||
return &requestBody{
|
||||
pr: pr,
|
||||
pw: pw,
|
||||
@@ -740,8 +771,8 @@ func (r *requestBody) Read(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (r *requestBody) Close() error {
|
||||
r.pw.Close()
|
||||
r.pr.Close()
|
||||
_ = r.pw.Close()
|
||||
_ = r.pr.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -774,6 +805,7 @@ func (p *pipedRequestBody) roundtrip(addr string) []byte {
|
||||
panic(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||
panic(fmt.Errorf("resp returned status code: %d", resp.StatusCode))
|
||||
@@ -917,7 +949,9 @@ func runEchoTCPService(t *testing.T, l net.Listener) {
|
||||
go func() {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for {
|
||||
@@ -971,12 +1005,15 @@ func runEchoWSService(t *testing.T, l net.Listener) {
|
||||
}
|
||||
}
|
||||
|
||||
// nolint: gosec
|
||||
server := http.Server{
|
||||
Handler: http.HandlerFunc(ws),
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.Serve(l)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
Reference in New Issue
Block a user