cloudflared/supervisor/supervisor.go
Luis Neto 6496322bee TUN-9007: modify logic to resolve region when the tunnel token has an endpoint field
## Summary

Within the work of FEDRamp it is necessary to change the HA SD lookup to use as srv `fed-v2-origintunneld`

This work assumes that the tunnel token has an optional endpoint field which will be used to modify the behaviour of the HA SD lookup.

Finally, the presence of the endpoint will override region to _fed_ and fail if any value is passed for the flag region.

Closes TUN-9007
2025-02-25 19:03:41 +00:00

328 lines
10 KiB
Go

package supervisor
import (
"context"
"errors"
"net"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/quic-go/quic-go"
"github.com/rs/zerolog"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/ingress"
"github.com/cloudflare/cloudflared/orchestration"
v3 "github.com/cloudflare/cloudflared/quic/v3"
"github.com/cloudflare/cloudflared/retry"
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/tunnelstate"
)
const (
// Waiting time before retrying a failed tunnel connection
tunnelRetryDuration = time.Second * 10
// Interval between registering new tunnels
registrationInterval = time.Second
)
// Supervisor manages non-declarative tunnels. Establishes TCP connections with the edge, and
// reconnects them if they disconnect.
type Supervisor struct {
config *TunnelConfig
orchestrator *orchestration.Orchestrator
edgeIPs *edgediscovery.Edge
edgeTunnelServer TunnelServer
tunnelErrors chan tunnelError
tunnelsConnecting map[int]chan struct{}
tunnelsProtocolFallback map[int]*protocolFallback
// nextConnectedIndex and nextConnectedSignal are used to wait for all
// currently-connecting tunnels to finish connecting so we can reset backoff timer
nextConnectedIndex int
nextConnectedSignal chan struct{}
log *ConnAwareLogger
logTransport *zerolog.Logger
reconnectCh chan ReconnectSignal
gracefulShutdownC <-chan struct{}
}
var errEarlyShutdown = errors.New("shutdown started")
type tunnelError struct {
index int
err error
}
func NewSupervisor(config *TunnelConfig, orchestrator *orchestration.Orchestrator, reconnectCh chan ReconnectSignal, gracefulShutdownC <-chan struct{}) (*Supervisor, error) {
isStaticEdge := len(config.EdgeAddrs) > 0
var err error
var edgeIPs *edgediscovery.Edge
if isStaticEdge { // static edge addresses
edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs)
} else {
edgeIPs, err = edgediscovery.ResolveEdge(config.Log, config.Region, config.EdgeIPVersion)
}
if err != nil {
return nil, err
}
tracker := tunnelstate.NewConnTracker(config.Log)
log := NewConnAwareLogger(config.Log, tracker, config.Observer)
edgeAddrHandler := NewIPAddrFallback(config.MaxEdgeAddrRetries)
edgeBindAddr := config.EdgeBindAddr
datagramMetrics := v3.NewMetrics(prometheus.DefaultRegisterer)
sessionManager := v3.NewSessionManager(datagramMetrics, config.Log, ingress.DialUDPAddrPort, orchestrator.GetFlowLimiter())
edgeTunnelServer := EdgeTunnelServer{
config: config,
orchestrator: orchestrator,
sessionManager: sessionManager,
datagramMetrics: datagramMetrics,
edgeAddrs: edgeIPs,
edgeAddrHandler: edgeAddrHandler,
edgeBindAddr: edgeBindAddr,
tracker: tracker,
reconnectCh: reconnectCh,
gracefulShutdownC: gracefulShutdownC,
connAwareLogger: log,
}
return &Supervisor{
config: config,
orchestrator: orchestrator,
edgeIPs: edgeIPs,
edgeTunnelServer: &edgeTunnelServer,
tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{},
tunnelsProtocolFallback: map[int]*protocolFallback{},
log: log,
logTransport: config.LogTransport,
reconnectCh: reconnectCh,
gracefulShutdownC: gracefulShutdownC,
}, nil
}
func (s *Supervisor) Run(
ctx context.Context,
connectedSignal *signal.Signal,
) error {
if s.config.ICMPRouterServer != nil {
go func() {
if err := s.config.ICMPRouterServer.Serve(ctx); err != nil {
if errors.Is(err, net.ErrClosed) {
s.log.Logger().Info().Err(err).Msg("icmp router terminated")
} else {
s.log.Logger().Err(err).Msg("icmp router terminated")
}
}
}()
}
if err := s.initialize(ctx, connectedSignal); err != nil {
if err == errEarlyShutdown {
return nil
}
return err
}
var tunnelsWaiting []int
tunnelsActive := s.config.HAConnections
backoff := retry.NewBackoff(s.config.Retries, tunnelRetryDuration, true)
var backoffTimer <-chan time.Time
shuttingDown := false
for {
select {
// Context cancelled
case <-ctx.Done():
for tunnelsActive > 0 {
<-s.tunnelErrors
tunnelsActive--
}
return nil
// startTunnel completed with a response
// (note that this may also be caused by context cancellation)
case tunnelError := <-s.tunnelErrors:
tunnelsActive--
if tunnelError.err != nil && !shuttingDown {
switch tunnelError.err.(type) {
case ReconnectSignal:
// For tunnels that closed with reconnect signal, we reconnect immediately
go s.startTunnel(ctx, tunnelError.index, s.newConnectedTunnelSignal(tunnelError.index))
tunnelsActive++
continue
}
// Make sure we don't continue if there is no more fallback allowed
if _, retry := s.tunnelsProtocolFallback[tunnelError.index].GetMaxBackoffDuration(ctx); !retry {
continue
}
s.log.ConnAwareLogger().Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
s.waitForNextTunnel(tunnelError.index)
if backoffTimer == nil {
backoffTimer = backoff.BackoffTimer()
}
} else if tunnelsActive == 0 {
s.log.ConnAwareLogger().Msg("no more connections active and exiting")
// All connected tunnels exited gracefully, no more work to do
return nil
}
// Backoff was set and its timer expired
case <-backoffTimer:
backoffTimer = nil
for _, index := range tunnelsWaiting {
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
}
tunnelsActive += len(tunnelsWaiting)
tunnelsWaiting = nil
// Tunnel successfully connected
case <-s.nextConnectedSignal:
if !s.waitForNextTunnel(s.nextConnectedIndex) && len(tunnelsWaiting) == 0 {
// No more tunnels outstanding, clear backoff timer
backoff.SetGracePeriod()
}
case <-s.gracefulShutdownC:
shuttingDown = true
}
}
}
// Returns nil if initialization succeeded, else the initialization error.
// Attempts here will be made to connect one tunnel, if successful, it will
// connect the available tunnels up to config.HAConnections.
func (s *Supervisor) initialize(
ctx context.Context,
connectedSignal *signal.Signal,
) error {
availableAddrs := s.edgeIPs.AvailableAddrs()
if s.config.HAConnections > availableAddrs {
s.log.Logger().Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
s.config.HAConnections = availableAddrs
}
s.tunnelsProtocolFallback[0] = &protocolFallback{
retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
s.config.ProtocolSelector.Current(),
false,
}
go s.startFirstTunnel(ctx, connectedSignal)
// Wait for response from first tunnel before proceeding to attempt other HA edge tunnels
select {
case <-ctx.Done():
<-s.tunnelErrors
return ctx.Err()
case tunnelError := <-s.tunnelErrors:
return tunnelError.err
case <-s.gracefulShutdownC:
return errEarlyShutdown
case <-connectedSignal.Wait():
}
// At least one successful connection, so start the rest
for i := 1; i < s.config.HAConnections; i++ {
s.tunnelsProtocolFallback[i] = &protocolFallback{
retry.NewBackoff(s.config.Retries, retry.DefaultBaseTime, true),
// Set the protocol we know the first tunnel connected with.
s.tunnelsProtocolFallback[0].protocol,
false,
}
go s.startTunnel(ctx, i, s.newConnectedTunnelSignal(i))
time.Sleep(registrationInterval)
}
return nil
}
// startTunnel starts the first tunnel connection. The resulting error will be sent on
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
func (s *Supervisor) startFirstTunnel(
ctx context.Context,
connectedSignal *signal.Signal,
) {
var err error
const firstConnIndex = 0
isStaticEdge := len(s.config.EdgeAddrs) > 0
defer func() {
s.tunnelErrors <- tunnelError{index: firstConnIndex, err: err}
}()
// If the first tunnel disconnects, keep restarting it.
for {
err = s.edgeTunnelServer.Serve(ctx, firstConnIndex, s.tunnelsProtocolFallback[firstConnIndex], connectedSignal)
if ctx.Err() != nil {
return
}
if err == nil {
return
}
// Make sure we don't continue if there is no more fallback allowed
if _, retry := s.tunnelsProtocolFallback[firstConnIndex].GetMaxBackoffDuration(ctx); !retry {
return
}
// Try again for Unauthorized errors because we hope them to be
// transient due to edge propagation lag on new Tunnels.
if strings.Contains(err.Error(), "Unauthorized") {
continue
}
switch err.(type) {
case edgediscovery.ErrNoAddressesLeft:
// If your provided addresses are not available, we will keep trying regardless.
if !isStaticEdge {
return
}
case connection.DupConnRegisterTunnelError,
*quic.IdleTimeoutError,
*quic.ApplicationError,
edgediscovery.DialError,
*connection.EdgeQuicDialError:
// Try again for these types of errors
default:
// Uncaught errors should bail startup
return
}
}
}
// startTunnel starts a new tunnel connection. The resulting error will be sent on
// s.tunnelError as this is expected to run in a goroutine.
func (s *Supervisor) startTunnel(
ctx context.Context,
index int,
connectedSignal *signal.Signal,
) {
var err error
defer func() {
s.tunnelErrors <- tunnelError{index: index, err: err}
}()
// nolint: gosec
err = s.edgeTunnelServer.Serve(ctx, uint8(index), s.tunnelsProtocolFallback[index], connectedSignal)
}
func (s *Supervisor) newConnectedTunnelSignal(index int) *signal.Signal {
sig := make(chan struct{})
s.tunnelsConnecting[index] = sig
s.nextConnectedSignal = sig
s.nextConnectedIndex = index
return signal.New(sig)
}
func (s *Supervisor) waitForNextTunnel(index int) bool {
delete(s.tunnelsConnecting, index)
s.nextConnectedSignal = nil
for k, v := range s.tunnelsConnecting {
s.nextConnectedIndex = k
s.nextConnectedSignal = v
return true
}
return false
}