mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-28 15:40:02 +00:00
TUN-3118: Changed graceful shutdown to immediately unregister tunnel from the edge, keep the connection open until the edge drops it or grace period expires
This commit is contained in:
@@ -3,6 +3,7 @@ package origin
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
@@ -54,6 +55,9 @@ type Supervisor struct {
|
||||
|
||||
reconnectCredentialManager *reconnectCredentialManager
|
||||
useReconnectToken bool
|
||||
|
||||
reconnectCh chan ReconnectSignal
|
||||
gracefulShutdownC chan struct{}
|
||||
}
|
||||
|
||||
type tunnelError struct {
|
||||
@@ -62,11 +66,13 @@ type tunnelError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor, error) {
|
||||
var (
|
||||
edgeIPs *edgediscovery.Edge
|
||||
err error
|
||||
)
|
||||
func NewSupervisor(config *TunnelConfig, reconnectCh chan ReconnectSignal, gracefulShutdownC chan struct{}) (*Supervisor, error) {
|
||||
cloudflaredUUID, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate cloudflared instance ID: %w", err)
|
||||
}
|
||||
|
||||
var edgeIPs *edgediscovery.Edge
|
||||
if len(config.EdgeAddrs) > 0 {
|
||||
edgeIPs, err = edgediscovery.StaticEdge(config.Log, config.EdgeAddrs)
|
||||
} else {
|
||||
@@ -90,11 +96,16 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
|
||||
log: config.Log,
|
||||
reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
|
||||
useReconnectToken: useReconnectToken,
|
||||
reconnectCh: reconnectCh,
|
||||
gracefulShutdownC: gracefulShutdownC,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
||||
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
|
||||
func (s *Supervisor) Run(
|
||||
ctx context.Context,
|
||||
connectedSignal *signal.Signal,
|
||||
) error {
|
||||
if err := s.initialize(ctx, connectedSignal); err != nil {
|
||||
return err
|
||||
}
|
||||
var tunnelsWaiting []int
|
||||
@@ -131,7 +142,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||
case tunnelError := <-s.tunnelErrors:
|
||||
tunnelsActive--
|
||||
if tunnelError.err != nil {
|
||||
s.log.Err(tunnelError.err).Msg("supervisor: Tunnel disconnected")
|
||||
s.log.Err(tunnelError.err).Int(connection.LogFieldConnIndex, tunnelError.index).Msg("Connection terminated")
|
||||
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
||||
s.waitForNextTunnel(tunnelError.index)
|
||||
|
||||
@@ -139,14 +150,16 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||
backoffTimer = backoff.BackoffTimer()
|
||||
}
|
||||
|
||||
// Previously we'd mark the edge address as bad here, but now we'll just silently use
|
||||
// another.
|
||||
// Previously we'd mark the edge address as bad here, but now we'll just silently use another.
|
||||
} else if tunnelsActive == 0 {
|
||||
// all connected tunnels exited gracefully, no more work to do
|
||||
return nil
|
||||
}
|
||||
// Backoff was set and its timer expired
|
||||
case <-backoffTimer:
|
||||
backoffTimer = nil
|
||||
for _, index := range tunnelsWaiting {
|
||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index), reconnectCh)
|
||||
go s.startTunnel(ctx, index, s.newConnectedTunnelSignal(index))
|
||||
}
|
||||
tunnelsActive += len(tunnelsWaiting)
|
||||
tunnelsWaiting = nil
|
||||
@@ -171,14 +184,17 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||
}
|
||||
|
||||
// Returns nil if initialization succeeded, else the initialization error.
|
||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
||||
availableAddrs := int(s.edgeIPs.AvailableAddrs())
|
||||
func (s *Supervisor) initialize(
|
||||
ctx context.Context,
|
||||
connectedSignal *signal.Signal,
|
||||
) error {
|
||||
availableAddrs := s.edgeIPs.AvailableAddrs()
|
||||
if s.config.HAConnections > availableAddrs {
|
||||
s.log.Info().Msgf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
|
||||
s.config.HAConnections = availableAddrs
|
||||
}
|
||||
|
||||
go s.startFirstTunnel(ctx, connectedSignal, reconnectCh)
|
||||
go s.startFirstTunnel(ctx, connectedSignal)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
<-s.tunnelErrors
|
||||
@@ -190,7 +206,7 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||
// At least one successful connection, so start the rest
|
||||
for i := 1; i < s.config.HAConnections; i++ {
|
||||
ch := signal.New(make(chan struct{}))
|
||||
go s.startTunnel(ctx, i, ch, reconnectCh)
|
||||
go s.startTunnel(ctx, i, ch)
|
||||
time.Sleep(registrationInterval)
|
||||
}
|
||||
return nil
|
||||
@@ -198,7 +214,10 @@ func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Sig
|
||||
|
||||
// startTunnel starts the first tunnel connection. The resulting error will be sent on
|
||||
// s.tunnelErrors. It will send a signal via connectedSignal if registration succeed
|
||||
func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) {
|
||||
func (s *Supervisor) startFirstTunnel(
|
||||
ctx context.Context,
|
||||
connectedSignal *signal.Signal,
|
||||
) {
|
||||
var (
|
||||
addr *net.TCPAddr
|
||||
err error
|
||||
@@ -221,7 +240,8 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||
firstConnIndex,
|
||||
connectedSignal,
|
||||
s.cloudflaredUUID,
|
||||
reconnectCh,
|
||||
s.reconnectCh,
|
||||
s.gracefulShutdownC,
|
||||
)
|
||||
// If the first tunnel disconnects, keep restarting it.
|
||||
edgeErrors := 0
|
||||
@@ -253,14 +273,19 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||
firstConnIndex,
|
||||
connectedSignal,
|
||||
s.cloudflaredUUID,
|
||||
reconnectCh,
|
||||
s.reconnectCh,
|
||||
s.gracefulShutdownC,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// startTunnel starts a new tunnel connection. The resulting error will be sent on
|
||||
// s.tunnelErrors.
|
||||
func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) {
|
||||
func (s *Supervisor) startTunnel(
|
||||
ctx context.Context,
|
||||
index int,
|
||||
connectedSignal *signal.Signal,
|
||||
) {
|
||||
var (
|
||||
addr *net.TCPAddr
|
||||
err error
|
||||
@@ -281,7 +306,8 @@ func (s *Supervisor) startTunnel(ctx context.Context, index int, connectedSignal
|
||||
uint8(index),
|
||||
connectedSignal,
|
||||
s.cloudflaredUUID,
|
||||
reconnectCh,
|
||||
s.reconnectCh,
|
||||
s.gracefulShutdownC,
|
||||
)
|
||||
}
|
||||
|
||||
|
@@ -107,12 +107,18 @@ func (c *TunnelConfig) SupportedFeatures() []string {
|
||||
return features
|
||||
}
|
||||
|
||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error {
|
||||
s, err := NewSupervisor(config, cloudflaredID)
|
||||
func StartTunnelDaemon(
|
||||
ctx context.Context,
|
||||
config *TunnelConfig,
|
||||
connectedSignal *signal.Signal,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
graceShutdownC chan struct{},
|
||||
) error {
|
||||
s, err := NewSupervisor(config, reconnectCh, graceShutdownC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Run(ctx, connectedSignal, reconnectCh)
|
||||
return s.Run(ctx, connectedSignal)
|
||||
}
|
||||
|
||||
func ServeTunnelLoop(
|
||||
@@ -124,6 +130,7 @@ func ServeTunnelLoop(
|
||||
connectedSignal *signal.Signal,
|
||||
cloudflaredUUID uuid.UUID,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
gracefulShutdownC chan struct{},
|
||||
) error {
|
||||
haConnections.Inc()
|
||||
defer haConnections.Dec()
|
||||
@@ -158,6 +165,7 @@ func ServeTunnelLoop(
|
||||
cloudflaredUUID,
|
||||
reconnectCh,
|
||||
protocallFallback.protocol,
|
||||
gracefulShutdownC,
|
||||
)
|
||||
if !recoverable {
|
||||
return err
|
||||
@@ -242,6 +250,7 @@ func ServeTunnel(
|
||||
cloudflaredUUID uuid.UUID,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
protocol connection.Protocol,
|
||||
gracefulShutdownC chan struct{},
|
||||
) (err error, recoverable bool) {
|
||||
// Treat panics as recoverable errors
|
||||
defer func() {
|
||||
@@ -268,7 +277,17 @@ func ServeTunnel(
|
||||
}
|
||||
if protocol == connection.HTTP2 {
|
||||
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
|
||||
return ServeHTTP2(ctx, log, config, edgeConn, connOptions, connIndex, connectedFuse, reconnectCh)
|
||||
return ServeHTTP2(
|
||||
ctx,
|
||||
log,
|
||||
config,
|
||||
edgeConn,
|
||||
connOptions,
|
||||
connIndex,
|
||||
connectedFuse,
|
||||
reconnectCh,
|
||||
gracefulShutdownC,
|
||||
)
|
||||
}
|
||||
return ServeH2mux(
|
||||
ctx,
|
||||
@@ -280,6 +299,7 @@ func ServeTunnel(
|
||||
connectedFuse,
|
||||
cloudflaredUUID,
|
||||
reconnectCh,
|
||||
gracefulShutdownC,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -293,6 +313,7 @@ func ServeH2mux(
|
||||
connectedFuse *connectedFuse,
|
||||
cloudflaredUUID uuid.UUID,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
gracefulShutdownC chan struct{},
|
||||
) (err error, recoverable bool) {
|
||||
config.Log.Debug().Msgf("Connecting via h2mux")
|
||||
// Returns error from parsing the origin URL or handshake errors
|
||||
@@ -302,6 +323,7 @@ func ServeH2mux(
|
||||
edgeConn,
|
||||
connIndex,
|
||||
config.Observer,
|
||||
gracefulShutdownC,
|
||||
)
|
||||
if err != nil {
|
||||
return err, recoverable
|
||||
@@ -312,13 +334,13 @@ func ServeH2mux(
|
||||
errGroup.Go(func() (err error) {
|
||||
if config.NamedTunnel != nil {
|
||||
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
|
||||
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
|
||||
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse)
|
||||
}
|
||||
registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
|
||||
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
|
||||
})
|
||||
|
||||
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
|
||||
errGroup.Go(listenReconnect(serveCtx, reconnectCh, gracefulShutdownC))
|
||||
|
||||
err = errGroup.Wait()
|
||||
if err != nil {
|
||||
@@ -367,9 +389,10 @@ func ServeHTTP2(
|
||||
connIndex uint8,
|
||||
connectedFuse connection.ConnectedFuse,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
gracefulShutdownC chan struct{},
|
||||
) (err error, recoverable bool) {
|
||||
log.Debug().Msgf("Connecting via http2")
|
||||
server := connection.NewHTTP2Connection(
|
||||
h2conn := connection.NewHTTP2Connection(
|
||||
tlsServerConn,
|
||||
config.ConnectionConfig,
|
||||
config.NamedTunnel,
|
||||
@@ -377,15 +400,15 @@ func ServeHTTP2(
|
||||
config.Observer,
|
||||
connIndex,
|
||||
connectedFuse,
|
||||
gracefulShutdownC,
|
||||
)
|
||||
|
||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
server.Serve(serveCtx)
|
||||
return fmt.Errorf("connection with edge closed")
|
||||
return h2conn.Serve(serveCtx)
|
||||
})
|
||||
|
||||
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
|
||||
errGroup.Go(listenReconnect(serveCtx, reconnectCh, gracefulShutdownC))
|
||||
|
||||
err = errGroup.Wait()
|
||||
if err != nil {
|
||||
@@ -394,11 +417,13 @@ func ServeHTTP2(
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) func() error {
|
||||
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh chan struct{}) func() error {
|
||||
return func() error {
|
||||
select {
|
||||
case reconnect := <-reconnectCh:
|
||||
return &reconnect
|
||||
return reconnect
|
||||
case <-gracefulShutdownCh:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user