mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 17:19:58 +00:00
TUN-8861: Add session limiter to UDP session manager
## Summary In order to make cloudflared behavior more predictable and prevent an exhaustion of resources, we have decided to add session limits that can be configured by the user. This first commit introduces the session limiter and adds it to the UDP handling path. For now the limiter is set to run only in unlimited mode.
This commit is contained in:
@@ -28,6 +28,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/nettest"
|
||||
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
|
||||
"github.com/cloudflare/cloudflared/datagramsession"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/packet"
|
||||
@@ -53,7 +55,8 @@ var _ ReadWriteAcker = (*streamReadWriteAcker)(nil)
|
||||
func TestQUICServer(t *testing.T) {
|
||||
// This is simply a sample websocket frame message.
|
||||
wsBuf := &bytes.Buffer{}
|
||||
wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
|
||||
err := wsutil.WriteClientBinary(wsBuf, []byte("Hello"))
|
||||
require.NoError(t, err)
|
||||
|
||||
var tests = []struct {
|
||||
desc string
|
||||
@@ -158,17 +161,19 @@ func TestQUICServer(t *testing.T) {
|
||||
|
||||
serverDone := make(chan struct{})
|
||||
go func() {
|
||||
// nolint: testifylint
|
||||
quicServer(
|
||||
ctx, t, quicListener, test.dest, test.connectionType, test.metadata, test.message, test.expectedResponse,
|
||||
)
|
||||
close(serverDone)
|
||||
}()
|
||||
|
||||
// nolint: gosec
|
||||
tunnelConn, _ := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), uint8(i))
|
||||
|
||||
connDone := make(chan struct{})
|
||||
go func() {
|
||||
tunnelConn.Serve(ctx)
|
||||
_ = tunnelConn.Serve(ctx)
|
||||
close(connDone)
|
||||
}()
|
||||
|
||||
@@ -254,14 +259,14 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, tr *tracing.T
|
||||
case "/ok":
|
||||
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
|
||||
case "/slow_echo_body":
|
||||
time.Sleep(5)
|
||||
time.Sleep(5 * time.Nanosecond)
|
||||
fallthrough
|
||||
case "/echo_body":
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
}
|
||||
_ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
|
||||
io.Copy(w, r.Body)
|
||||
_, _ = io.Copy(w, r.Body)
|
||||
case "/error":
|
||||
return fmt.Errorf("Failed to proxy to origin")
|
||||
default:
|
||||
@@ -493,16 +498,16 @@ func TestBuildHTTPRequest(t *testing.T) {
|
||||
test := test // capture range variable
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
req, err := buildHTTPRequest(context.Background(), test.connectRequest, test.body, 0, &log)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
test.req = test.req.WithContext(req.Context())
|
||||
assert.Equal(t, test.req, req.Request)
|
||||
require.Equal(t, test.req, req.Request)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (moc *mockOriginProxyWithRequest) ProxyTCP(ctx context.Context, rwa ReadWriteAcker, tcpRequest *TCPRequest) error {
|
||||
rwa.AckConnection("")
|
||||
io.Copy(rwa, rwa)
|
||||
_ = rwa.AckConnection("")
|
||||
_, _ = io.Copy(rwa, rwa)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -520,16 +525,19 @@ func TestServeUDPSession(t *testing.T) {
|
||||
edgeQUICSessionChan := make(chan quic.Connection)
|
||||
go func() {
|
||||
earlyListener, err := quic.Listen(udpListener, testTLSServerConfig, testQUICConfig)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
edgeQUICSession, err := earlyListener.Accept(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
edgeQUICSessionChan <- edgeQUICSession
|
||||
}()
|
||||
|
||||
// Random index to avoid reusing port
|
||||
tunnelConn, datagramConn := testTunnelConnection(t, netip.MustParseAddrPort(udpListener.LocalAddr().String()), 28)
|
||||
go tunnelConn.Serve(ctx)
|
||||
go func() {
|
||||
_ = tunnelConn.Serve(ctx)
|
||||
}()
|
||||
|
||||
edgeQUICSession := <-edgeQUICSessionChan
|
||||
|
||||
@@ -545,14 +553,14 @@ func TestNopCloserReadWriterCloseBeforeEOF(t *testing.T) {
|
||||
|
||||
n, err := readerWriter.Read(buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, 5)
|
||||
require.Equal(t, 5, n)
|
||||
|
||||
// close
|
||||
require.NoError(t, readerWriter.Close())
|
||||
|
||||
// read should get error
|
||||
n, err = readerWriter.Read(buffer)
|
||||
require.Equal(t, n, 0)
|
||||
require.Equal(t, 0, n)
|
||||
require.Equal(t, err, fmt.Errorf("closed by handler"))
|
||||
}
|
||||
|
||||
@@ -562,7 +570,7 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
|
||||
|
||||
n, err := readerWriter.Read(buffer)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n, 9)
|
||||
require.Equal(t, 9, n)
|
||||
|
||||
// force another read to read eof
|
||||
_, err = readerWriter.Read(buffer)
|
||||
@@ -573,7 +581,7 @@ func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) {
|
||||
|
||||
// read should get EOF still
|
||||
n, err = readerWriter.Read(buffer)
|
||||
require.Equal(t, n, 0)
|
||||
require.Equal(t, 0, n)
|
||||
require.Equal(t, err, io.EOF)
|
||||
}
|
||||
|
||||
@@ -669,6 +677,7 @@ func serveSession(ctx context.Context, datagramConn *datagramV2Connection, edgeQ
|
||||
unregisterReason: expectedReason,
|
||||
calledUnregisterChan: unregisterFromEdgeChan,
|
||||
}
|
||||
// nolint: testifylint
|
||||
go runRPCServer(ctx, edgeQUICSession, sessionRPCServer, nil, t)
|
||||
|
||||
<-unregisterFromEdgeChan
|
||||
@@ -729,6 +738,7 @@ func (s mockSessionRPCServer) UnregisterUdpSession(ctx context.Context, sessionI
|
||||
|
||||
func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8) (TunnelConnection, *datagramV2Connection) {
|
||||
tlsClientConfig := &tls.Config{
|
||||
// nolint: gosec
|
||||
InsecureSkipVerify: true,
|
||||
NextProtos: []string{"argotunnel"},
|
||||
}
|
||||
@@ -747,6 +757,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
|
||||
index,
|
||||
&log,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start a session manager for the connection
|
||||
sessionDemuxChan := make(chan *packet.Session, 4)
|
||||
@@ -757,7 +768,9 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
|
||||
|
||||
datagramConn := &datagramV2Connection{
|
||||
conn,
|
||||
index,
|
||||
sessionManager,
|
||||
cfdsession.NewLimiter(0),
|
||||
datagramMuxer,
|
||||
packetRouter,
|
||||
15 * time.Second,
|
||||
@@ -796,6 +809,7 @@ func (m *mockReaderNoopWriter) Close() error {
|
||||
|
||||
// GenerateTLSConfig sets up a bare-bones TLS config for a QUIC server
|
||||
func GenerateTLSConfig() *tls.Config {
|
||||
// nolint: gosec
|
||||
key, err := rsa.GenerateKey(rand.Reader, 1024)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -812,6 +826,7 @@ func GenerateTLSConfig() *tls.Config {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// nolint: gosec
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
NextProtos: []string{"argotunnel"},
|
||||
|
@@ -7,12 +7,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
pkgerrors "github.com/pkg/errors"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
|
||||
"github.com/cloudflare/cloudflared/datagramsession"
|
||||
"github.com/cloudflare/cloudflared/ingress"
|
||||
"github.com/cloudflare/cloudflared/management"
|
||||
@@ -38,10 +41,14 @@ type DatagramSessionHandler interface {
|
||||
}
|
||||
|
||||
type datagramV2Connection struct {
|
||||
conn quic.Connection
|
||||
conn quic.Connection
|
||||
index uint8
|
||||
|
||||
// sessionManager tracks active sessions. It receives datagrams from quic connection via datagramMuxer
|
||||
sessionManager datagramsession.Manager
|
||||
// sessionLimiter tracks active sessions across the tunnel and limits new sessions if they are above the limit.
|
||||
sessionLimiter cfdsession.Limiter
|
||||
|
||||
// datagramMuxer mux/demux datagrams from quic connection
|
||||
datagramMuxer *cfdquic.DatagramMuxerV2
|
||||
packetRouter *ingress.PacketRouter
|
||||
@@ -58,6 +65,7 @@ func NewDatagramV2Connection(ctx context.Context,
|
||||
index uint8,
|
||||
rpcTimeout time.Duration,
|
||||
streamWriteTimeout time.Duration,
|
||||
sessionLimiter cfdsession.Limiter,
|
||||
logger *zerolog.Logger,
|
||||
) DatagramSessionHandler {
|
||||
sessionDemuxChan := make(chan *packet.Session, demuxChanCapacity)
|
||||
@@ -66,13 +74,15 @@ func NewDatagramV2Connection(ctx context.Context,
|
||||
packetRouter := ingress.NewPacketRouter(icmpRouter, datagramMuxer, index, logger)
|
||||
|
||||
return &datagramV2Connection{
|
||||
conn,
|
||||
sessionManager,
|
||||
datagramMuxer,
|
||||
packetRouter,
|
||||
rpcTimeout,
|
||||
streamWriteTimeout,
|
||||
logger,
|
||||
conn: conn,
|
||||
index: index,
|
||||
sessionManager: sessionManager,
|
||||
sessionLimiter: sessionLimiter,
|
||||
datagramMuxer: datagramMuxer,
|
||||
packetRouter: packetRouter,
|
||||
rpcTimeout: rpcTimeout,
|
||||
streamWriteTimeout: streamWriteTimeout,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,12 +119,23 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
|
||||
attribute.String("dst", fmt.Sprintf("%s:%d", dstIP, dstPort)),
|
||||
))
|
||||
log := q.logger.With().Int(management.EventTypeKey, int(management.UDP)).Logger()
|
||||
|
||||
// Try to start a new session
|
||||
if err := q.sessionLimiter.Acquire(management.UDP.String()); err != nil {
|
||||
log.Warn().Msgf("Too many concurrent sessions being handled, rejecting udp proxy to %s:%d", dstIP, dstPort)
|
||||
|
||||
err := pkgerrors.Wrap(err, "failed to start udp session due to rate limiting")
|
||||
tracing.EndWithErrorStatus(registerSpan, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Each session is a series of datagram from an eyeball to a dstIP:dstPort.
|
||||
// (src port, dst IP, dst port) uniquely identifies a session, so it needs a dedicated connected socket.
|
||||
originProxy, err := ingress.DialUDP(dstIP, dstPort)
|
||||
if err != nil {
|
||||
log.Err(err).Msgf("Failed to create udp proxy to %s:%d", dstIP, dstPort)
|
||||
tracing.EndWithErrorStatus(registerSpan, err)
|
||||
q.sessionLimiter.Release()
|
||||
return nil, err
|
||||
}
|
||||
registerSpan.SetAttributes(
|
||||
@@ -127,10 +148,14 @@ func (q *datagramV2Connection) RegisterUdpSession(ctx context.Context, sessionID
|
||||
originProxy.Close()
|
||||
log.Err(err).Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).Msgf("Failed to register udp session")
|
||||
tracing.EndWithErrorStatus(registerSpan, err)
|
||||
q.sessionLimiter.Release()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go q.serveUDPSession(session, closeAfterIdleHint)
|
||||
go func() {
|
||||
defer q.sessionLimiter.Release() // we do the release here, instead of inside the `serveUDPSession` just to keep all acquire/release calls in the same method.
|
||||
q.serveUDPSession(session, closeAfterIdleHint)
|
||||
}()
|
||||
|
||||
log.Debug().
|
||||
Str(datagramsession.LogFieldSessionID, datagramsession.FormatSessionID(sessionID)).
|
||||
@@ -170,7 +195,7 @@ func (q *datagramV2Connection) serveUDPSession(session *datagramsession.Session,
|
||||
|
||||
// closeUDPSession first unregisters the session from session manager, then it tries to unregister from edge
|
||||
func (q *datagramV2Connection) closeUDPSession(ctx context.Context, sessionID uuid.UUID, message string) {
|
||||
q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
|
||||
_ = q.sessionManager.UnregisterSession(ctx, sessionID, message, false)
|
||||
quicStream, err := q.conn.OpenStream()
|
||||
if err != nil {
|
||||
// Log this at debug because this is not an error if session was closed due to lost connection
|
||||
|
96
connection/quic_datagram_v2_test.go
Normal file
96
connection/quic_datagram_v2_test.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/quic-go/quic-go"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/cloudflare/cloudflared/mocks"
|
||||
cfdsession "github.com/cloudflare/cloudflared/session"
|
||||
)
|
||||
|
||||
type mockQuicConnection struct {
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) AcceptStream(_ context.Context) (quic.Stream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) AcceptUniStream(_ context.Context) (quic.ReceiveStream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) OpenStream() (quic.Stream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) OpenStreamSync(_ context.Context) (quic.Stream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) OpenUniStream() (quic.SendStream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) OpenUniStreamSync(_ context.Context) (quic.SendStream, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) CloseWithError(_ quic.ApplicationErrorCode, s string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) Context() context.Context {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) ConnectionState() quic.ConnectionState {
|
||||
panic("not meant to be called")
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) SendDatagram(_ []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockQuicConnection) ReceiveDatagram(_ context.Context) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestRateLimitOnNewDatagramV2UDPSession(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
conn := &mockQuicConnection{}
|
||||
ctrl := gomock.NewController(t)
|
||||
sessionLimiterMock := mocks.NewMockLimiter(ctrl)
|
||||
|
||||
datagramConn := NewDatagramV2Connection(
|
||||
context.Background(),
|
||||
conn,
|
||||
nil,
|
||||
0,
|
||||
0*time.Second,
|
||||
0*time.Second,
|
||||
sessionLimiterMock,
|
||||
&log,
|
||||
)
|
||||
|
||||
sessionLimiterMock.EXPECT().Acquire("udp").Return(cfdsession.ErrTooManyActiveSessions)
|
||||
sessionLimiterMock.EXPECT().Release().Times(0)
|
||||
|
||||
_, err := datagramConn.RegisterUdpSession(context.Background(), uuid.New(), net.IPv4(0, 0, 0, 0), 1000, 1*time.Second, "")
|
||||
require.ErrorIs(t, err, cfdsession.ErrTooManyActiveSessions)
|
||||
}
|
Reference in New Issue
Block a user