TUN-9319: Add dynamic loading of features to connections via ConnectionOptionsSnapshot

Make sure to enforce snapshots of features and client information for each connection
so that the feature information can change in the background. This allows for new features
to only be applied to a connection if it completely disconnects and attempts a reconnect.

Updates the feature refresh time to 1 hour from previous cloudflared versions which
refreshed every 6 hours.

Closes TUN-9319
This commit is contained in:
Devin Carr 2025-05-14 20:11:05 +00:00
parent 02705c44b2
commit 3bf9217de5
14 changed files with 359 additions and 106 deletions

74
client/config.go Normal file
View File

@ -0,0 +1,74 @@
package client
import (
"fmt"
"net"
"github.com/google/uuid"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/features"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// Config captures the local client runtime configuration.
type Config struct {
ConnectorID uuid.UUID
Version string
Arch string
featureSelector features.FeatureSelector
}
func NewConfig(version string, arch string, featureSelector features.FeatureSelector) (*Config, error) {
connectorID, err := uuid.NewRandom()
if err != nil {
return nil, fmt.Errorf("unable to generate a connector UUID: %w", err)
}
return &Config{
ConnectorID: connectorID,
Version: version,
Arch: arch,
featureSelector: featureSelector,
}, nil
}
// ConnectionOptionsSnapshot is a snapshot of the current client information used to initialize a connection.
//
// The FeatureSnapshot is the features that are available for this connection. At the client level they may
// change, but they will not change within the scope of this struct.
type ConnectionOptionsSnapshot struct {
client pogs.ClientInfo
originLocalIP net.IP
numPreviousAttempts uint8
FeatureSnapshot features.FeatureSnapshot
}
func (c *Config) ConnectionOptionsSnapshot(originIP net.IP, previousAttempts uint8) *ConnectionOptionsSnapshot {
snapshot := c.featureSelector.Snapshot()
return &ConnectionOptionsSnapshot{
client: pogs.ClientInfo{
ClientID: c.ConnectorID[:],
Version: c.Version,
Arch: c.Arch,
Features: snapshot.FeaturesList,
},
originLocalIP: originIP,
numPreviousAttempts: previousAttempts,
FeatureSnapshot: snapshot,
}
}
func (c ConnectionOptionsSnapshot) ConnectionOptions() *pogs.ConnectionOptions {
return &pogs.ConnectionOptions{
Client: c.client,
OriginLocalIP: c.originLocalIP,
ReplaceExisting: false,
CompressionQuality: 0,
NumPreviousAttempts: c.numPreviousAttempts,
}
}
func (c ConnectionOptionsSnapshot) LogFields(event *zerolog.Event) *zerolog.Event {
return event.Strs("features", c.client.Features)
}

50
client/config_test.go Normal file
View File

@ -0,0 +1,50 @@
package client
import (
"net"
"testing"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/features"
)
func TestGenerateConnectionOptions(t *testing.T) {
version := "1234"
arch := "linux_amd64"
originIP := net.ParseIP("192.168.1.1")
var previousAttempts uint8 = 4
config, err := NewConfig(version, arch, &mockFeatureSelector{})
require.NoError(t, err)
require.Equal(t, version, config.Version)
require.Equal(t, arch, config.Arch)
// Validate ConnectionOptionsSnapshot fields
connOptions := config.ConnectionOptionsSnapshot(originIP, previousAttempts)
require.Equal(t, version, connOptions.client.Version)
require.Equal(t, arch, connOptions.client.Arch)
require.Equal(t, config.ConnectorID[:], connOptions.client.ClientID)
// Vaidate snapshot feature fields against the connOptions generated
snapshot := config.featureSelector.Snapshot()
require.Equal(t, features.DatagramV3, snapshot.DatagramVersion)
require.Equal(t, features.DatagramV3, connOptions.FeatureSnapshot.DatagramVersion)
pogsConnOptions := connOptions.ConnectionOptions()
require.Equal(t, connOptions.client, pogsConnOptions.Client)
require.Equal(t, originIP, pogsConnOptions.OriginLocalIP)
require.False(t, pogsConnOptions.ReplaceExisting)
require.Equal(t, uint8(0), pogsConnOptions.CompressionQuality)
require.Equal(t, previousAttempts, pogsConnOptions.NumPreviousAttempts)
}
type mockFeatureSelector struct{}
func (m *mockFeatureSelector) Snapshot() features.FeatureSnapshot {
return features.FeatureSnapshot{
PostQuantum: features.PostQuantumPrefer,
DatagramVersion: features.DatagramV3,
FeaturesList: []string{features.FeaturePostQuantum, features.FeatureDatagramV3_1},
}
}

View File

@ -15,7 +15,6 @@ import (
"github.com/coreos/go-systemd/v22/daemon"
"github.com/facebookgo/grace/gracenet"
"github.com/getsentry/sentry-go"
"github.com/google/uuid"
"github.com/mitchellh/go-homedir"
"github.com/pkg/errors"
"github.com/rs/zerolog"
@ -446,14 +445,7 @@ func StartServer(
log.Err(err).Msg("Couldn't start tunnel")
return err
}
var clientID uuid.UUID
if tunnelConfig.NamedTunnel != nil {
clientID, err = uuid.FromBytes(tunnelConfig.NamedTunnel.Client.ClientID)
if err != nil {
// set to nil for classic tunnels
clientID = uuid.Nil
}
}
connectorID := tunnelConfig.ClientConfig.ConnectorID
// Disable ICMP packet routing for quick tunnels
if quickTunnelURL != "" {
@ -471,7 +463,7 @@ func StartServer(
c.String("management-hostname"),
c.Bool("management-diagnostics"),
serviceIP,
clientID,
connectorID,
c.String(cfdflags.ConnectorLabel),
logger.ManagementLogger.Log,
logger.ManagementLogger,
@ -503,14 +495,14 @@ func StartServer(
sources = append(sources, ipv6.String())
}
readinessServer := metrics.NewReadyServer(clientID, tracker)
readinessServer := metrics.NewReadyServer(connectorID, tracker)
cliFlags := nonSecretCliFlags(log, c, nonSecretFlagsList)
diagnosticHandler := diagnostic.NewDiagnosticHandler(
log,
0,
diagnostic.NewSystemCollectorImpl(buildInfo.CloudflaredVersion),
tunnelConfig.NamedTunnel.Credentials.TunnelID,
clientID,
connectorID,
tracker,
cliFlags,
sources,

View File

@ -10,13 +10,13 @@ import (
"strings"
"time"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/urfave/cli/v2"
"github.com/urfave/cli/v2/altsrc"
"golang.org/x/term"
"github.com/cloudflare/cloudflared/client"
"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
"github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/config"
@ -125,27 +125,29 @@ func prepareTunnelConfig(
observer *connection.Observer,
namedTunnel *connection.TunnelProperties,
) (*supervisor.TunnelConfig, *orchestration.Config, error) {
clientID, err := uuid.NewRandom()
transportProtocol := c.String(flags.Protocol)
isPostQuantumEnforced := c.Bool(flags.PostQuantum)
featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, c.StringSlice(flags.Features), isPostQuantumEnforced, log)
if err != nil {
return nil, nil, errors.Wrap(err, "can't generate connector UUID")
return nil, nil, errors.Wrap(err, "Failed to create feature selector")
}
log.Info().Msgf("Generated Connector ID: %s", clientID)
clientConfig, err := client.NewConfig(info.Version(), info.OSArch(), featureSelector)
if err != nil {
return nil, nil, err
}
log.Info().Msgf("Generated Connector ID: %s", clientConfig.ConnectorID)
tags, err := NewTagSliceFromCLI(c.StringSlice(flags.Tag))
if err != nil {
log.Err(err).Msg("Tag parse failure")
return nil, nil, errors.Wrap(err, "Tag parse failure")
}
tags = append(tags, pogs.Tag{Name: "ID", Value: clientID.String()})
tags = append(tags, pogs.Tag{Name: "ID", Value: clientConfig.ConnectorID.String()})
transportProtocol := c.String(flags.Protocol)
isPostQuantumEnforced := c.Bool(flags.PostQuantum)
featureSelector, err := features.NewFeatureSelector(ctx, namedTunnel.Credentials.AccountTag, c.StringSlice(flags.Features), c.Bool(flags.PostQuantum), log)
if err != nil {
return nil, nil, errors.Wrap(err, "Failed to create feature selector")
}
clientFeatures := featureSelector.ClientFeatures()
pqMode := featureSelector.PostQuantumMode()
clientFeatures := featureSelector.Snapshot()
pqMode := clientFeatures.PostQuantum
if pqMode == features.PostQuantumStrict {
// Error if the user tries to force a non-quic transport protocol
if transportProtocol != connection.AutoSelectFlag && transportProtocol != connection.QUIC.String() {
@ -154,12 +156,6 @@ func prepareTunnelConfig(
transportProtocol = connection.QUIC.String()
}
namedTunnel.Client = pogs.ClientInfo{
ClientID: clientID[:],
Features: clientFeatures,
Version: info.Version(),
Arch: info.OSArch(),
}
cfg := config.GetConfiguration()
ingressRules, err := ingress.ParseIngressFromConfigAndCLI(cfg, c, log)
if err != nil {
@ -224,10 +220,8 @@ func prepareTunnelConfig(
}
tunnelConfig := &supervisor.TunnelConfig{
ClientConfig: clientConfig,
GracePeriod: gracePeriod,
ReplaceExisting: c.Bool(flags.Force),
OSArch: info.OSArch(),
ClientID: clientID.String(),
EdgeAddrs: c.StringSlice(flags.Edge),
Region: resolvedRegion,
EdgeIPVersion: edgeIPVersion,
@ -246,7 +240,6 @@ func prepareTunnelConfig(
NamedTunnel: namedTunnel,
ProtocolSelector: protocolSelector,
EdgeTLSConfigs: edgeTLSConfigs,
FeatureSelector: featureSelector,
MaxEdgeAddrRetries: uint8(c.Int(flags.MaxEdgeAddrRetries)), // nolint: gosec
RPCTimeout: c.Duration(flags.RpcTimeout),
WriteStreamTimeout: c.Duration(flags.WriteStreamTimeout),

View File

@ -57,7 +57,6 @@ type Orchestrator interface {
type TunnelProperties struct {
Credentials Credentials
Client pogs.ClientInfo
QuickTunnelUrl string
}

View File

@ -16,10 +16,10 @@ import (
"github.com/rs/zerolog"
"golang.org/x/net/http2"
"github.com/cloudflare/cloudflared/client"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
// note: these constants are exported so we can reuse them in the edge-side code
@ -39,7 +39,7 @@ type HTTP2Connection struct {
conn net.Conn
server *http2.Server
orchestrator Orchestrator
connOptions *pogs.ConnectionOptions
connOptions *client.ConnectionOptionsSnapshot
observer *Observer
connIndex uint8
@ -54,7 +54,7 @@ type HTTP2Connection struct {
func NewHTTP2Connection(
conn net.Conn,
orchestrator Orchestrator,
connOptions *pogs.ConnectionOptions,
connOptions *client.ConnectionOptionsSnapshot,
observer *Observer,
connIndex uint8,
controlStreamHandler ControlStreamHandler,
@ -118,7 +118,7 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var requestErr error
switch connType {
case TypeControlStream:
requestErr = c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions, c.orchestrator)
requestErr = c.controlStreamHandler.ServeControlStream(r.Context(), respWriter, c.connOptions.ConnectionOptions(), c.orchestrator)
if requestErr != nil {
c.controlStreamErr = requestErr
}

View File

@ -20,6 +20,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"
"github.com/cloudflare/cloudflared/client"
"github.com/cloudflare/cloudflared/tracing"
"github.com/cloudflare/cloudflared/tunnelrpc"
@ -51,7 +52,7 @@ func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
cfdConn,
// OriginProxy is set in testConfigManager
testOrchestrator,
&pogs.ConnectionOptions{},
&client.ConnectionOptionsSnapshot{},
obs,
connIndex,
controlStream,
@ -74,7 +75,7 @@ func TestHTTP2ConfigurationSet(t *testing.T) {
require.NoError(t, err)
reqBody := []byte(`{
"version": 2,
"version": 2,
"config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}}
`)
reader := bytes.NewReader(reqBody)

View File

@ -17,6 +17,7 @@ import (
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/client"
cfdflow "github.com/cloudflare/cloudflared/flow"
cfdquic "github.com/cloudflare/cloudflared/quic"
@ -43,7 +44,7 @@ type quicConnection struct {
orchestrator Orchestrator
datagramHandler DatagramSessionHandler
controlStreamHandler ControlStreamHandler
connOptions *pogs.ConnectionOptions
connOptions *client.ConnectionOptionsSnapshot
connIndex uint8
rpcTimeout time.Duration
@ -59,7 +60,7 @@ func NewTunnelConnection(
orchestrator Orchestrator,
datagramSessionHandler DatagramSessionHandler,
controlStreamHandler ControlStreamHandler,
connOptions *pogs.ConnectionOptions,
connOptions *client.ConnectionOptionsSnapshot,
rpcTimeout time.Duration,
streamWriteTimeout time.Duration,
gracePeriod time.Duration,
@ -130,7 +131,7 @@ func (q *quicConnection) Serve(ctx context.Context) error {
// serveControlStream will serve the RPC; blocking until the control plane is done.
func (q *quicConnection) serveControlStream(ctx context.Context, controlStream quic.Stream) error {
return q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions, q.orchestrator)
return q.controlStreamHandler.ServeControlStream(ctx, controlStream, q.connOptions.ConnectionOptions(), q.orchestrator)
}
// Close the connection with no errors specified.

View File

@ -29,6 +29,7 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/net/nettest"
"github.com/cloudflare/cloudflared/client"
cfdflow "github.com/cloudflare/cloudflared/flow"
"github.com/cloudflare/cloudflared/datagramsession"
@ -843,7 +844,7 @@ func testTunnelConnection(t *testing.T, serverAddr netip.AddrPort, index uint8)
&mockOrchestrator{originProxy: &mockOriginProxyWithRequest{}},
datagramConn,
fakeControlStream{},
&pogs.ConnectionOptions{},
&client.ConnectionOptionsSnapshot{},
15*time.Second,
0*time.Second,
0*time.Second,

View File

@ -33,6 +33,15 @@ type staticFeatures struct {
PostQuantumMode *PostQuantumMode
}
type FeatureSnapshot struct {
PostQuantum PostQuantumMode
DatagramVersion DatagramVersion
// We provide the list of features since we need it to send in the ConnectionOptions during connection
// registrations.
FeaturesList []string
}
type PostQuantumMode uint8
const (

View File

@ -7,6 +7,7 @@ import (
"hash/fnv"
"net"
"slices"
"sync"
"time"
"github.com/rs/zerolog"
@ -15,22 +16,29 @@ import (
const (
featureSelectorHostname = "cfd-features.argotunnel.com"
lookupTimeout = time.Second * 10
defaultLookupFreq = time.Hour
)
// If the TXT record adds other fields, the umarshal logic will ignore those keys
// If the TXT record is missing a key, the field will unmarshal to the default Go value
type featuresRecord struct {
DatagramV3Percentage uint32 `json:"dv3_1"`
// DatagramV3Percentage int32 `json:"dv3"` // Removed in TUN-9291
// PostQuantumPercentage int32 `json:"pq"` // Removed in TUN-7970
}
func NewFeatureSelector(ctx context.Context, accountTag string, cliFeatures []string, pq bool, logger *zerolog.Logger) (*FeatureSelector, error) {
return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), cliFeatures, pq)
func NewFeatureSelector(ctx context.Context, accountTag string, cliFeatures []string, pq bool, logger *zerolog.Logger) (FeatureSelector, error) {
return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), cliFeatures, pq, defaultLookupFreq)
}
type FeatureSelector interface {
Snapshot() FeatureSnapshot
}
// FeatureSelector determines if this account will try new features; loaded once during startup.
type FeatureSelector struct {
type featureSelector struct {
accountHash uint32
logger *zerolog.Logger
resolver resolver
@ -38,10 +46,12 @@ type FeatureSelector struct {
staticFeatures staticFeatures
cliFeatures []string
features featuresRecord
// lock protects concurrent access to dynamic features
lock sync.RWMutex
remoteFeatures featuresRecord
}
func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, cliFeatures []string, pq bool) (*FeatureSelector, error) {
func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, cliFeatures []string, pq bool, refreshFreq time.Duration) (*featureSelector, error) {
// Combine default features and user-provided features
var pqMode *PostQuantumMode
if pq {
@ -52,7 +62,7 @@ func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.
staticFeatures := staticFeatures{
PostQuantumMode: pqMode,
}
selector := &FeatureSelector{
selector := &featureSelector{
accountHash: switchThreshold(accountTag),
logger: logger,
resolver: resolver,
@ -60,14 +70,32 @@ func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.
cliFeatures: dedupAndRemoveFeatures(cliFeatures),
}
if err := selector.init(ctx); err != nil {
// Load the remote features
if err := selector.refresh(ctx); err != nil {
logger.Err(err).Msg("Failed to fetch features, default to disable")
}
// Spin off reloading routine
go selector.refreshLoop(ctx, refreshFreq)
return selector, nil
}
func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode {
func (fs *featureSelector) Snapshot() FeatureSnapshot {
fs.lock.RLock()
defer fs.lock.RUnlock()
return FeatureSnapshot{
PostQuantum: fs.postQuantumMode(),
DatagramVersion: fs.datagramVersion(),
FeaturesList: fs.clientFeatures(),
}
}
func (fs *featureSelector) accountEnabled(percentage uint32) bool {
return percentage > fs.accountHash
}
func (fs *featureSelector) postQuantumMode() PostQuantumMode {
if fs.staticFeatures.PostQuantumMode != nil {
return *fs.staticFeatures.PostQuantumMode
}
@ -75,7 +103,7 @@ func (fs *FeatureSelector) PostQuantumMode() PostQuantumMode {
return PostQuantumPrefer
}
func (fs *FeatureSelector) DatagramVersion() DatagramVersion {
func (fs *featureSelector) datagramVersion() DatagramVersion {
// If user provides the feature via the cli, we take it as priority over remote feature evaluation
if slices.Contains(fs.cliFeatures, FeatureDatagramV3_1) {
return DatagramV3
@ -85,16 +113,20 @@ func (fs *FeatureSelector) DatagramVersion() DatagramVersion {
return DatagramV2
}
if fs.accountEnabled(fs.remoteFeatures.DatagramV3Percentage) {
return DatagramV3
}
return DatagramV2
}
// ClientFeatures will return the list of currently available features that cloudflared should provide to the edge.
func (fs *FeatureSelector) ClientFeatures() []string {
// clientFeatures will return the list of currently available features that cloudflared should provide to the edge.
func (fs *featureSelector) clientFeatures() []string {
// Evaluate any remote features along with static feature list to construct the list of features
return dedupAndRemoveFeatures(slices.Concat(defaultFeatures, fs.cliFeatures, []string{string(fs.DatagramVersion())}))
return dedupAndRemoveFeatures(slices.Concat(defaultFeatures, fs.cliFeatures, []string{string(fs.datagramVersion())}))
}
func (fs *FeatureSelector) init(ctx context.Context) error {
func (fs *featureSelector) refresh(ctx context.Context) error {
record, err := fs.resolver.lookupRecord(ctx)
if err != nil {
return err
@ -105,11 +137,29 @@ func (fs *FeatureSelector) init(ctx context.Context) error {
return err
}
fs.features = features
fs.lock.Lock()
defer fs.lock.Unlock()
fs.remoteFeatures = features
return nil
}
func (fs *featureSelector) refreshLoop(ctx context.Context, refreshFreq time.Duration) {
ticker := time.NewTicker(refreshFreq)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
err := fs.refresh(ctx)
if err != nil {
fs.logger.Err(err).Msg("Failed to refresh feature selector")
}
}
}
}
// resolver represents an object that can look up featuresRecord
type resolver interface {
lookupRecord(ctx context.Context) ([]byte, error)

View File

@ -3,17 +3,36 @@ package features
import (
"context"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/require"
)
const (
testAccountTag = "123456"
testAccountHash = 74 // switchThreshold of `accountTag`
)
func TestUnmarshalFeaturesRecord(t *testing.T) {
tests := []struct {
record []byte
expectedPercentage uint32
}{
{
record: []byte(`{"dv3_1":0}`),
expectedPercentage: 0,
},
{
record: []byte(`{"dv3_1":39}`),
expectedPercentage: 39,
},
{
record: []byte(`{"dv3_1":100}`),
expectedPercentage: 100,
},
{
record: []byte(`{}`), // Unmarshal to default struct if key is not present
},
@ -29,6 +48,7 @@ func TestUnmarshalFeaturesRecord(t *testing.T) {
var features featuresRecord
err := json.Unmarshal(test.record, &features)
require.NoError(t, err)
require.Equal(t, test.expectedPercentage, features.DatagramV3Percentage, test)
}
}
@ -57,10 +77,11 @@ func TestFeaturePrecedenceEvaluationPostQuantum(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
resolver := &staticResolver{record: featuresRecord{}}
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, []string{}, test.cli)
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, []string{}, test.cli, time.Second)
require.NoError(t, err)
require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures())
require.Equal(t, test.expectedVersion, selector.PostQuantumMode())
snapshot := selector.Snapshot()
require.ElementsMatch(t, test.expectedFeatures, snapshot.FeaturesList)
require.Equal(t, test.expectedVersion, snapshot.PostQuantum)
})
}
}
@ -100,10 +121,11 @@ func TestFeaturePrecedenceEvaluationDatagramVersion(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
resolver := &staticResolver{record: test.remote}
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false)
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false, time.Second)
require.NoError(t, err)
require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures())
require.Equal(t, test.expectedVersion, selector.DatagramVersion())
snapshot := selector.Snapshot()
require.ElementsMatch(t, test.expectedFeatures, snapshot.FeaturesList)
require.Equal(t, test.expectedVersion, snapshot.DatagramVersion)
})
}
}
@ -133,34 +155,99 @@ func TestDeprecatedFeaturesRemoved(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
resolver := &staticResolver{record: test.remote}
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false)
selector, err := newFeatureSelector(context.Background(), test.name, &logger, resolver, test.cli, false, time.Second)
require.NoError(t, err)
require.ElementsMatch(t, test.expectedFeatures, selector.ClientFeatures())
snapshot := selector.Snapshot()
require.ElementsMatch(t, test.expectedFeatures, snapshot.FeaturesList)
})
}
}
func TestRefreshFeaturesRecord(t *testing.T) {
percentages := []uint32{0, 10, testAccountHash - 1, testAccountHash, testAccountHash + 1, 100, 101, 1000}
selector := newTestSelector(t, percentages, false, time.Minute)
// Starting out should default to DatagramV2
snapshot := selector.Snapshot()
require.Equal(t, DatagramV2, snapshot.DatagramVersion)
for _, percentage := range percentages {
snapshot = selector.Snapshot()
if percentage > testAccountHash {
require.Equal(t, DatagramV3, snapshot.DatagramVersion)
} else {
require.Equal(t, DatagramV2, snapshot.DatagramVersion)
}
// Manually progress the next refresh
_ = selector.refresh(context.Background())
}
// Make sure a resolver error doesn't override the last fetched features
snapshot = selector.Snapshot()
require.Equal(t, DatagramV3, snapshot.DatagramVersion)
}
func TestSnapshotIsolation(t *testing.T) {
percentages := []uint32{testAccountHash, testAccountHash + 1}
selector := newTestSelector(t, percentages, false, time.Minute)
// Starting out should default to DatagramV2
snapshot := selector.Snapshot()
require.Equal(t, DatagramV2, snapshot.DatagramVersion)
// Manually progress the next refresh
_ = selector.refresh(context.Background())
snapshot2 := selector.Snapshot()
require.Equal(t, DatagramV3, snapshot2.DatagramVersion)
require.NotEqual(t, snapshot.DatagramVersion, snapshot2.DatagramVersion)
}
func TestStaticFeatures(t *testing.T) {
percentages := []uint32{0}
// PostQuantum Enabled from user flag
selector := newTestSelector(t, percentages, true)
require.Equal(t, PostQuantumStrict, selector.PostQuantumMode())
selector := newTestSelector(t, percentages, true, time.Second)
snapshot := selector.Snapshot()
require.Equal(t, PostQuantumStrict, snapshot.PostQuantum)
// PostQuantum Disabled (or not set)
selector = newTestSelector(t, percentages, false)
require.Equal(t, PostQuantumPrefer, selector.PostQuantumMode())
selector = newTestSelector(t, percentages, false, time.Second)
snapshot = selector.Snapshot()
require.Equal(t, PostQuantumPrefer, snapshot.PostQuantum)
}
func newTestSelector(t *testing.T, percentages []uint32, pq bool) *FeatureSelector {
accountTag := t.Name()
func newTestSelector(t *testing.T, percentages []uint32, pq bool, refreshFreq time.Duration) *featureSelector {
logger := zerolog.Nop()
selector, err := newFeatureSelector(context.Background(), accountTag, &logger, &staticResolver{}, []string{}, pq)
resolver := &mockResolver{
percentages: percentages,
}
selector, err := newFeatureSelector(context.Background(), testAccountTag, &logger, resolver, []string{}, pq, refreshFreq)
require.NoError(t, err)
return selector
}
type mockResolver struct {
nextIndex int
percentages []uint32
}
func (mr *mockResolver) lookupRecord(ctx context.Context) ([]byte, error) {
if mr.nextIndex >= len(mr.percentages) {
return nil, fmt.Errorf("no more record to lookup")
}
record, err := json.Marshal(featuresRecord{
DatagramV3Percentage: mr.percentages[mr.nextIndex],
})
mr.nextIndex++
return record, err
}
type staticResolver struct {
record featuresRecord
}

View File

@ -1,6 +1,8 @@
package v3
import (
"fmt"
"github.com/prometheus/client_golang/prometheus"
"github.com/cloudflare/cloudflared/quic"
@ -32,28 +34,28 @@ type metrics struct {
}
func (m *metrics) IncrementFlows(connIndex uint8) {
m.totalUDPFlows.WithLabelValues(string(connIndex)).Inc()
m.activeUDPFlows.WithLabelValues(string(connIndex)).Inc()
m.totalUDPFlows.WithLabelValues(fmt.Sprintf("%d", connIndex)).Inc()
m.activeUDPFlows.WithLabelValues(fmt.Sprintf("%d", connIndex)).Inc()
}
func (m *metrics) DecrementFlows(connIndex uint8) {
m.activeUDPFlows.WithLabelValues(string(connIndex)).Dec()
m.activeUDPFlows.WithLabelValues(fmt.Sprintf("%d", connIndex)).Dec()
}
func (m *metrics) PayloadTooLarge(connIndex uint8) {
m.payloadTooLarge.WithLabelValues(string(connIndex)).Inc()
m.payloadTooLarge.WithLabelValues(fmt.Sprintf("%d", connIndex)).Inc()
}
func (m *metrics) RetryFlowResponse(connIndex uint8) {
m.retryFlowResponses.WithLabelValues(string(connIndex)).Inc()
m.retryFlowResponses.WithLabelValues(fmt.Sprintf("%d", connIndex)).Inc()
}
func (m *metrics) MigrateFlow(connIndex uint8) {
m.migratedFlows.WithLabelValues(string(connIndex)).Inc()
m.migratedFlows.WithLabelValues(fmt.Sprintf("%d", connIndex)).Inc()
}
func (m *metrics) UnsupportedRemoteCommand(connIndex uint8, command string) {
m.unsupportedRemoteCommands.WithLabelValues(string(connIndex), command).Inc()
m.unsupportedRemoteCommands.WithLabelValues(fmt.Sprintf("%d", connIndex), command).Inc()
}
func NewMetrics(registerer prometheus.Registerer) Metrics {

View File

@ -17,6 +17,7 @@ import (
"github.com/rs/zerolog"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/client"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/edgediscovery/allregions"
@ -38,10 +39,8 @@ const (
)
type TunnelConfig struct {
ClientConfig *client.Config
GracePeriod time.Duration
ReplaceExisting bool
OSArch string
ClientID string
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
EdgeAddrs []string
Region string
@ -72,22 +71,13 @@ type TunnelConfig struct {
DisableQUICPathMTUDiscovery bool
QUICConnectionLevelFlowControlLimit uint64
QUICStreamLevelFlowControlLimit uint64
FeatureSelector *features.FeatureSelector
}
func (c *TunnelConfig) connectionOptions(originLocalAddr string, numPreviousAttempts uint8) *pogs.ConnectionOptions {
func (c *TunnelConfig) connectionOptions(originLocalAddr string, previousAttempts uint8) *client.ConnectionOptionsSnapshot {
// attempt to parse out origin IP, but don't fail since it's informational field
host, _, _ := net.SplitHostPort(originLocalAddr)
originIP := net.ParseIP(host)
return &pogs.ConnectionOptions{
Client: c.NamedTunnel.Client,
OriginLocalIP: originIP,
ReplaceExisting: c.ReplaceExisting,
CompressionQuality: 0,
NumPreviousAttempts: numPreviousAttempts,
}
return c.ClientConfig.ConnectionOptionsSnapshot(originIP, previousAttempts)
}
func StartTunnelDaemon(
@ -463,6 +453,8 @@ func (e *EdgeTunnelServer) serveConnection(
case connection.QUIC:
// nolint: gosec
connOptions := e.config.connectionOptions(addr.UDP.String(), uint8(backoff.Retries()))
// nolint: zerologlint
connOptions.LogFields(connLog.Logger().Debug().Uint8(connection.LogFieldConnIndex, connIndex)).Msgf("Tunnel connection options")
return e.serveQUIC(ctx,
addr.UDP.AddrPort(),
connLog,
@ -479,6 +471,8 @@ func (e *EdgeTunnelServer) serveConnection(
// nolint: gosec
connOptions := e.config.connectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
// nolint: zerologlint
connOptions.LogFields(connLog.Logger().Debug().Uint8(connection.LogFieldConnIndex, connIndex)).Msgf("Tunnel connection options")
if err := e.serveHTTP2(
ctx,
connLog,
@ -508,11 +502,11 @@ func (e *EdgeTunnelServer) serveHTTP2(
ctx context.Context,
connLog *ConnAwareLogger,
tlsServerConn net.Conn,
connOptions *pogs.ConnectionOptions,
connOptions *client.ConnectionOptionsSnapshot,
controlStreamHandler connection.ControlStreamHandler,
connIndex uint8,
) error {
pqMode := e.config.FeatureSelector.PostQuantumMode()
pqMode := connOptions.FeatureSnapshot.PostQuantum
if pqMode == features.PostQuantumStrict {
return unrecoverableError{errors.New("HTTP/2 transport does not support post-quantum")}
}
@ -550,19 +544,19 @@ func (e *EdgeTunnelServer) serveQUIC(
ctx context.Context,
edgeAddr netip.AddrPort,
connLogger *ConnAwareLogger,
connOptions *pogs.ConnectionOptions,
connOptions *client.ConnectionOptionsSnapshot,
controlStreamHandler connection.ControlStreamHandler,
connIndex uint8,
) (err error, recoverable bool) {
tlsConfig := e.config.EdgeTLSConfigs[connection.QUIC]
pqMode := e.config.FeatureSelector.PostQuantumMode()
pqMode := connOptions.FeatureSnapshot.PostQuantum
curvePref, err := curvePreference(pqMode, fips.IsFipsEnabled(), tlsConfig.CurvePreferences)
if err != nil {
return err, true
}
connLogger.Logger().Info().Msgf("Using %v as curve preferences", curvePref)
connLogger.Logger().Info().Msgf("Tunnel connection curve preferences: %v", curvePref)
tlsConfig.CurvePreferences = curvePref
@ -600,12 +594,12 @@ func (e *EdgeTunnelServer) serveQUIC(
if err != nil {
connLogger.ConnAwareLogger().Err(err).Msgf("Failed to dial a quic connection")
e.reportErrorToSentry(err)
e.reportErrorToSentry(err, connOptions.FeatureSnapshot.PostQuantum)
return err, true
}
var datagramSessionManager connection.DatagramSessionHandler
if e.config.FeatureSelector.DatagramVersion() == features.DatagramV3 {
if connOptions.FeatureSnapshot.DatagramVersion == features.DatagramV3 {
datagramSessionManager = connection.NewDatagramV3Connection(
ctx,
conn,
@ -672,7 +666,7 @@ func (e *EdgeTunnelServer) serveQUIC(
// The reportErrorToSentry is an helper function that handles
// verifies if an error should be reported to Sentry.
func (e *EdgeTunnelServer) reportErrorToSentry(err error) {
func (e *EdgeTunnelServer) reportErrorToSentry(err error, pqMode features.PostQuantumMode) {
dialErr, ok := err.(*connection.EdgeQuicDialError)
if ok {
// The TransportError provides an Unwrap function however
@ -681,7 +675,7 @@ func (e *EdgeTunnelServer) reportErrorToSentry(err error) {
if ok &&
transportErr.ErrorCode.IsCryptoError() &&
fips.IsFipsEnabled() &&
e.config.FeatureSelector.PostQuantumMode() == features.PostQuantumStrict {
pqMode == features.PostQuantumStrict {
// Only report to Sentry when using FIPS, PQ,
// and the error is a Crypto error reported by
// an EdgeQuicDialError