TUN-3458: Upgrade to http2 when available, fallback to h2mux when we reach max retries

This commit is contained in:
cthuang
2020-10-14 14:42:00 +01:00
parent b5cdf3b2c7
commit a490443630
13 changed files with 632 additions and 159 deletions

View File

@@ -97,3 +97,11 @@ func (b BackoffHandler) GetBaseTime() time.Duration {
func (b *BackoffHandler) Retries() int {
return int(b.retries)
}
func (b *BackoffHandler) ReachedMaxRetries() bool {
return b.retries == b.MaxRetries
}
func (b *BackoffHandler) resetNow() {
b.resetDeadline = time.Now()
}

View File

@@ -17,10 +17,10 @@ import (
)
const (
// SRV and TXT record resolution TTL
ResolveTTL = time.Hour
// Waiting time before retrying a failed tunnel connection
tunnelRetryDuration = time.Second * 10
// SRV record resolution TTL
resolveTTL = time.Hour
// Interval between registering new tunnels
registrationInterval = time.Second
@@ -43,8 +43,6 @@ type Supervisor struct {
cloudflaredUUID uuid.UUID
config *TunnelConfig
edgeIPs *edgediscovery.Edge
lastResolve time.Time
resolverC chan resolveResult
tunnelErrors chan tunnelError
tunnelsConnecting map[int]chan struct{}
// nextConnectedIndex and nextConnectedSignal are used to wait for all
@@ -58,10 +56,6 @@ type Supervisor struct {
useReconnectToken bool
}
type resolveResult struct {
err error
}
type tunnelError struct {
index int
addr *net.TCPAddr
@@ -74,9 +68,9 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
err error
)
if len(config.EdgeAddrs) > 0 {
edgeIPs, err = edgediscovery.StaticEdge(config.Observer, config.EdgeAddrs)
edgeIPs, err = edgediscovery.StaticEdge(config.Logger, config.EdgeAddrs)
} else {
edgeIPs, err = edgediscovery.ResolveEdge(config.Observer)
edgeIPs, err = edgediscovery.ResolveEdge(config.Logger)
}
if err != nil {
return nil, err
@@ -93,14 +87,13 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
edgeIPs: edgeIPs,
tunnelErrors: make(chan tunnelError),
tunnelsConnecting: map[int]chan struct{}{},
logger: config.Observer,
logger: config.Logger,
reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
useReconnectToken: useReconnectToken,
}, nil
}
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
logger := s.config.Observer
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
return err
}
@@ -117,7 +110,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
refreshAuthBackoffTimer = timer
} else {
logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
s.logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
}
}
@@ -136,7 +129,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
case tunnelError := <-s.tunnelErrors:
tunnelsActive--
if tunnelError.err != nil {
logger.Infof("supervisor: Tunnel disconnected due to error: %s", tunnelError.err)
s.logger.Infof("supervisor: Tunnel disconnected due to error: %s", tunnelError.err)
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
s.waitForNextTunnel(tunnelError.index)
@@ -159,7 +152,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
case <-refreshAuthBackoffTimer:
newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
if err != nil {
logger.Errorf("supervisor: Authentication failed: %s", err)
s.logger.Errorf("supervisor: Authentication failed: %s", err)
// Permanent failure. Leave the `select` without setting the
// channel to be non-null, so we'll never hit this case of the `select` again.
continue
@@ -171,27 +164,15 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
// No more tunnels outstanding, clear backoff timer
backoff.SetGracePeriod()
}
// DNS resolution returned
case result := <-s.resolverC:
s.lastResolve = time.Now()
s.resolverC = nil
if result.err == nil {
logger.Debug("supervisor: Service discovery refresh complete")
} else {
logger.Errorf("supervisor: Service discovery error: %s", result.err)
}
}
}
}
// Returns nil if initialization succeeded, else the initialization error.
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
logger := s.logger
s.lastResolve = time.Now()
availableAddrs := int(s.edgeIPs.AvailableAddrs())
if s.config.HAConnections > availableAddrs {
logger.Infof("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
s.logger.Infof("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
s.config.HAConnections = availableAddrs
}
@@ -304,7 +285,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
return nil, err
}
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.TLSConfig, arbitraryEdgeIP)
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.EdgeTLSConfigs[connection.H2mux], arbitraryEdgeIP)
if err != nil {
return nil, err
}

View File

