TUN-3268: Each connection has its own event digest to reconnect

This commit is contained in:
cthuang
2020-08-18 11:14:14 +01:00
parent 9323844ea7
commit 8eeb452cce
5 changed files with 230 additions and 263 deletions

View File

@@ -3,9 +3,7 @@ package origin
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/google/uuid"
@@ -37,7 +35,6 @@ const (
)
var (
errJWTUnset = errors.New("JWT unset")
errEventDigestUnset = errors.New("event digest unset")
)
@@ -58,14 +55,7 @@ type Supervisor struct {
logger logger.Service
jwtLock sync.RWMutex
jwt []byte
eventDigestLock sync.RWMutex
eventDigest []byte
connDigestLock sync.RWMutex
connDigest map[uint8][]byte
reconnectCredentialManager *reconnectCredentialManager
bufferPool *buffer.Pool
}
@@ -95,14 +85,14 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
}
return &Supervisor{
cloudflaredUUID: cloudflaredUUID,
config: config,
edgeIPs: edgeIPs,
tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{},
logger: config.Logger,
connDigest: make(map[uint8][]byte),
bufferPool: buffer.NewPool(512 * 1024),
cloudflaredUUID: cloudflaredUUID,
config: config,
edgeIPs: edgeIPs,
tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{},
logger: config.Logger,
reconnectCredentialManager: newReconnectCredentialManager(metricsNamespace, tunnelSubsystem, config.HAConnections),
bufferPool: buffer.NewPool(512 * 1024),
}, nil
}
@@ -121,7 +111,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
var refreshAuthBackoffTimer <-chan time.Time
if s.config.UseReconnectToken {
if timer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
refreshAuthBackoffTimer = timer
} else {
logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
@@ -164,7 +154,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
tunnelsWaiting = nil
// Time to call Authenticate
case <-refreshAuthBackoffTimer:
newTimer, err := s.refreshAuth(ctx, refreshAuthBackoff, s.authenticate)
newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
if err != nil {
logger.Errorf("supervisor: Authentication failed: %s", err)
// Permanent failure. Leave the `select` without setting the
@@ -237,7 +227,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return
}
err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
// If the first tunnel disconnects, keep restarting it.
edgeErrors := 0
for s.unusedIPs() {
@@ -260,7 +250,7 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
return
}
}
err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
}
}
@@ -279,7 +269,7 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
if err != nil {
return
}
err = ServeTunnelLoop(ctx, s, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
err = ServeTunnelLoop(ctx, s.reconnectCredentialManager, s.config, addr, uint8(index), connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
}
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
@@ -305,90 +295,6 @@ func (s *Supervisor) unusedIPs() bool {
return s.edgeIPs.AvailableAddrs() > s.config.HAConnections
}
func (s *Supervisor) ReconnectToken() ([]byte, error) {
s.jwtLock.RLock()
defer s.jwtLock.RUnlock()
if s.jwt == nil {
return nil, errJWTUnset
}
return s.jwt, nil
}
func (s *Supervisor) SetReconnectToken(jwt []byte) {
s.jwtLock.Lock()
defer s.jwtLock.Unlock()
s.jwt = jwt
}
func (s *Supervisor) EventDigest() ([]byte, error) {
s.eventDigestLock.RLock()
defer s.eventDigestLock.RUnlock()
if s.eventDigest == nil {
return nil, errEventDigestUnset
}
return s.eventDigest, nil
}
func (s *Supervisor) SetEventDigest(eventDigest []byte) {
s.eventDigestLock.Lock()
defer s.eventDigestLock.Unlock()
s.eventDigest = eventDigest
}
func (s *Supervisor) ConnDigest(connID uint8) ([]byte, error) {
s.connDigestLock.RLock()
defer s.connDigestLock.RUnlock()
digest, ok := s.connDigest[connID]
if !ok {
return nil, fmt.Errorf("no connection digest for connection %v", connID)
}
return digest, nil
}
func (s *Supervisor) SetConnDigest(connID uint8, connDigest []byte) {
s.connDigestLock.Lock()
defer s.connDigestLock.Unlock()
s.connDigest[connID] = connDigest
}
func (s *Supervisor) refreshAuth(
ctx context.Context,
backoff *BackoffHandler,
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
) (retryTimer <-chan time.Time, err error) {
logger := s.config.Logger
authOutcome, err := authenticate(ctx, backoff.Retries())
if err != nil {
s.config.Metrics.authFail.WithLabelValues(err.Error()).Inc()
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
logger.Debugf("refresh_auth: Retrying in %v: %s", duration, err)
return backoff.BackoffTimer(), nil
}
return nil, err
}
// clear backoff timer
backoff.SetGracePeriod()
switch outcome := authOutcome.(type) {
case tunnelpogs.AuthSuccess:
s.SetReconnectToken(outcome.JWT())
s.config.Metrics.authSuccess.Inc()
return timeAfter(outcome.RefreshAfter()), nil
case tunnelpogs.AuthUnknown:
duration := outcome.RefreshAfter()
s.config.Metrics.authFail.WithLabelValues(outcome.Error()).Inc()
logger.Debugf("refresh_auth: Retrying in %v: %s", duration, outcome)
return timeAfter(duration), nil
case tunnelpogs.AuthFail:
s.config.Metrics.authFail.WithLabelValues(outcome.Error()).Inc()
return nil, outcome
default:
err := fmt.Errorf("refresh_auth: Unexpected outcome type %T", authOutcome)
s.config.Metrics.authFail.WithLabelValues(err.Error()).Inc()
return nil, err
}
}
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
arbitraryEdgeIP, err := s.edgeIPs.GetAddrForRPC()
if err != nil {