mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 19:19:57 +00:00
TUN-8427: Fix BackoffHandler's internally shared clock structure
A clock structure was used to help support unit testing timetravel but it is a globally shared object and is likely unsafe to share across tests. Reordering of the tests seemed to have intermittent failures for the TestWaitForBackoffFallback specifically on windows builds. Adjusting this to be a shim inside the BackoffHandler struct should resolve shared object overrides in unit testing. Additionally, added the reset retries functionality to be inline with the ResetNow function of the BackoffHandler to align better with expected functionality of the method. Removes unused reconnectCredentialManager.
This commit is contained in:
@@ -1,138 +0,0 @@
|
||||
package supervisor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
|
||||
"github.com/cloudflare/cloudflared/retry"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
var (
|
||||
errJWTUnset = errors.New("JWT unset")
|
||||
)
|
||||
|
||||
// reconnectTunnelCredentialManager is invoked by functions in tunnel.go to
|
||||
// get/set parameters for ReconnectTunnel RPC calls.
|
||||
type reconnectCredentialManager struct {
|
||||
mu sync.RWMutex
|
||||
jwt []byte
|
||||
eventDigest map[uint8][]byte
|
||||
connDigest map[uint8][]byte
|
||||
authSuccess prometheus.Counter
|
||||
authFail *prometheus.CounterVec
|
||||
}
|
||||
|
||||
func newReconnectCredentialManager(namespace, subsystem string, haConnections int) *reconnectCredentialManager {
|
||||
authSuccess := prometheus.NewCounter(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "tunnel_authenticate_success",
|
||||
Help: "Count of successful tunnel authenticate",
|
||||
},
|
||||
)
|
||||
authFail := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: namespace,
|
||||
Subsystem: subsystem,
|
||||
Name: "tunnel_authenticate_fail",
|
||||
Help: "Count of tunnel authenticate errors by type",
|
||||
},
|
||||
[]string{"error"},
|
||||
)
|
||||
prometheus.MustRegister(authSuccess, authFail)
|
||||
return &reconnectCredentialManager{
|
||||
eventDigest: make(map[uint8][]byte, haConnections),
|
||||
connDigest: make(map[uint8][]byte, haConnections),
|
||||
authSuccess: authSuccess,
|
||||
authFail: authFail,
|
||||
}
|
||||
}
|
||||
|
||||
func (cm *reconnectCredentialManager) ReconnectToken() ([]byte, error) {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
if cm.jwt == nil {
|
||||
return nil, errJWTUnset
|
||||
}
|
||||
return cm.jwt, nil
|
||||
}
|
||||
|
||||
func (cm *reconnectCredentialManager) SetReconnectToken(jwt []byte) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
cm.jwt = jwt
|
||||
}
|
||||
|
||||
func (cm *reconnectCredentialManager) EventDigest(connID uint8) ([]byte, error) {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
digest, ok := cm.eventDigest[connID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no event digest for connection %v", connID)
|
||||
}
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
func (cm *reconnectCredentialManager) SetEventDigest(connID uint8, digest []byte) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
cm.eventDigest[connID] = digest
|
||||
}
|
||||
|
||||
func (cm *reconnectCredentialManager) ConnDigest(connID uint8) ([]byte, error) {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
digest, ok := cm.connDigest[connID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no connection digest for connection %v", connID)
|
||||
}
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
func (cm *reconnectCredentialManager) SetConnDigest(connID uint8, digest []byte) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
cm.connDigest[connID] = digest
|
||||
}
|
||||
|
||||
func (cm *reconnectCredentialManager) RefreshAuth(
|
||||
ctx context.Context,
|
||||
backoff *retry.BackoffHandler,
|
||||
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
|
||||
) (retryTimer <-chan time.Time, err error) {
|
||||
authOutcome, err := authenticate(ctx, backoff.Retries())
|
||||
if err != nil {
|
||||
cm.authFail.WithLabelValues(err.Error()).Inc()
|
||||
if _, ok := backoff.GetMaxBackoffDuration(ctx); ok {
|
||||
return backoff.BackoffTimer(), nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
// clear backoff timer
|
||||
backoff.SetGracePeriod()
|
||||
|
||||
switch outcome := authOutcome.(type) {
|
||||
case tunnelpogs.AuthSuccess:
|
||||
cm.SetReconnectToken(outcome.JWT())
|
||||
cm.authSuccess.Inc()
|
||||
return retry.Clock.After(outcome.RefreshAfter()), nil
|
||||
case tunnelpogs.AuthUnknown:
|
||||
duration := outcome.RefreshAfter()
|
||||
cm.authFail.WithLabelValues(outcome.Error()).Inc()
|
||||
return retry.Clock.After(duration), nil
|
||||
case tunnelpogs.AuthFail:
|
||||
cm.authFail.WithLabelValues(outcome.Error()).Inc()
|
||||
return nil, outcome
|
||||
default:
|
||||
err := fmt.Errorf("refresh_auth: Unexpected outcome type %T", authOutcome)
|
||||
cm.authFail.WithLabelValues(err.Error()).Inc()
|
||||
return nil, err
|
||||
}
|
||||
}
|
@@ -1,120 +0,0 @@
|
||||
package supervisor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/retry"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
func TestRefreshAuthBackoff(t *testing.T) {
|
||||
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
||||
|
||||
var wait time.Duration
|
||||
retry.Clock.After = func(d time.Duration) <-chan time.Time {
|
||||
wait = d
|
||||
return time.After(d)
|
||||
}
|
||||
backoff := &retry.BackoffHandler{MaxRetries: 3}
|
||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return nil, fmt.Errorf("authentication failure")
|
||||
}
|
||||
|
||||
// authentication failures should consume the backoff
|
||||
for i := uint(0); i < backoff.MaxRetries; i++ {
|
||||
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, retryChan)
|
||||
require.Greater(t, wait.Seconds(), 0.0)
|
||||
require.Less(t, wait.Seconds(), float64((1<<(i+1))*time.Second))
|
||||
}
|
||||
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, retryChan)
|
||||
|
||||
// now we actually make contact with the remote server
|
||||
_, _ = rcm.RefreshAuth(context.Background(), backoff, func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
||||
})
|
||||
|
||||
// The backoff timer should have been reset. To confirm this, make timeNow
|
||||
// return a value after the backoff timer's grace period
|
||||
retry.Clock.Now = func() time.Time {
|
||||
expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries)
|
||||
return time.Now().Add(expectedGracePeriod * 2)
|
||||
}
|
||||
_, ok := backoff.GetMaxBackoffDuration(context.Background())
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestRefreshAuthSuccess(t *testing.T) {
|
||||
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
||||
|
||||
var wait time.Duration
|
||||
retry.Clock.After = func(d time.Duration) <-chan time.Time {
|
||||
wait = d
|
||||
return time.After(d)
|
||||
}
|
||||
|
||||
backoff := &retry.BackoffHandler{MaxRetries: 3}
|
||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
|
||||
}
|
||||
|
||||
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, retryChan)
|
||||
assert.Equal(t, 19*time.Hour, wait)
|
||||
|
||||
token, err := rcm.ReconnectToken()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []byte("jwt"), token)
|
||||
}
|
||||
|
||||
func TestRefreshAuthUnknown(t *testing.T) {
|
||||
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
||||
|
||||
var wait time.Duration
|
||||
retry.Clock.After = func(d time.Duration) <-chan time.Time {
|
||||
wait = d
|
||||
return time.After(d)
|
||||
}
|
||||
|
||||
backoff := &retry.BackoffHandler{MaxRetries: 3}
|
||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
|
||||
}
|
||||
|
||||
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, retryChan)
|
||||
assert.Equal(t, 19*time.Hour, wait)
|
||||
|
||||
token, err := rcm.ReconnectToken()
|
||||
assert.Equal(t, errJWTUnset, err)
|
||||
assert.Nil(t, token)
|
||||
}
|
||||
|
||||
func TestRefreshAuthFail(t *testing.T) {
|
||||
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
|
||||
|
||||
backoff := &retry.BackoffHandler{MaxRetries: 3}
|
||||
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
|
||||
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
|
||||
}
|
||||
|
||||
retryChan, err := rcm.RefreshAuth(context.Background(), backoff, auth)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, retryChan)
|
||||
|
||||
token, err := rcm.ReconnectToken()
|
||||
assert.Equal(t, errJWTUnset, err)
|
||||
assert.Nil(t, token)
|
||||
}
|
@@ -49,8 +49,6 @@ type Supervisor struct {
|
||||
log *ConnAwareLogger
|
||||
logTransport *zerolog.Logger
|
||||
|
||||
reconnectCredentialManager *reconnectCredentialManager
|
||||
|
||||
reconnectCh chan ReconnectSignal
|
||||
gracefulShutdownC <-chan struct{}
|
||||
}
|
||||
@@ -76,8 +74,6 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reconnectCredentialManager := newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections)
|
||||
|
||||
tracker := tunnelstate.NewConnTracker(config.Log)
|
||||
log := NewConnAwareLogger(config.Log, tracker, config.Observer)
|
||||
|
||||
@@ -87,7 +83,6 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
|
||||
edgeTunnelServer := EdgeTunnelServer{
|
||||
config: config,
|
||||
orchestrator: orchestrator,
|
||||
credentialManager: reconnectCredentialManager,
|
||||
edgeAddrs: edgeIPs,
|
||||
edgeAddrHandler: edgeAddrHandler,
|
||||
edgeBindAddr: edgeBindAddr,
|
||||
@@ -98,18 +93,17 @@ func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrato
|
||||
}
|
||||
|
||||
return &Supervisor{
|
||||
config: config,
|
||||
orchestrator: orchestrator,
|
||||
edgeIPs: edgeIPs,
|
||||
edgeTunnelServer: &edgeTunnelServer,
|
||||
tunnelErrors: make(chan tunnelError),
|
||||
tunnelsConnecting: map[int]chan struct{}{},
|
||||
tunnelsProtocolFallback: map[int]*protocolFallback{},
|
||||
log: log,
|
||||
logTransport: config.LogTransport,
|
||||
reconnectCredentialManager: reconnectCredentialManager,
|
||||
reconnectCh: reconnectCh,
|
||||
gracefulShutdownC: gracefulShutdownC,
|
||||
config: config,
|
||||
orchestrator: orchestrator,
|
||||
edgeIPs: edgeIPs,
|
||||
edgeTunnelServer: &edgeTunnelServer,
|
||||
tunnelErrors: make(chan tunnelError),
|
||||
tunnelsConnecting: map[int]chan struct{}{},
|
||||
tunnelsProtocolFallback: map[int]*protocolFallback{},
|
||||
log: log,
|
||||
logTransport: config.LogTransport,
|
||||
reconnectCh: reconnectCh,
|
||||
gracefulShutdownC: gracefulShutdownC,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -138,7 +132,7 @@ func (s *Supervisor) Run(
|
||||
var tunnelsWaiting []int
|
||||
tunnelsActive := s.config.HAConnections
|
||||
|
||||
backoff := retry.BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
|
||||
backoff := retry.NewBackoff(s.config.Retries, tunnelRetryDuration, true)
|
||||
var backoffTimer <-chan time.Time
|
||||
|
||||
shuttingDown := false
|
||||
@@ -212,7 +206,7 @@ func (s *Supervisor) initialize(
|
||||
s.config.HAConnections = availableAddrs
|
||||
}
|
||||
s.tunnelsProtocolFallback[0] = &protocolFallback{
|
||||
retry.BackoffHandler{MaxRetries: s.config.Retries, RetryForever: true},
|
||||
retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
|
||||
s.config.ProtocolSelector.Current(),
|
||||
false,
|
||||
}
|
||||
@@ -234,7 +228,7 @@ func (s *Supervisor) initialize(
|
||||
// At least one successful connection, so start the rest
|
||||
for i := 1; i < s.config.HAConnections; i++ {
|
||||
s.tunnelsProtocolFallback[i] = &protocolFallback{
|
||||
retry.BackoffHandler{MaxRetries: s.config.Retries, RetryForever: true},
|
||||
retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
|
||||
// Set the protocol we know the first tunnel connected with.
|
||||
s.tunnelsProtocolFallback[0].protocol,
|
||||
false,
|
||||
|
@@ -203,7 +203,6 @@ func (f *ipAddrFallback) ShouldGetNewAddress(connIndex uint8, err error) (needsN
|
||||
type EdgeTunnelServer struct {
|
||||
config *TunnelConfig
|
||||
orchestrator *orchestration.Orchestrator
|
||||
credentialManager *reconnectCredentialManager
|
||||
edgeAddrHandler EdgeAddrHandler
|
||||
edgeAddrs *edgediscovery.Edge
|
||||
edgeBindAddr net.IP
|
||||
|
@@ -24,14 +24,18 @@ func (dmf *dynamicMockFetcher) fetch() edgediscovery.PercentageFetcher {
|
||||
}
|
||||
}
|
||||
|
||||
func immediateTimeAfter(time.Duration) <-chan time.Time {
|
||||
c := make(chan time.Time, 1)
|
||||
c <- time.Now()
|
||||
return c
|
||||
}
|
||||
|
||||
func TestWaitForBackoffFallback(t *testing.T) {
|
||||
maxRetries := uint(3)
|
||||
backoff := retry.BackoffHandler{
|
||||
MaxRetries: maxRetries,
|
||||
BaseTime: time.Millisecond * 10,
|
||||
}
|
||||
backoff := retry.NewBackoff(maxRetries, 40*time.Millisecond, false)
|
||||
backoff.Clock.After = immediateTimeAfter
|
||||
log := zerolog.Nop()
|
||||
resolveTTL := time.Duration(0)
|
||||
resolveTTL := 10 * time.Second
|
||||
mockFetcher := dynamicMockFetcher{
|
||||
protocolPercents: edgediscovery.ProtocolPercents{edgediscovery.ProtocolPercent{Protocol: "quic", Percentage: 100}},
|
||||
}
|
||||
@@ -64,21 +68,23 @@ func TestWaitForBackoffFallback(t *testing.T) {
|
||||
}
|
||||
|
||||
// Retry fallback protocol
|
||||
for i := 0; i < int(maxRetries); i++ {
|
||||
protoFallback.BackoffTimer() // simulate retry
|
||||
ok := selectNextProtocol(&log, protoFallback, protocolSelector, nil)
|
||||
assert.True(t, ok)
|
||||
fallback, ok := protocolSelector.Fallback()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, fallback, protoFallback.protocol)
|
||||
}
|
||||
protoFallback.BackoffTimer() // simulate retry
|
||||
ok := selectNextProtocol(&log, protoFallback, protocolSelector, nil)
|
||||
assert.True(t, ok)
|
||||
fallback, ok := protocolSelector.Fallback()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, fallback, protoFallback.protocol)
|
||||
assert.Equal(t, connection.HTTP2, protoFallback.protocol)
|
||||
|
||||
currentGlobalProtocol := protocolSelector.Current()
|
||||
assert.Equal(t, initProtocol, currentGlobalProtocol)
|
||||
|
||||
// Simulate max retries again (retries reset after protocol switch)
|
||||
for i := 0; i < int(maxRetries); i++ {
|
||||
protoFallback.BackoffTimer()
|
||||
}
|
||||
// No protocol to fallback, return error
|
||||
protoFallback.BackoffTimer() // simulate retry
|
||||
ok := selectNextProtocol(&log, protoFallback, protocolSelector, nil)
|
||||
ok = selectNextProtocol(&log, protoFallback, protocolSelector, nil)
|
||||
assert.False(t, ok)
|
||||
|
||||
protoFallback.reset()
|
||||
|
Reference in New Issue
Block a user