mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 15:49:58 +00:00
TUN-3268: Each connection has its own event digest to reconnect
This commit is contained in:
@@ -3,9 +3,7 @@ package origin
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -37,7 +35,6 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
errJWTUnset = errors.New("JWT unset")
|
||||
errEventDigestUnset = errors.New("event digest unset")
|
||||
)
|
||||
|
||||
@@ -58,14 +55,7 @@ type Supervisor struct {
|
||||
|
||||
logger logger.Service
|
||||
|
||||
jwtLock sync.RWMutex
|
||||
jwt []byte
|
||||
|
||||
eventDigestLock sync.RWMutex
|
||||
eventDigest []byte
|
||||
|
||||
connDigestLock sync.RWMutex
|
||||
connDigest map[uint8][]byte
|
||||
reconnectCredentialManager *reconnectCredentialManager
|
||||
|
||||
bufferPool *buffer.Pool
|
||||
}
|
||||
@@ -95,14 +85,14 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
|
||||
}
|
||||
|
||||
return &Supervisor{
|
||||
cloudflaredUUID: cloudflaredUUID,
|
||||
config: config,
|
||||
edgeIPs: edgeIPs,
|
||||
tunnelErrors: make(chan tunnelError),
|
||||
tunnelsConnecting: map[int]chan struct{}{},
|
||||
logger: config.Logger,
|
||||
connDigest: make(map[uint8][]byte),
|
||||
bufferPool: buffer.NewPool(512 * 1024),
|
||||
cloudflaredUUID: cloudflaredUUID,
|
||||
config: config,
|
||||
edgeIPs: edgeIPs,
|
||||
tunnelErrors: make(chan tunnelError),
|
||||
tunnelsConnecting: map[int]chan struct{}{},
|
||||
logger: config.Logger,
|
||||
reconnectCredentialManager: newReconnectCredentialManager(metricsNamespace, tunnelSubsystem, config.HAConnections),
|
||||
bufferPool: buffer.NewPool(512 * 1024),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -121,7 +111,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||
var refreshAuthBackoffTimer <-chan time.Time
|
||||
|
||||
if s.config.UseReconnectToken {
|
||||
if timer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
|
||||
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
|
||||
refreshAuthBackoffTimer = timer
|
||||
} else {
|
||||
logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
|
||||
@@ -164,7 +154,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||
tunnelsWaiting = nil
|
||||
// Time to call Authenticate
|
||||
case <-refreshAuthBackoffTimer:
|
||||
newTimer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate)
|
||||
newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
|
||||
if err != nil {
|
||||
logger.Errorf("supervisor: Authentication failed: %s", err)
|
||||
// Permanent failure. Leave the `select` without setting the
|
||||
@@ -237,7 +227,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||
return
|
||||
}
|
||||
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
// If the first tunnel disconnects, keep restarting it.
|
||||
edgeErrors := 0
|
||||
for s.unusedIPs() {
|
||||
@@ -260,7 +250,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||
return
|
||||
}
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -279,7 +269,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
}
|
||||
|
||||
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
||||
@@ -305,90 +295,6 @@ func (s *Supervisor) unusedIPs() bool {
|
||||
return s.edgeIPs.AvailableAddrs() > s.config.HAConnections
|
||||
}
|
||||
|
||||
func (s *Supervisor) ReconnectToken() ([]byte, error) {
|
||||
s.jwtLock.RLock()
|
||||
defer s.jwtLock.RUnlock()
|
||||
if s.jwt == nil {
|
||||
return nil, errJWTUnset
|
||||
}
|
||||
return s.jwt, nil
|
||||
}
|
||||
|
||||
func (s *Supervisor) SetReconnectToken(jwt []byte) {
|
||||
s.jwtLock.Lock()
|
||||
defer s.jwtLock.Unlock()
|
||||
s.jwt = jwt
|
||||
}
|
||||
|
||||
func (s *Supervisor) EventDigest() ([]byte, error) {
|
||||
s.eventDigestLock.RLock()
|
||||
defer s.eventDigestLock.RUnlock()
|
||||
if s.eventDigest == nil {
|
||||
return nil, errEventDigestUnset
|
||||
}
|
||||
return s.eventDigest, nil
|
||||
}
|
||||
|
||||
func (s *Supervisor) SetEventDigest(eventDigest []byte) {
|
||||
s.eventDigestLock.Lock()
|
||||
defer s.eventDigestLock.Unlock()
|
||||
s.eventDigest = eventDigest
|
||||
}
|
||||
|
||||
func (s *Supervisor) ConnDigest(connID uint8) ([]byte, error) {
|
||||
s.connDigestLock.RLock()
|
||||
defer s.connDigestLock.RUnlock()
|
||||
digest, ok := s.connDigest[connID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no connection digest for connection %v", connID)
|
||||
}
|
||||
return digest, nil
|
||||
}
|
||||
|
||||
func (s *Supervisor) SetConnDigest(connID uint8, connDigest []byte) {
|
||||
s.connDigestLock.Lock()
|
||||
defer s.connDigestLock.Unlock()
|
||||
s.connDigest[connID] = connDigest
|
||||
}
|
||||
|
||||
func (s *Supervisor) refreshAuth(
|
||||
ctx context.Context,
|
||||
backoff *BackoffHandler,
|
||||
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
|
||||
) (retryTimer <-chan time.Time, err error) {
|
||||
logger := s.config.Logger
|
||||
authOutcome, err := authenticate(ctx, backoff.Retries())
|
||||
if err != nil {
|
||||
s.config.Metrics.authFail.WithLabelValues(err.Error()).Inc()
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||
logger.Debugf("refresh_auth: Retrying in %v: %s", duration, err)
|
||||
return backoff.BackoffTimer(), nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
// clear backoff timer
|
||||
backoff.SetGracePeriod()
|
||||
|
||||
switch outcome := authOutcome.(type) {
|
||||
case tunnelpogs.AuthSuccess:
|
||||
s.SetReconnectToken(outcome.JWT())
|
||||
s.config.Metrics.authSuccess.Inc()
|
||||
return timeAfter(outcome.RefreshAfter()), nil
|
||||
case tunnelpogs.AuthUnknown:
|
||||
duration := outcome.RefreshAfter()
|
||||
s.config.Metrics.authFail.WithLabelValues(outcome.Error()).Inc()
|
||||
logger.Debugf("refresh_auth: Retrying in %v: %s", duration, outcome)
|
||||
return timeAfter(duration), nil
|
||||
case tunnelpogs.AuthFail:
|
||||
s.config.Metrics.authFail.WithLabelValues(outcome.Error()).Inc()
|
||||
return nil, outcome
|
||||
default:
|
||||
err := fmt.Errorf("refresh_auth: Unexpected outcome type %T", authOutcome)
|
||||
s.config.Metrics.authFail.WithLabelValues(err.Error()).Inc()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
|
||||
arbitraryEdgeIP, err := s.edgeIPs.GetAddrForRPC()
|
||||
if err != nil {
|
||||
|
Reference in New Issue
Block a user