mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-05-10 23:16:34 +00:00

This PR does two things: It changes how we fallback to a lower protocol: The current state is to try connecting with a protocol. If it fails, fall back to a lower protocol. And try connecting with that and so on. With this PR, if we fail to connect with a protocol, we will try to connect to other edge addresses first. Only if we fail to connect to those will we fall back to a lower protocol. It fixes a behaviour where if we fail to connect to an edge addr, we keep re-trying the same address over and over again. This PR now switches between edge addresses on subsequent connecton attempts. Note that through these switches, it still respects the backoff time. (We are connecting to a different edge, but this helps to not bombard an edge address with connect requests if a particular edge addresses stops working).
421 lines
13 KiB
Go
421 lines
13 KiB
Go
package supervisor
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/lucas-clemente/quic-go"
|
|
"github.com/rs/zerolog"
|
|
|
|
"github.com/cloudflare/cloudflared/connection"
|
|
"github.com/cloudflare/cloudflared/edgediscovery"
|
|
"github.com/cloudflare/cloudflared/h2mux"
|
|
"github.com/cloudflare/cloudflared/orchestration"
|
|
"github.com/cloudflare/cloudflared/retry"
|
|
"github.com/cloudflare/cloudflared/signal"
|
|
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
|
"github.com/cloudflare/cloudflared/tunnelstate"
|
|
)
|
|
|
|
const (
|
|
// SRV and TXT record resolution TTL
|
|
ResolveTTL = time.Hour
|
|
// Waiting time before retrying a failed tunnel connection
|
|
tunnelRetryDuration = time.Second * 10
|
|
// Interval between registering new tunnels
|
|
registrationInterval = time.Second
|
|
|
|
subsystemRefreshAuth = "refresh_auth"
|
|
// Maximum exponent for 'Authenticate' exponential backoff
|
|
refreshAuthMaxBackoff = 10
|
|
// Waiting time before retrying a failed 'Authenticate' connection
|
|
refreshAuthRetryDuration = time.Second * 10
|
|
)
|
|
|
|
// Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and
|
|
// reconnects them if they disconnect.
|
|
type Supervisor struct {
|
|
cloudflaredUUID uuid.UUID
|
|
config *TunnelConfig
|
|
orchestrator *orchestration.Orchestrator
|
|
edgeIPs *edgediscovery.Edge
|
|
edgeTunnelServer TunnelServer
|
|
tunnelErrors chan tunnelError
|
|
tunnelsConnecting map[int]chan struct{}
|
|
tunnelsProtocolFallback map[int]*protocolFallback
|
|
// nextConnectedIndex and nextConnectedSignal are used to wait for all
|
|
// currently-connecting tunnels to finish connecting so we can reset backoff timer
|
|
nextConnectedIndex int
|
|
nextConnectedSignal chan struct{}
|
|
|
|
log *ConnAwareLogger
|
|
logTransport *zerolog.Logger
|
|
|
|
reconnectCredentialManager *reconnectCredentialManager
|
|
useReconnectToken bool
|
|
|
|
reconnectCh chan ReconnectSignal
|
|
gracefulShutdownC <-chan struct{}
|
|
}
|
|
|
|
var errEarlyShutdown = errors.New("shutdown started")
|
|
|
|
type tunnelError struct {
|
|
index int
|
|
err error
|
|
}
|
|
|
|
func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrator, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
|
|
cloudflaredUUID, err := uuid.NewRandom()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err)
|
|
}
|
|
|
|
isStaticEdge := len(config.EdgeAddrs) > 0
|
|
|
|
var edgeIPs *edgediscovery.Edge
|
|
if isStaticEdge { // static edge addresses
|
|
edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs)
|
|
} else {
|
|
edgeIPs, err = edgediscovery.ResolveEdge(config.Log, config.Region, config.EdgeIPVersion)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
reconnectCredentialManager := newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections)
|
|
|
|
tracker := tunnelstate.NewConnTracker(config.Log)
|
|
log := NewConnAwareLogger(config.Log, tracker, config.Observer)
|
|
|
|
edgeAddrHandler := NewIPAddrFallback(config.MaxEdgeAddrRetries)
|
|
|
|
edgeTunnelServer := EdgeTunnelServer{
|
|
config: config,
|
|
cloudflaredUUID: cloudflaredUUID,
|
|
orchestrator: orchestrator,
|
|
credentialManager: reconnectCredentialManager,
|
|
edgeAddrs: edgeIPs,
|
|
edgeAddrHandler: edgeAddrHandler,
|
|
tracker: tracker,
|
|
reconnectCh: reconnectCh,
|
|
gracefulShutdownC: gracefulShutdownC,
|
|
connAwareLogger: log,
|
|
}
|
|
|
|
useReconnectToken := false
|
|
if config.ClassicTunnel != nil {
|
|
useReconnectToken = config.ClassicTunnel.UseReconnectToken
|
|
}
|
|
|
|
return &Supervisor{
|
|
cloudflaredUUID: cloudflaredUUID,
|
|
config: config,
|
|
orchestrator: orchestrator,
|
|
edgeIPs: edgeIPs,
|
|
edgeTunnelServer: &edgeTunnelServer,
|
|
tunnelErrors: make(chan tunnelError),
|
|
tunnelsConnecting: map[int]chan struct{}{},
|
|
tunnelsProtocolFallback: map[int]*protocolFallback{},
|
|
log: log,
|
|
logTransport: config.LogTransport,
|
|
reconnectCredentialManager: reconnectCredentialManager,
|
|
useReconnectToken: useReconnectToken,
|
|
reconnectCh: reconnectCh,
|
|
gracefulShutdownC: gracefulShutdownC,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Supervisor) Run(
|
|
ctx context.Context,
|
|
connectedSignal *signal.Signal,
|
|
) error {
|
|
if s.config.PacketConfig != nil {
|
|
go func() {
|
|
if err := s.config.PacketConfig.ICMPRouter.Serve(ctx); err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
s.log.Logger().Info().Err(err).Msg("icmp router terminated")
|
|
} else {
|
|
s.log.Logger().Err(err).Msg("icmp router terminated")
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
if err := s.initialize(ctx, connectedSignal); err != nil {
|
|
if err == errEarlyShutdown {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
var tunnelsWaiting []int
|
|
tunnelsActive := s.config.HAConnections
|
|
|
|
backoff := retry.BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
|
|
var backoffTimer <-chan time.Time
|
|
|
|
refreshAuthBackoff := &retry.BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
|
|
var refreshAuthBackoffTimer <-chan time.Time
|
|
|
|
if s.useReconnectToken {
|
|
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
|
|
refreshAuthBackoffTimer = timer
|
|
} else {
|
|
s.log.Logger().Err(err).
|
|
Dur("refreshAuthRetryDuration", refreshAuthRetryDuration).
|
|
Msgf("supervisor: initial refreshAuth failed, retrying in %v", refreshAuthRetryDuration)
|
|
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
|
|
}
|
|
}
|
|
|
|
shuttingDown := false
|
|
for {
|
|
select {
|
|
// Context cancelled
|
|
case <-ctx.Done():
|
|
for tunnelsActive > 0 {
|
|
<-s.tunnelErrors
|
|
tunnelsActive--
|
|
}
|
|
return nil
|
|
// startTunnel completed with a response
|
|
// (note that this may also be caused by context cancellation)
|
|
case tunnelError := <-s.tunnelErrors:
|
|
tunnelsActive--
|
|
if tunnelError.err != nil && !shuttingDown {
|
|
switch tunnelError.err.(type) {
|
|
case ReconnectSignal:
|
|
// For tunnels that closed with reconnect signal, we reconnect immediately
|
|
go s.startTunnel(ctx, tunnelError.index, s.newConnectedTunnelSignal(tunnelError.index))
|
|
tunnelsActive++
|
|
continue
|
|
}
|
|
// Make sure we don't continue if there is no more fallback allowed
|
|
if _, retry := s.tunnelsProtocolFallback[tunnelError.index].GetMaxBackoffDuration(ctx); !retry {
|
|
continue
|
|
}
|
|
s.log.ConnAwareLogger().Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
|
|
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
|
s.waitForNextTunnel(tunnelError.index)
|
|
|
|
if backoffTimer == nil {
|
|
backoffTimer = backoff.BackoffTimer()
|
|
}
|
|
} else if tunnelsActive == 0 {
|
|
s.log.ConnAwareLogger().Msg("no more connections active and exiting")
|
|
// All connected tunnels exited gracefully, no more work to do
|
|
return nil
|
|
}
|
|
// Backoff was set and its timer expired
|
|
case <-backoffTimer:
|
|
backoffTimer = nil
|
|
for _, index := range tunnelsWaiting {
|
|
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
|
|
}
|
|
tunnelsActive += len(tunnelsWaiting)
|
|
tunnelsWaiting = nil
|
|
// Time to call Authenticate
|
|
case <-refreshAuthBackoffTimer:
|
|
newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
|
|
if err != nil {
|
|
s.log.Logger().Err(err).Msg("supervisor: Authentication failed")
|
|
// Permanent failure. Leave the `select` without setting the
|
|
// channel to be non-null, so we'll never hit this case of the `select` again.
|
|
continue
|
|
}
|
|
refreshAuthBackoffTimer = newTimer
|
|
// Tunnel successfully connected
|
|
case <-s.nextConnectedSignal:
|
|
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
|
|
// No more tunnels outstanding, clear backoff timer
|
|
backoff.SetGracePeriod()
|
|
}
|
|
case <-s.gracefulShutdownC:
|
|
shuttingDown = true
|
|
}
|
|
}
|
|
}
|
|
|
|
// Returns nil if initialization succeeded, else the initialization error.
|
|
// Attempts here will be made to connect one tunnel, if successful, it will
|
|
// connect the available tunnels up to config.HAConnections.
|
|
func (s *Supervisor) initialize(
|
|
ctx context.Context,
|
|
connectedSignal *signal.Signal,
|
|
) error {
|
|
availableAddrs := s.edgeIPs.AvailableAddrs()
|
|
if s.config.HAConnections > availableAddrs {
|
|
s.log.Logger().Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
|
|
s.config.HAConnections = availableAddrs
|
|
}
|
|
s.tunnelsProtocolFallback[0] = &protocolFallback{
|
|
retry.BackoffHandler{MaxRetries: s.config.Retries, RetryForever: true},
|
|
s.config.ProtocolSelector.Current(),
|
|
false,
|
|
}
|
|
|
|
go s.startFirstTunnel(ctx, connectedSignal)
|
|
|
|
// Wait for response from first tunnel before proceeding to attempt other HA edge tunnels
|
|
select {
|
|
case <-ctx.Done():
|
|
<-s.tunnelErrors
|
|
return ctx.Err()
|
|
case tunnelError := <-s.tunnelErrors:
|
|
return tunnelError.err
|
|
case <-s.gracefulShutdownC:
|
|
return errEarlyShutdown
|
|
case <-connectedSignal.Wait():
|
|
}
|
|
|
|
// At least one successful connection, so start the rest
|
|
for i := 1; i < s.config.HAConnections; i++ {
|
|
s.tunnelsProtocolFallback[i] = &protocolFallback{
|
|
retry.BackoffHandler{MaxRetries: s.config.Retries, RetryForever: true},
|
|
// Set the protocol we know the first tunnel connected with.
|
|
s.tunnelsProtocolFallback[0].protocol,
|
|
false,
|
|
}
|
|
go s.startTunnel(ctx, i, s.newConnectedTunnelSignal(i))
|
|
time.Sleep(registrationInterval)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
|
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
|
func (s *Supervisor) startFirstTunnel(
|
|
ctx context.Context,
|
|
connectedSignal *signal.Signal,
|
|
) {
|
|
var (
|
|
err error
|
|
)
|
|
const firstConnIndex = 0
|
|
isStaticEdge := len(s.config.EdgeAddrs) > 0
|
|
defer func() {
|
|
s.tunnelErrors <- tunnelError{index: firstConnIndex, err: err}
|
|
}()
|
|
|
|
// If the first tunnel disconnects, keep restarting it.
|
|
for {
|
|
err = s.edgeTunnelServer.Serve(ctx, firstConnIndex, s.tunnelsProtocolFallback[firstConnIndex], connectedSignal)
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
if err == nil {
|
|
return
|
|
}
|
|
// Make sure we don't continue if there is no more fallback allowed
|
|
if _, retry := s.tunnelsProtocolFallback[firstConnIndex].GetMaxBackoffDuration(ctx); !retry {
|
|
return
|
|
}
|
|
// Try again for Unauthorized errors because we hope them to be
|
|
// transient due to edge propagation lag on new Tunnels.
|
|
if strings.Contains(err.Error(), "Unauthorized") {
|
|
continue
|
|
}
|
|
switch err.(type) {
|
|
case edgediscovery.ErrNoAddressesLeft:
|
|
// If your provided addresses are not available, we will keep trying regardless.
|
|
if !isStaticEdge {
|
|
return
|
|
}
|
|
case connection.DupConnRegisterTunnelError,
|
|
*quic.IdleTimeoutError,
|
|
edgediscovery.DialError,
|
|
*connection.EdgeQuicDialError:
|
|
// Try again for these types of errors
|
|
default:
|
|
// Uncaught errors should bail startup
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
|
// s.tunnelError as this is expected to run in a goroutine.
|
|
func (s *Supervisor) startTunnel(
|
|
ctx context.Context,
|
|
index int,
|
|
connectedSignal *signal.Signal,
|
|
) {
|
|
var (
|
|
err error
|
|
)
|
|
defer func() {
|
|
s.tunnelErrors <- tunnelError{index: index, err: err}
|
|
}()
|
|
|
|
err = s.edgeTunnelServer.Serve(ctx, uint8(index), s.tunnelsProtocolFallback[index], connectedSignal)
|
|
}
|
|
|
|
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
|
|
sig := make(chan struct{})
|
|
s.tunnelsConnecting[index] = sig
|
|
s.nextConnectedSignal = sig
|
|
s.nextConnectedIndex = index
|
|
return signal.New(sig)
|
|
}
|
|
|
|
func (s *Supervisor) waitForNextTunnel(index int) bool {
|
|
delete(s.tunnelsConnecting, index)
|
|
s.nextConnectedSignal = nil
|
|
for k, v := range s.tunnelsConnecting {
|
|
s.nextConnectedIndex = k
|
|
s.nextConnectedSignal = v
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (s *Supervisor) unusedIPs() bool {
|
|
return s.edgeIPs.AvailableAddrs() > s.config.HAConnections
|
|
}
|
|
|
|
func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error) {
|
|
arbitraryEdgeIP, err := s.edgeIPs.GetAddrForRPC()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.EdgeTLSConfigs[connection.H2mux], arbitraryEdgeIP.TCP)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer edgeConn.Close()
|
|
|
|
handler := h2mux.MuxedStreamFunc(func(*h2mux.MuxedStream) error {
|
|
// This callback is invoked by h2mux when the edge initiates a stream.
|
|
return nil // noop
|
|
})
|
|
muxerConfig := s.config.MuxerConfig.H2MuxerConfig(handler, s.logTransport)
|
|
muxer, err := h2mux.Handshake(edgeConn, edgeConn, *muxerConfig, h2mux.ActiveStreams)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
go muxer.Serve(ctx)
|
|
defer func() {
|
|
// If we don't wait for the muxer shutdown here, edgeConn.Close() runs before the muxer connections are done,
|
|
// and the user sees log noise: "error writing data", "connection closed unexpectedly"
|
|
<-muxer.Shutdown()
|
|
}()
|
|
|
|
stream, err := muxer.OpenRPCStream(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
rpcClient := connection.NewTunnelServerClient(ctx, stream, s.log.Logger())
|
|
defer rpcClient.Close()
|
|
|
|
const arbitraryConnectionID = uint8(0)
|
|
registrationOptions := s.config.registrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
|
|
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
|
|
return rpcClient.Authenticate(ctx, s.config.ClassicTunnel, registrationOptions)
|
|
}
|