@@ -62,13 +62,14 @@ type TunnelConfig struct {
ReportedVersion string
Retries uint
RunFromTerminal bool
TLSConfig *tls.Config
NamedTunnel *connection.NamedTunnelConfig
ClassicTunnel *connection.ClassicTunnelConfig
MuxerConfig *connection.MuxerConfig
TunnelEventChan chan ui.TunnelEvent
IngressRules ingress.Ingress
NamedTunnel *connection.NamedTunnelConfig
ClassicTunnel *connection.ClassicTunnelConfig
MuxerConfig *connection.MuxerConfig
TunnelEventChan chan ui.TunnelEvent
IngressRules ingress.Ingress
ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config
}
type muxerShutdownError struct{}
@@ -157,7 +158,7 @@ func ServeTunnelLoop(ctx context.Context,
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
addr *net.TCPAddr,
connectionIndex uint8,
connIndex uint8,
connectedSignal *signal.Signal,
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
@@ -165,7 +166,11 @@ func ServeTunnelLoop(ctx context.Context,
haConnections.Inc()
defer haConnections.Dec()
backoff := BackoffHandler{MaxRetries: config.Retries}
protocallFallback := &protocallFallback{
BackoffHandler{MaxRetries: config.Retries},
config.ProtocolSelector.Current(),
false,
}
connectedFuse := h2mux.NewBooleanFuse()
go func() {
if connectedFuse.Await() {
@@ -174,29 +179,90 @@ func ServeTunnelLoop(ctx context.Context,
}()
// Ensure the above goroutine will terminate if we return without connecting
defer connectedFuse.Fuse(false)
// Each connection to keep its own copy of protocol, because individual connections might fallback
// to another protocol when a particular metal doesn't support new protocol
for {
err, recoverable := ServeTunnel(
ctx,
credentialManager,
config,
addr, connectionIndex,
addr,
connIndex,
connectedFuse,
&backoff,
protocallFallback,
cloudflaredUUID,
reconnectCh,
protocallFallback.protocol,
)
if recoverable {
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
if config.TunnelEventChan != nil {
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting}
}
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connectionIndex, duration, err)
backoff.Backoff(ctx)
continue
}
if !recoverable {
return err
}
err = waitForBackoff(ctx, protocallFallback, config, connIndex, err)
if err != nil {
return err
}
}
}
// protocallFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
// max retries
type protocallFallback struct {
BackoffHandler
protocol connection.Protocol
inFallback bool
}
func (pf *protocallFallback) reset() {
pf.resetNow()
pf.inFallback = false
}
func (pf *protocallFallback) fallback(fallback connection.Protocol) {
pf.resetNow()
pf.protocol = fallback
pf.inFallback = true
}
// Expect err to always be non nil
func waitForBackoff(
ctx context.Context,
protobackoff *protocallFallback,
config *TunnelConfig,
connIndex uint8,
err error,
) error {
duration, ok := protobackoff.GetBackoffDuration(ctx)
if !ok {
return err
}
if config.TunnelEventChan != nil {
config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Reconnecting}
}
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connIndex, duration, err)
protobackoff.Backoff(ctx)
if protobackoff.ReachedMaxRetries() {
fallback, hasFallback := config.ProtocolSelector.Fallback()
if !hasFallback {
return err
}
// Already using fallback protocol, no point to retry
if protobackoff.protocol == fallback {
return err
}
config.Logger.Infof("Fallback to use %s", fallback)
protobackoff.fallback(fallback)
} else if !protobackoff.inFallback {
current := config.ProtocolSelector.Current()
if protobackoff.protocol != current {
protobackoff.protocol = current
config.Logger.Infof("Change protocol to %s", current)
}
}
return nil
}
func ServeTunnel(
@@ -204,11 +270,12 @@ func ServeTunnel(
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
addr *net.TCPAddr,
connectionIndex uint8,
connIndex uint8,
fuse *h2mux.BooleanFuse,
backoff *BackoffHandler,
backoff *protocallFallback,
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
protocol connection.Protocol,
) (err error, recoverable bool) {
// Treat panics as recoverable errors
defer func() {
@@ -226,11 +293,11 @@ func ServeTunnel(
// If launch-ui flag is set, send disconnect msg
if config.TunnelEventChan != nil {
defer func() {
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Disconnected}
config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Disconnected}
}()
}
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.TLSConfig, addr)
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr)
if err != nil {
return err, true
}
@@ -238,11 +305,11 @@ func ServeTunnel(
fuse: fuse,
backoff: backoff,
}
if config.Protocol == connection.HTTP2 {
if protocol == connection.HTTP2 {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
return ServeHTTP2(ctx, config, edgeConn, connOptions, connectionIndex, connectedFuse, reconnectCh)
return ServeHTTP2(ctx, config, edgeConn, connOptions, connIndex, connectedFuse, reconnectCh)
}
return ServeH2mux(ctx, credentialManager, config, edgeConn, connectionIndex, connectedFuse, cloudflaredUUID, reconnectCh)
return ServeH2mux(ctx, credentialManager, config, edgeConn, connIndex, connectedFuse, cloudflaredUUID, reconnectCh)
}
func ServeH2mux(
@@ -255,6 +322,7 @@ func ServeH2mux(
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) {
config.Logger.Debugf("Connecting via h2mux")
// Returns error from parsing the origin URL or handshake errors
handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer)
if err != nil {
@@ -266,10 +334,10 @@ func ServeH2mux(
errGroup.Go(func() (err error) {
if config.NamedTunnel != nil {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
return handler.ServeNamedTunnel(ctx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
}
registrationOptions := config.RegistrationOptions(connectionIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
return handler.ServeClassicTunnel(ctx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
})
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
@@ -295,7 +363,7 @@ func ServeH2mux(
config.Logger.Info("Muxer shutdown")
return err, true
case *ReconnectSignal:
config.Logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay)
config.Logger.Infof("Restarting connection %d due to reconnect signal in %s", connectionIndex, err.Delay)
err.DelayBeforeReconnect()
return err, true
default:
@@ -319,10 +387,8 @@ func ServeHTTP2(
connectedFuse connection.ConnectedFuse,
reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) {
server, err := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
if err != nil {
return err, false
}
config.Logger.Debugf("Connecting via http2")
server := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
@@ -352,12 +418,12 @@ func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) fu
type connectedFuse struct {
fuse *h2mux.BooleanFuse
backoff *BackoffHandler
backoff *protocallFallback
}
func (cf *connectedFuse) Connected() {
cf.fuse.Fuse(true)
cf.backoff.SetGracePeriod()
cf.backoff.reset()
}
func (cf *connectedFuse) IsConnected() bool {

90
origin/tunnel_test.go Normal file
View File

@@ -0,0 +1,90 @@
package origin
import (
"context"
"fmt"
"testing"
"time"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/stretchr/testify/assert"
)
type dynamicMockFetcher struct {
percentage int32
err error
}
func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher {
return func() (int32, error) {
if dmf.err != nil {
return 0, dmf.err
}
return dmf.percentage, nil
}
}
func TestWaitForBackoffFallback(t *testing.T) {
maxRetries := uint(3)
backoff := BackoffHandler{
MaxRetries: maxRetries,
BaseTime: time.Millisecond * 10,
}
ctx := context.Background()
logger, err := logger.New()
assert.NoError(t, err)
resolveTTL := time.Duration(0)
namedTunnel := &connection.NamedTunnelConfig{
Auth: pogs.TunnelAuth{
AccountTag: "test-account",
},
}
mockFetcher := dynamicMockFetcher{
percentage: 0,
}
protocolSelector, err := connection.NewProtocolSelector(connection.HTTP2.String(), namedTunnel, mockFetcher.fetch(), resolveTTL, logger)
assert.NoError(t, err)
config := &TunnelConfig{
Logger: logger,
ProtocolSelector: protocolSelector,
}
connIndex := uint8(1)
initProtocol := protocolSelector.Current()
assert.Equal(t, connection.HTTP2, initProtocol)
protocallFallback := &protocallFallback{
backoff,
initProtocol,
false,
}
// Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this
for i := 0; i < int(maxRetries-1); i++ {
err := waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
assert.NoError(t, err)
assert.Equal(t, initProtocol, protocallFallback.protocol)
}
// Retry fallback protocol
for i := 0; i < int(maxRetries); i++ {
err := waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
assert.NoError(t, err)
fallback, ok := protocolSelector.Fallback()
assert.True(t, ok)
assert.Equal(t, fallback, protocallFallback.protocol)
}
currentGlobalProtocol := protocolSelector.Current()
assert.Equal(t, initProtocol, currentGlobalProtocol)
// No protocol to fallback, return error
err = waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
assert.Error(t, err)
protocallFallback.reset()
err = waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("New error"))
assert.NoError(t, err)
assert.Equal(t, initProtocol, protocallFallback.protocol)
}