TUN-3869: Improve reliability of graceful shutdown.

- Don't rely on edge to close connection on graceful shutdown in h2mux, start muxer shutdown from cloudflared.
- Don't retry failed connections after graceful shutdown has started.
- After graceful shutdown channel is closed we stop waiting for retry timer and don't try to restart tunnel loop.
- Use readonly channel for graceful shutdown in functions that only consume the signal
This commit is contained in:
Igor Postelnik
2021-02-04 18:07:49 -06:00
parent dbd90f270e
commit 0b16a473da
6 changed files with 95 additions and 83 deletions

View File

@@ -113,7 +113,7 @@ func StartTunnelDaemon(
config *TunnelConfig,
connectedSignal *signal.Signal,
reconnectCh chan ReconnectSignal,
graceShutdownC chan struct{},
graceShutdownC <-chan struct{},
) error {
s, err := NewSupervisor(config, reconnectCh, graceShutdownC)
if err != nil {
@@ -131,14 +131,14 @@ func ServeTunnelLoop(
connectedSignal *signal.Signal,
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
gracefulShutdownC chan struct{},
gracefulShutdownC <-chan struct{},
) error {
haConnections.Inc()
defer haConnections.Dec()
connLog := config.Log.With().Uint8(connection.LogFieldConnIndex, connIndex).Logger()
protocallFallback := &protocallFallback{
protocolFallback := &protocolFallback{
BackoffHandler{MaxRetries: config.Retries},
config.ProtocolSelector.Current(),
false,
@@ -162,82 +162,82 @@ func ServeTunnelLoop(
addr,
connIndex,
connectedFuse,
protocallFallback,
protocolFallback,
cloudflaredUUID,
reconnectCh,
protocallFallback.protocol,
protocolFallback.protocol,
gracefulShutdownC,
)
if !recoverable {
return err
}
err = waitForBackoff(ctx, &connLog, protocallFallback, config, connIndex, err)
if err != nil {
config.Observer.SendReconnect(connIndex)
duration, ok := protocolFallback.GetBackoffDuration(ctx)
if !ok {
return err
}
connLog.Info().Msgf("Retrying connection in %s seconds", duration)
select {
case <-ctx.Done():
return ctx.Err()
case <-gracefulShutdownC:
return nil
case <-protocolFallback.BackoffTimer():
if !selectNextProtocol(&connLog, protocolFallback, config.ProtocolSelector) {
return err
}
}
}
}
// protocallFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
// protocolFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
// max retries
type protocallFallback struct {
type protocolFallback struct {
BackoffHandler
protocol connection.Protocol
inFallback bool
}
func (pf *protocallFallback) reset() {
func (pf *protocolFallback) reset() {
pf.resetNow()
pf.inFallback = false
}
func (pf *protocallFallback) fallback(fallback connection.Protocol) {
func (pf *protocolFallback) fallback(fallback connection.Protocol) {
pf.resetNow()
pf.protocol = fallback
pf.inFallback = true
}
// Expect err to always be non nil
func waitForBackoff(
ctx context.Context,
log *zerolog.Logger,
protobackoff *protocallFallback,
config *TunnelConfig,
connIndex uint8,
err error,
) error {
duration, ok := protobackoff.GetBackoffDuration(ctx)
if !ok {
return err
}
config.Observer.SendReconnect(connIndex)
log.Info().
Err(err).
Uint8(connection.LogFieldConnIndex, connIndex).
Msgf("Retrying connection in %s seconds", duration)
protobackoff.Backoff(ctx)
if protobackoff.ReachedMaxRetries() {
fallback, hasFallback := config.ProtocolSelector.Fallback()
// selectNextProtocol picks connection protocol for the next retry iteration,
// returns true if it was able to pick the protocol, false if we are out of options and should stop retrying
func selectNextProtocol(
connLog *zerolog.Logger,
protocolBackoff *protocolFallback,
selector connection.ProtocolSelector,
) bool {
if protocolBackoff.ReachedMaxRetries() {
fallback, hasFallback := selector.Fallback()
if !hasFallback {
return err
return false
}
// Already using fallback protocol, no point to retry
if protobackoff.protocol == fallback {
return err
if protocolBackoff.protocol == fallback {
return false
}
log.Info().Msgf("Fallback to use %s", fallback)
protobackoff.fallback(fallback)
} else if !protobackoff.inFallback {
current := config.ProtocolSelector.Current()
if protobackoff.protocol != current {
protobackoff.protocol = current
config.Log.Info().Msgf("Change protocol to %s", current)
connLog.Info().Msgf("Switching to fallback protocol %s", fallback)
protocolBackoff.fallback(fallback)
} else if !protocolBackoff.inFallback {
current := selector.Current()
if protocolBackoff.protocol != current {
protocolBackoff.protocol = current
connLog.Info().Msgf("Changing protocol to %s", current)
}
}
return nil
return true
}
// ServeTunnel runs a single tunnel connection, returns nil on graceful shutdown,
@@ -250,11 +250,11 @@ func ServeTunnel(
addr *net.TCPAddr,
connIndex uint8,
fuse *h2mux.BooleanFuse,
backoff *protocallFallback,
backoff *protocolFallback,
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
protocol connection.Protocol,
gracefulShutdownC chan struct{},
gracefulShutdownC <-chan struct{},
) (err error, recoverable bool) {
// Treat panics as recoverable errors
defer func() {
@@ -358,7 +358,7 @@ func ServeH2mux(
connectedFuse *connectedFuse,
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
gracefulShutdownC chan struct{},
gracefulShutdownC <-chan struct{},
) error {
connLog.Debug().Msgf("Connecting via h2mux")
// Returns error from parsing the origin URL or handshake errors
@@ -404,7 +404,7 @@ func ServeHTTP2(
connIndex uint8,
connectedFuse connection.ConnectedFuse,
reconnectCh chan ReconnectSignal,
gracefulShutdownC chan struct{},
gracefulShutdownC <-chan struct{},
) error {
connLog.Debug().Msgf("Connecting via http2")
h2conn := connection.NewHTTP2Connection(
@@ -435,7 +435,7 @@ func ServeHTTP2(
return errGroup.Wait()
}
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh chan struct{}) error {
func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gracefulShutdownCh <-chan struct{}) error {
select {
case reconnect := <-reconnectCh:
return reconnect
@@ -448,7 +448,7 @@ func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal, gr
type connectedFuse struct {
fuse *h2mux.BooleanFuse
backoff *protocallFallback
backoff *protocolFallback
}
func (cf *connectedFuse) Connected() {