mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 00:09:57 +00:00
TUN-3458: Upgrade to http2 when available, fallback to h2mux when we reach max retries
This commit is contained in:
@@ -97,3 +97,11 @@ func (b BackoffHandler) GetBaseTime() time.Duration {
|
||||
func (b *BackoffHandler) Retries() int {
|
||||
return int(b.retries)
|
||||
}
|
||||
|
||||
func (b *BackoffHandler) ReachedMaxRetries() bool {
|
||||
return b.retries == b.MaxRetries
|
||||
}
|
||||
|
||||
func (b *BackoffHandler) resetNow() {
|
||||
b.resetDeadline = time.Now()
|
||||
}
|
||||
|
@@ -17,10 +17,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// SRV and TXT record resolution TTL
|
||||
ResolveTTL = time.Hour
|
||||
// Waiting time before retrying a failed tunnel connection
|
||||
tunnelRetryDuration = time.Second * 10
|
||||
// SRV record resolution TTL
|
||||
resolveTTL = time.Hour
|
||||
// Interval between registering new tunnels
|
||||
registrationInterval = time.Second
|
||||
|
||||
@@ -43,8 +43,6 @@ type Supervisor struct {
|
||||
cloudflaredUUID uuid.UUID
|
||||
config *TunnelConfig
|
||||
edgeIPs *edgediscovery.Edge
|
||||
lastResolve time.Time
|
||||
resolverC chan resolveResult
|
||||
tunnelErrors chan tunnelError
|
||||
tunnelsConnecting map[int]chan struct{}
|
||||
// nextConnectedIndex and nextConnectedSignal are used to wait for all
|
||||
@@ -58,10 +56,6 @@ type Supervisor struct {
|
||||
useReconnectToken bool
|
||||
}
|
||||
|
||||
type resolveResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
type tunnelError struct {
|
||||
index int
|
||||
addr *net.TCPAddr
|
||||
@@ -74,9 +68,9 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
|
||||
err error
|
||||
)
|
||||
if len(config.EdgeAddrs) > 0 {
|
||||
edgeIPs, err = edgediscovery.StaticEdge(config.Observer, config.EdgeAddrs)
|
||||
edgeIPs, err = edgediscovery.StaticEdge(config.Logger, config.EdgeAddrs)
|
||||
} else {
|
||||
edgeIPs, err = edgediscovery.ResolveEdge(config.Observer)
|
||||
edgeIPs, err = edgediscovery.ResolveEdge(config.Logger)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -93,14 +87,13 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor
|
||||
edgeIPs: edgeIPs,
|
||||
tunnelErrors: make(chan tunnelError),
|
||||
tunnelsConnecting: map[int]chan struct{}{},
|
||||
logger: config.Observer,
|
||||
logger: config.Logger,
|
||||
reconnectCredentialManager: newReconnectCredentialManager(connection.MetricsNamespace, connection.TunnelSubsystem, config.HAConnections),
|
||||
useReconnectToken: useReconnectToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
||||
logger := s.config.Observer
|
||||
if err := s.initialize(ctx, connectedSignal, reconnectCh); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -117,7 +110,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||
if timer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate); err == nil {
|
||||
refreshAuthBackoffTimer = timer
|
||||
} else {
|
||||
logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
|
||||
s.logger.Errorf("supervisor: initial refreshAuth failed, retrying in %v: %s", refreshAuthRetryDuration, err)
|
||||
refreshAuthBackoffTimer = time.After(refreshAuthRetryDuration)
|
||||
}
|
||||
}
|
||||
@@ -136,7 +129,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||
case tunnelError := <-s.tunnelErrors:
|
||||
tunnelsActive--
|
||||
if tunnelError.err != nil {
|
||||
logger.Infof("supervisor: Tunnel disconnected due to error: %s", tunnelError.err)
|
||||
s.logger.Infof("supervisor: Tunnel disconnected due to error: %s", tunnelError.err)
|
||||
tunnelsWaiting = append(tunnelsWaiting, tunnelError.index)
|
||||
s.waitForNextTunnel(tunnelError.index)
|
||||
|
||||
@@ -159,7 +152,7 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||
case <-refreshAuthBackoffTimer:
|
||||
newTimer, err := s.reconnectCredentialManager.RefreshAuth(ctx, refreshAuthBackoff, s.authenticate)
|
||||
if err != nil {
|
||||
logger.Errorf("supervisor: Authentication failed: %s", err)
|
||||
s.logger.Errorf("supervisor: Authentication failed: %s", err)
|
||||
// Permanent failure. Leave the `select` without setting the
|
||||
// channel to be non-null, so we'll never hit this case of the `select` again.
|
||||
continue
|
||||
@@ -171,27 +164,15 @@ func (s *Supervisor) Run(ctx context.Context, connectedSignal *signal.Signal, re
|
||||
// No more tunnels outstanding, clear backoff timer
|
||||
backoff.SetGracePeriod()
|
||||
}
|
||||
// DNS resolution returned
|
||||
case result := <-s.resolverC:
|
||||
s.lastResolve = time.Now()
|
||||
s.resolverC = nil
|
||||
if result.err == nil {
|
||||
logger.Debug("supervisor: Service discovery refresh complete")
|
||||
} else {
|
||||
logger.Errorf("supervisor: Service discovery error: %s", result.err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Returns nil if initialization succeeded, else the initialization error.
|
||||
func (s *Supervisor) initialize(ctx context.Context, connectedSignal *signal.Signal, reconnectCh chan ReconnectSignal) error {
|
||||
logger := s.logger
|
||||
|
||||
s.lastResolve = time.Now()
|
||||
availableAddrs := int(s.edgeIPs.AvailableAddrs())
|
||||
if s.config.HAConnections > availableAddrs {
|
||||
logger.Infof("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
|
||||
s.logger.Infof("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, availableAddrs)
|
||||
s.config.HAConnections = availableAddrs
|
||||
}
|
||||
|
||||
@@ -304,7 +285,7 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.TLSConfig, arbitraryEdgeIP)
|
||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, s.config.EdgeTLSConfigs[connection.H2mux], arbitraryEdgeIP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
136
origin/tunnel.go
136
origin/tunnel.go
@@ -62,13 +62,14 @@ type TunnelConfig struct {
|
||||
ReportedVersion string
|
||||
Retries uint
|
||||
RunFromTerminal bool
|
||||
TLSConfig *tls.Config
|
||||
|
||||
NamedTunnel *connection.NamedTunnelConfig
|
||||
ClassicTunnel *connection.ClassicTunnelConfig
|
||||
MuxerConfig *connection.MuxerConfig
|
||||
TunnelEventChan chan ui.TunnelEvent
|
||||
IngressRules ingress.Ingress
|
||||
NamedTunnel *connection.NamedTunnelConfig
|
||||
ClassicTunnel *connection.ClassicTunnelConfig
|
||||
MuxerConfig *connection.MuxerConfig
|
||||
TunnelEventChan chan ui.TunnelEvent
|
||||
IngressRules ingress.Ingress
|
||||
ProtocolSelector connection.ProtocolSelector
|
||||
EdgeTLSConfigs map[connection.Protocol]*tls.Config
|
||||
}
|
||||
|
||||
type muxerShutdownError struct{}
|
||||
@@ -157,7 +158,7 @@ func ServeTunnelLoop(ctx context.Context,
|
||||
credentialManager *reconnectCredentialManager,
|
||||
config *TunnelConfig,
|
||||
addr *net.TCPAddr,
|
||||
connectionIndex uint8,
|
||||
connIndex uint8,
|
||||
connectedSignal *signal.Signal,
|
||||
cloudflaredUUID uuid.UUID,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
@@ -165,7 +166,11 @@ func ServeTunnelLoop(ctx context.Context,
|
||||
haConnections.Inc()
|
||||
defer haConnections.Dec()
|
||||
|
||||
backoff := BackoffHandler{MaxRetries: config.Retries}
|
||||
protocallFallback := &protocallFallback{
|
||||
BackoffHandler{MaxRetries: config.Retries},
|
||||
config.ProtocolSelector.Current(),
|
||||
false,
|
||||
}
|
||||
connectedFuse := h2mux.NewBooleanFuse()
|
||||
go func() {
|
||||
if connectedFuse.Await() {
|
||||
@@ -174,29 +179,90 @@ func ServeTunnelLoop(ctx context.Context,
|
||||
}()
|
||||
// Ensure the above goroutine will terminate if we return without connecting
|
||||
defer connectedFuse.Fuse(false)
|
||||
// Each connection to keep its own copy of protocol, because individual connections might fallback
|
||||
// to another protocol when a particular metal doesn't support new protocol
|
||||
for {
|
||||
err, recoverable := ServeTunnel(
|
||||
ctx,
|
||||
credentialManager,
|
||||
config,
|
||||
addr, connectionIndex,
|
||||
addr,
|
||||
connIndex,
|
||||
connectedFuse,
|
||||
&backoff,
|
||||
protocallFallback,
|
||||
cloudflaredUUID,
|
||||
reconnectCh,
|
||||
protocallFallback.protocol,
|
||||
)
|
||||
if recoverable {
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||
if config.TunnelEventChan != nil {
|
||||
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting}
|
||||
}
|
||||
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connectionIndex, duration, err)
|
||||
backoff.Backoff(ctx)
|
||||
continue
|
||||
}
|
||||
if !recoverable {
|
||||
return err
|
||||
}
|
||||
|
||||
err = waitForBackoff(ctx, protocallFallback, config, connIndex, err)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// protocallFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
|
||||
// max retries
|
||||
type protocallFallback struct {
|
||||
BackoffHandler
|
||||
protocol connection.Protocol
|
||||
inFallback bool
|
||||
}
|
||||
|
||||
func (pf *protocallFallback) reset() {
|
||||
pf.resetNow()
|
||||
pf.inFallback = false
|
||||
}
|
||||
|
||||
func (pf *protocallFallback) fallback(fallback connection.Protocol) {
|
||||
pf.resetNow()
|
||||
pf.protocol = fallback
|
||||
pf.inFallback = true
|
||||
}
|
||||
|
||||
// Expect err to always be non nil
|
||||
func waitForBackoff(
|
||||
ctx context.Context,
|
||||
protobackoff *protocallFallback,
|
||||
config *TunnelConfig,
|
||||
connIndex uint8,
|
||||
err error,
|
||||
) error {
|
||||
duration, ok := protobackoff.GetBackoffDuration(ctx)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
|
||||
if config.TunnelEventChan != nil {
|
||||
config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Reconnecting}
|
||||
}
|
||||
|
||||
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connIndex, duration, err)
|
||||
protobackoff.Backoff(ctx)
|
||||
|
||||
if protobackoff.ReachedMaxRetries() {
|
||||
fallback, hasFallback := config.ProtocolSelector.Fallback()
|
||||
if !hasFallback {
|
||||
return err
|
||||
}
|
||||
// Already using fallback protocol, no point to retry
|
||||
if protobackoff.protocol == fallback {
|
||||
return err
|
||||
}
|
||||
config.Logger.Infof("Fallback to use %s", fallback)
|
||||
protobackoff.fallback(fallback)
|
||||
} else if !protobackoff.inFallback {
|
||||
current := config.ProtocolSelector.Current()
|
||||
if protobackoff.protocol != current {
|
||||
protobackoff.protocol = current
|
||||
config.Logger.Infof("Change protocol to %s", current)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ServeTunnel(
|
||||
@@ -204,11 +270,12 @@ func ServeTunnel(
|
||||
credentialManager *reconnectCredentialManager,
|
||||
config *TunnelConfig,
|
||||
addr *net.TCPAddr,
|
||||
connectionIndex uint8,
|
||||
connIndex uint8,
|
||||
fuse *h2mux.BooleanFuse,
|
||||
backoff *BackoffHandler,
|
||||
backoff *protocallFallback,
|
||||
cloudflaredUUID uuid.UUID,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
protocol connection.Protocol,
|
||||
) (err error, recoverable bool) {
|
||||
// Treat panics as recoverable errors
|
||||
defer func() {
|
||||
@@ -226,11 +293,11 @@ func ServeTunnel(
|
||||
// If launch-ui flag is set, send disconnect msg
|
||||
if config.TunnelEventChan != nil {
|
||||
defer func() {
|
||||
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Disconnected}
|
||||
config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Disconnected}
|
||||
}()
|
||||
}
|
||||
|
||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.TLSConfig, addr)
|
||||
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr)
|
||||
if err != nil {
|
||||
return err, true
|
||||
}
|
||||
@@ -238,11 +305,11 @@ func ServeTunnel(
|
||||
fuse: fuse,
|
||||
backoff: backoff,
|
||||
}
|
||||
if config.Protocol == connection.HTTP2 {
|
||||
if protocol == connection.HTTP2 {
|
||||
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
|
||||
return ServeHTTP2(ctx, config, edgeConn, connOptions, connectionIndex, connectedFuse, reconnectCh)
|
||||
return ServeHTTP2(ctx, config, edgeConn, connOptions, connIndex, connectedFuse, reconnectCh)
|
||||
}
|
||||
return ServeH2mux(ctx, credentialManager, config, edgeConn, connectionIndex, connectedFuse, cloudflaredUUID, reconnectCh)
|
||||
return ServeH2mux(ctx, credentialManager, config, edgeConn, connIndex, connectedFuse, cloudflaredUUID, reconnectCh)
|
||||
}
|
||||
|
||||
func ServeH2mux(
|
||||
@@ -255,6 +322,7 @@ func ServeH2mux(
|
||||
cloudflaredUUID uuid.UUID,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
) (err error, recoverable bool) {
|
||||
config.Logger.Debugf("Connecting via h2mux")
|
||||
// Returns error from parsing the origin URL or handshake errors
|
||||
handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer)
|
||||
if err != nil {
|
||||
@@ -266,10 +334,10 @@ func ServeH2mux(
|
||||
errGroup.Go(func() (err error) {
|
||||
if config.NamedTunnel != nil {
|
||||
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
|
||||
return handler.ServeNamedTunnel(ctx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
|
||||
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
|
||||
}
|
||||
registrationOptions := config.RegistrationOptions(connectionIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
|
||||
return handler.ServeClassicTunnel(ctx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
|
||||
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
|
||||
})
|
||||
|
||||
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
|
||||
@@ -295,7 +363,7 @@ func ServeH2mux(
|
||||
config.Logger.Info("Muxer shutdown")
|
||||
return err, true
|
||||
case *ReconnectSignal:
|
||||
config.Logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay)
|
||||
config.Logger.Infof("Restarting connection %d due to reconnect signal in %s", connectionIndex, err.Delay)
|
||||
err.DelayBeforeReconnect()
|
||||
return err, true
|
||||
default:
|
||||
@@ -319,10 +387,8 @@ func ServeHTTP2(
|
||||
connectedFuse connection.ConnectedFuse,
|
||||
reconnectCh chan ReconnectSignal,
|
||||
) (err error, recoverable bool) {
|
||||
server, err := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
|
||||
if err != nil {
|
||||
return err, false
|
||||
}
|
||||
config.Logger.Debugf("Connecting via http2")
|
||||
server := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
|
||||
|
||||
errGroup, serveCtx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
@@ -352,12 +418,12 @@ func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) fu
|
||||
|
||||
type connectedFuse struct {
|
||||
fuse *h2mux.BooleanFuse
|
||||
backoff *BackoffHandler
|
||||
backoff *protocallFallback
|
||||
}
|
||||
|
||||
func (cf *connectedFuse) Connected() {
|
||||
cf.fuse.Fuse(true)
|
||||
cf.backoff.SetGracePeriod()
|
||||
cf.backoff.reset()
|
||||
}
|
||||
|
||||
func (cf *connectedFuse) IsConnected() bool {
|
||||
|
90
origin/tunnel_test.go
Normal file
90
origin/tunnel_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package origin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type dynamicMockFetcher struct {
|
||||
percentage int32
|
||||
err error
|
||||
}
|
||||
|
||||
func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher {
|
||||
return func() (int32, error) {
|
||||
if dmf.err != nil {
|
||||
return 0, dmf.err
|
||||
}
|
||||
return dmf.percentage, nil
|
||||
}
|
||||
}
|
||||
func TestWaitForBackoffFallback(t *testing.T) {
|
||||
maxRetries := uint(3)
|
||||
backoff := BackoffHandler{
|
||||
MaxRetries: maxRetries,
|
||||
BaseTime: time.Millisecond * 10,
|
||||
}
|
||||
ctx := context.Background()
|
||||
logger, err := logger.New()
|
||||
assert.NoError(t, err)
|
||||
resolveTTL := time.Duration(0)
|
||||
namedTunnel := &connection.NamedTunnelConfig{
|
||||
Auth: pogs.TunnelAuth{
|
||||
AccountTag: "test-account",
|
||||
},
|
||||
}
|
||||
mockFetcher := dynamicMockFetcher{
|
||||
percentage: 0,
|
||||
}
|
||||
protocolSelector, err := connection.NewProtocolSelector(connection.HTTP2.String(), namedTunnel, mockFetcher.fetch(), resolveTTL, logger)
|
||||
assert.NoError(t, err)
|
||||
config := &TunnelConfig{
|
||||
Logger: logger,
|
||||
ProtocolSelector: protocolSelector,
|
||||
}
|
||||
connIndex := uint8(1)
|
||||
|
||||
initProtocol := protocolSelector.Current()
|
||||
assert.Equal(t, connection.HTTP2, initProtocol)
|
||||
|
||||
protocallFallback := &protocallFallback{
|
||||
backoff,
|
||||
initProtocol,
|
||||
false,
|
||||
}
|
||||
|
||||
// Retry #0 and #1. At retry #2, we switch protocol, so the fallback loop has one more retry than this
|
||||
for i := 0; i < int(maxRetries-1); i++ {
|
||||
err := waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, initProtocol, protocallFallback.protocol)
|
||||
}
|
||||
|
||||
// Retry fallback protocol
|
||||
for i := 0; i < int(maxRetries); i++ {
|
||||
err := waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
|
||||
assert.NoError(t, err)
|
||||
fallback, ok := protocolSelector.Fallback()
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, fallback, protocallFallback.protocol)
|
||||
}
|
||||
|
||||
currentGlobalProtocol := protocolSelector.Current()
|
||||
assert.Equal(t, initProtocol, currentGlobalProtocol)
|
||||
|
||||
// No protocol to fallback, return error
|
||||
err = waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("Some error"))
|
||||
assert.Error(t, err)
|
||||
|
||||
protocallFallback.reset()
|
||||
err = waitForBackoff(ctx, protocallFallback, config, connIndex, fmt.Errorf("New error"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, initProtocol, protocallFallback.protocol)
|
||||
}
|
Reference in New Issue
Block a user